BOJ_11505 - 구간 곱 구하기

Meantint·2021년 7월 17일
0

BOJ

목록 보기
11/16
post-thumbnail

 세그먼트 트리 문제. 원래 값에서 나머지를 구해야 하기 때문에 단말 노드부터 올라오면서 모두 update를 해줘야한다.

 여담으로 코드의 가독성과 타이핑의 귀찮음을 깨닫고 나머지 값을 MOD#define 처리해야겠다고 마음먹게 되었다.

문제/코드 링크

풀이

  • init 함수 작성

    • 리프 노드인 경우 seg[node] = cost[left] 실행

    • 리프 노드가 아닌 경우 왼쪽 서브 트리와 오른쪽 서브 트리 결괏값을 곱한 후 1000000007로 나눠준다.

  • update 함수 작성

    • 범위 안에 target_index가 없다면 return seg[node] 실행한다.

    • 단말 노드인 경우 cost[left]의 값을 new_value로 갱신해주고 cost[left]의 값을 seg[node]에도 넣어서 갱신해준다.

    • 단말 노드가 아닌 경우 왼쪽, 오른쪽으로 재귀 분할한 두 개의 결괏값을 곱해주고 1000000007로 나눠준 값을 seg[node]에 저장하고 리턴한다.

  • query 함수 작성

    • 범위를 완전히 벗어나면 곱셈에 영향을 주지 않는 값인 1을 리턴한다.

    • 완전 범위 안이면 return seg[node] 실행.

    • 걸쳐 있다면 왼쪽, 오른쪽 서브 트리 값을 곱한 후 1000000007을 나눈 값을 리턴한다.

Code

#include <iostream>
#include <vector>

#define ll long long

using namespace std;

int n, m, k;
vector<ll> cost;
vector<ll> seg;

ll init(int node, int left, int right)
{
    if (left == right) {
        return seg[node] = cost[left];
    }

    int mid = ((left + right) >> 1);
    ll left_value = init(node * 2, left, mid);
    ll right_value = init(node * 2 + 1, mid + 1, right);

    return seg[node] = (left_value * right_value) % (ll)1000000007;
}

ll update(int node, int left, int right, int target_index, ll new_value)
{
    if (target_index < left || right < target_index) {
        return seg[node];
    }
    if (left != right) {
        int mid = ((left + right) >> 1);
        ll left_value = update(node * 2, left, mid, target_index, new_value);
        ll right_value = update(node * 2 + 1, mid + 1, right, target_index, new_value);

        return seg[node] = (left_value * right_value) % (ll)1000000007;
    }
    else {
        return seg[node] = cost[left] = new_value;
    }
}

ll query(int node, int left, int right, int start, int end)
{
    if (end < left || right < start) {
        return 1;
    }
    if (start <= left && right <= end) {
        return seg[node];
    }

    // 단말노드가 아니라면
    if (left != right) {
        int mid = ((left + right) >> 1);
        ll left_value = query(node * 2, left, mid, start, end);
        ll right_value = query(node * 2 + 1, mid + 1, right, start, end);

        return (left_value * right_value) % (ll)1000000007;
    }
}

int main()
{
    cin >> n >> m >> k;
    cost.resize(n);

    int seg_size = 1;
    while (seg_size < n) {
        seg_size <<= 1;
    }
    seg.resize(seg_size * 2, 0);

    for (int i = 0; i < n; ++i) {
        cin >> cost[i];
    }
    init(1, 0, n - 1);
    for (int i = 0; i < m + k; ++i) {
        int case_num;
        cin >> case_num;

        // Change
        if (case_num == 1) {
            int target_index;
            ll new_value;
            cin >> target_index >> new_value;

            update(1, 0, n - 1, target_index - 1, new_value);
            // cost[target_index - 1] = new_value;
        }
        // Query
        else {
            int start, end;
            cin >> start >> end;

            cout << query(1, 0, n - 1, start - 1, end - 1) << '\n';
        }
    }
    cout << '\n';

    return 0;
}

0개의 댓글

관련 채용 정보