[C++] 백준 1280: 나무 심기

Cyan·2024년 5월 28일
0

코딩 테스트

목록 보기
154/166

백준 1280: 나무 심기

문제 요약

1번부터 N번까지 번호가 매겨져 있는 N개의 나무가 있다. i번 나무는 좌표 X[i]에 심어질 것이다.

동호는 나무를 1번 나무부터 차례대로 좌표 X[i]에 심으려고 한다. 1번 나무를 심는 비용은 없고, 각각의 나무를 심는데 드는 비용은 현재 심어져있는 모든 나무 까지 거리의 합이다. 만약 3번 나무를 심는다면, 1번 나무와의 거리 + 2번 나무와의 거리가 3번 나무를 심는데 드는 비용이다.

2번 나무부터 N번 나무까지를 심는 비용의 곱을 출력하는 프로그램을 작성하시오.

문제 분류

  • 자료 구조
  • 세그먼트 트리

문제 풀이

우선, 어떠한 나무를 심을 경우 해당 나무를 심는 비용에 대해 알아보자.
가령 1, 2, 7, 9의 좌표에 나무가 1그루씩 있고, 나는 5번 좌표에 나무를 심으려한다.
이를 수식으로 나타내면, (51)+(52)+(75)+(95)(5-1)+(5-2)+(7-5)+(9-5)이다.
이에 대해 ((52)(1+2))+((9+7)(52))((5*2)-(1+2))+((9+7)-(5*2))로 나타낼 수 있다.
즉, 5의 좌표에 나무를 심을 경우, 5의 좌표에 왼쪽에 있는 나무의 개수와 그 누적 합, 오른쪽에 있는 나무의 개수와 그 누적합의 합으로 표현할 수 있다.
일반화하자면, x의 좌표에 나무를 심는다고 한다면 그 나무의 비용은
((x*cnt(0,x-1)-sum(0,x-1))+(sum(x+1, MAX)-x*cnt(x+1,MAX)))
가 된다.
해당 범위의 개수와 누적합을 구하는 세그먼트 트리를 각각 구축하면 된다.

여기서 자신이 나무를 심은 자리에 중복하여 심을 수 있으므로, 누적합의 update()seg[idx] = loc;이 아닌 seg[idx] += loc;로 작성한다. 개수를 구하는 세그먼트 트리의
update2()도 마찬가지로 seg2[idx]++;로 누적해준다.

또한 입력은 0부터 200,000미만의 자연수이므로 MAX199999가 될 것이다.

풀이 코드

#include <stdio.h>
#include <iostream>
#include <algorithm>
#define MAX 199999

using namespace std;

long long seg[800000], seg2[800000];

long long update(int l, int r, int idx, int loc)
{
	if (loc < l || loc > r) return seg[idx];
	if (l == r) {
		if (l == loc) seg[idx] += loc;
	}
	else {
		int mid = (l + r) / 2;
		seg[idx] = update(l, mid, idx * 2 + 1, loc) + update(mid + 1, r, idx * 2 + 2, loc);
	}
	return seg[idx];
}

long long update2(int l, int r, int idx, int loc)
{
	if (loc < l || loc > r) return seg2[idx];
	if (l == r) {
		if (l == loc) seg2[idx]++;
	}
	else {
		int mid = (l + r) / 2;
		seg2[idx] = update2(l, mid, idx * 2 + 1, loc) + update2(mid + 1, r, idx * 2 + 2, loc);
	}
	return seg2[idx];
}

long long sum(int start, int end, int l, int r, int idx)
{
	if (r < start || l > end) return 0;
	if (start <= l && r <= end) return seg[idx];
	int mid = (l + r) / 2;
	return sum(start, end, l, mid, idx * 2 + 1) + sum(start, end, mid + 1, r, idx * 2 + 2);
}

long long cnt(int start, int end, int l, int r, int idx)
{
	if (r < start || l > end) return 0;
	if (start <= l && r <= end) return seg2[idx];
	int mid = (l + r) / 2;
	return cnt(start, end, l, mid, idx * 2 + 1) + cnt(start, end, mid + 1, r, idx * 2 + 2);
}

int main()
{
	int n, in;
	long long res = 1, temp;
	scanf("%d", &n);
	for (int i = 0; i < n; i++) {
		scanf("%d", &in);
		update(0, MAX, 0, in);
		update2(0, MAX, 0, in);
		temp = (cnt(0, in - 1, 0, MAX, 0) * in) - sum(0, in - 1, 0, MAX, 0)
			+ sum(in + 1, MAX, 0, MAX, 0) - (cnt(in + 1, MAX, 0, MAX, 0) * in);
		if (i > 0) {
			res *= temp % 1000000007;
			res %= 1000000007;
		}
	}
	cout << res;

	return 0;
}

0개의 댓글