[C++] 백준 8462번: 배열의 힘

be_clever·2022년 3월 23일
0

Baekjoon Online Judge

목록 보기
127/172

문제 링크

8462번: 배열의 힘

문제 요약

배열의 부분 배열 안의 자연수 s가 있을 때, KsK_s는 이 부분 배열 안의 s의 갯수이다. 부분 배열의 힘이란, Ks×Ks×sK_s \times K_s \times s를 의미한다. 배열의 부분 배열의 범위가 주어지면, 각 부분 배열의 힘을 구해야 한다.

접근 방법

입력되는 수의 범위가 100만까지이기 때문에, 크기가 100만인 cnt 배열을 선언해줬습니다. cnt 배열은 현재 구간 내에, 해당 값을 가진 원소의 수가 몇개인지를 저장하고 있습니다.

만약에 구간에 새로운 원소 s가 들어오게 된다면, s에 대한 Ks×Ks×sK_s \times K_s \times s 값은 바뀌게 됩니다.

따라서, 현재의 합에서 Ks×Ks×sK_s \times K_s \times s를 빼준 다음에, KsK_s를 1 증가시켜주고 Ks×Ks×sK_s \times K_s \times s를 다시 합에 더해주면 됩니다. 여기서 KsK_s는 cnt 배열에 저장되어 있습니다.

원소를 빼야하는 경우는 반대로 해줘야 합니다.

이러한 점을 이용해서 mo's algorithm으로 적절히 구현을 해주면 쉽게 풀립니다. 단, 연산 시에 int 범위를 벗어나는 값이 나올 수 있어 오버플로가 발생할 수 있기 때문에 주의해야 합니다.

코드

#include <bits/stdc++.h>

using namespace std;

const int MAX = 100001, RANGE = 1000001;
int n, t, sqrtN, arr[MAX], cnt[RANGE];
long long sum, ans[MAX];

struct Query {
	int left, right, pos;

	bool operator<(const Query& a) {
		if (left / sqrtN == a.left / sqrtN)
			return right < a.right;
		return left / sqrtN < a.left / sqrtN;
	}
};

void add(int idx) {
	sum -= (long long)cnt[arr[idx]] * (long long)cnt[arr[idx]] * (long long)arr[idx];
	cnt[arr[idx]]++;
	sum += (long long)cnt[arr[idx]] * (long long)cnt[arr[idx]] * (long long)arr[idx];
}

void sub(int idx) {
	sum -= (long long)cnt[arr[idx]] * (long long)cnt[arr[idx]] * (long long)arr[idx];
	cnt[arr[idx]]--;
	sum += (long long)cnt[arr[idx]] * (long long)cnt[arr[idx]] * (long long)arr[idx];
}

int main(void) {
	ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);

	cin >> n >> t;
	vector<Query> query(t);
	sqrtN = sqrt(n);

	for (int i = 1; i <= n; i++)
		cin >> arr[i];

	for (int i = 0; i < t; i++) {
		cin >> query[i].left >> query[i].right;
		query[i].pos = i;
	}

	sort(query.begin(), query.end());

	int l = query[0].left, r = query[0].right;

	for (int i = l; i <= r; i++) add(i);

	ans[query[0].pos] = sum;
	for (int i = 1; i < t; i++)
	{
		while (l < query[i].left) sub(l++);
		while (l > query[i].left) add(--l);
		while (r < query[i].right) add(++r);
		while (r > query[i].right) sub(r--);
		ans[query[i].pos] = sum;
	}

	for (int i = 0; i < t; i++)
		cout << ans[i] << '\n';
	
	return 0;
}
profile
똑똑해지고 싶어요

0개의 댓글