[C++] 2042: 구간 합 구하기

쩡우·2023년 1월 13일
0

BOJ algorithm

목록 보기
33/65

문제

어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

풀이

세그먼트 트리의 베이직 문제.

수의 변경이 일어나지 않는 수열의 부분의 합을 구할 때는 누적 합을 이용하면 되지만, 중간에 수의 변경이 빈번히 일어나는 경우에는 세그먼트 트리를 사용하여야 한다.

전산수학에서 배웠던 이분 트리를 기초로 사용한다. 수열의 각 인덱스의 숫자들은 리프 노드에 저장되고, 각 부모 노드는 2개의 자식 노드의 합을 저장한다.

init()에서는 트리를 생성한다. 재귀 방식을 이용한다. start, mid, end를 정해주고, 현재 노드에 값을 저장한 후 그 값을 리턴한다. 현재 노드의 값은 두 자식 노드의 리턴 값을 더한 값이다. start와 end가 같다면, 리프 노드이므로, 해당 수열의 index(=start=end)의 값을 노드에 저장하고 그 값을 리턴한다.

update()에서는 수열의 수를 바꾼다. 리프 노드의 수를 바꾸면 해당 노드의 모든 조상 노드의 수도 바뀌므로, 그 과정을 처리해준다. update를 실행하기 전에, 현재 arr[i]와 바꿀 값의 차이를 구해 주고, 현재 arr[i]를 수정한다. 이 차이를 해당 리프 노드와 모든 조상 노드에 더하여 트리를 수정할 것이다. 루트 노드부터 내려가면서, start와 end의 범위 내에 해당 index가 포함된다면, 그 노드는 조상 노드이므로 처음에 구했던 차를 더해준다. 그 후 리프 노드가 아니라면 각 자식 노드로 내려가 update()를 실행한다.

sum()에서는 구간의 합을 구한다. 범위의 구간을 매개변수로 준다. 목표 구간과 해당 노드의 범위가 전혀 겹치지 않는다면, 의미가 없으므로 0을 리턴한다. 해당 노드의 범위가 목표 구간 내에 완벽히 포함된다면, 해당 노드의 값은 그 노드에 속하는 모든 리프 노드의 합이므로, 더 들어갈 필요가 없다. 따라서 해당 노드의 값을 리턴한다. 두 경우가 아니라면, 해당 노드의 범위가 목표 구간 내에 포함되지는 않지만 겹치는 경우이므로, 완벽히 포함되는 범위와 포함되지 않는 범위를 구별하여 더해주기 위해 두 자식 노드에 sum을 실행한다.

코드

#include <iostream>

using namespace std;

void input_data(void);
long long init(int, int, int);
void update(int, long long, int, int, int);
long long sum(int, int, int, int, int);

int n, m, k, count;
long long arr[1000001];
long long tree[4000004];

int main(void)
{
	input_data();
	init(1, 1, n);
	while (count-- > 0)
	{
		long long a, b, c;
		cin >> a >> b >> c;
		if (a == 1)
		{
			long long difference = c - arr[b];
			arr[b] = c;
			update(b, difference, 1, 1, n);
		}
		else
			cout << sum(b, c, 1, 1, n) << '\n';
	}
	
	return (0);
}

void input_data(void)
{
	ios_base::sync_with_stdio(0);
	cin.tie(0);

	cin >> n >> m >> k;

	int i = 0;
	while (++i <= n)
		cin >> arr[i];
	count = m + k;

	return ;
}

long long init(int node, int start, int end)
{
	if (start == end)
		return (tree[node] = arr[start]);

	int mid = (start + end) / 2;
	tree[node] = init(node * 2, start, mid) + init(node * 2 + 1, mid + 1, end);
	return (tree[node]);
}

void update(int target_index, long long difference, int node, int start, int end)
{
	if (target_index < start || end < target_index)
		return ;
	tree[node] += difference;
	if (start != end)
	{
		int mid = (start + end) / 2;
		update(target_index, difference, node * 2, start, mid);
		update(target_index, difference, node * 2 + 1, mid + 1, end);
	}	

	return ;
}

long long sum(int range_left, int range_right, int node, int tree_start, int tree_end)
{
	if (tree_end < range_left || range_right < tree_start)
    	return (0);

	if (range_left <= tree_start && tree_end <= range_right)
    	return (tree[node]);

	int mid = (tree_start + tree_end) / 2;
	return (sum(range_left, range_right, node * 2, tree_start, mid) + sum(range_left, range_right, node * 2 + 1, mid + 1, tree_end));
}

어렵당 ..

profile
Jeongwoo's develop story

0개의 댓글