[2042] 구간 합 구하기

HeeSeong·2021년 3월 21일
0

백준

목록 보기
8/79
post-thumbnail

🔗 문제 링크

https://www.acmicpc.net/problem/2042


❔ 문제 설명


어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.


⚠️ 제한사항


  • 첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다.

  • M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다.

  • 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

  • 입력으로 주어지는 모든 수는 263-2^{63}보다 크거나 같고, 26312^{63}-1보다 작거나 같은 정수이다.

  • 첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 263-2^{63}보다 크거나 같고, 26312^{63}-1보다 작거나 같은 정수이다.



💡 풀이 (언어 : Python)


처음에는 단순한 구현 문제로 생각해서 풀었는데 시간 초과 판정이 나왔다.
두번째로 DP를 이용해서 구간 합을 미리 계산하고 구간합 끼리 뺄셈으로 정답을 구해 시간 복잡도를 줄이는 코드를 작성했지만 이것도 시간초과 판정이 나왔다....

내 풀이

# 1번째 풀이
import sys

n, m, k = map(int, sys.stdin.readline().split())

soo = [0]
answer = []

for i in range(n):
    soo.append( int(sys.stdin.readline()) )

for j in range(m+k):
    nn, mm, kk = map(int, sys.stdin.readline().split())

    if nn == 1:
        soo[mm] = kk
    else:
        answer.append( sum(soo[mm:kk+1]) )

for a in answer:
    print(a)
    
    
## 2번째 풀이
import sys

n, m, k = map(int, sys.stdin.readline().split())

soo = [0 for i in range(n+1)]
ssum = [0 for i in range(n+1)]
answer = []

for i in range(1, n+1):
    soo[i] = int(sys.stdin.readline())

for i in range(1, n+1):
    ssum[i] = ssum[i-1] + soo[i]

for j in range(m+k):
    a, b, c = map(int, sys.stdin.readline().split())

    if a == 1:
        soo[b] = c
        for i in range(b, n+1):
            ssum[i] = ssum[i-1] + soo[i] 
    else:
        answer.append( ssum[c] - ssum[b-1] )

for ans in answer:
    print(ans)

이 문제는 Binary Indexed Tree (Fenwik Tree) 라고 불리는 알고리즘의 대표적 문제이다. 처음에 이해하기 어려웠는데 상당히 기발한 알고리즘인 것 같다. 구간 합 계산의 시간 복잡도를 획기적으로 줄일 수 있는 알고리즘이다.

정답 코드

import sys
input = sys.stdin.readline

# 데이터의 개수(n), 변경 횟수(m), 구간 합 계산 횟수(k)
n, m, k = map(int, input().split())

# 전체 데이터의 개수는 최대 1,000,000개
arr = [0] * (n + 1)
tree = [0] * (n + 1)

# i번째 수까지의 누적 합을 계산하는 함수
def prefix_sum(i):
    result = 0
    while i > 0:
        result += tree[i]
        # 0이 아닌 마지막 비트만큼 빼가면서 이동
        i -= (i & -i)
    return result

# i번째 수를 dif만큼 더하는 함수
def update(i, dif):
    while i <= n:
        tree[i] += dif
        i += (i & -i)

# start부터 end까지의 구간 합을 계산하는 함수
def interval_sum(start, end):
    return prefix_sum(end) - prefix_sum(start - 1)

for i in range(1, n + 1):
    x = int(input())
    arr[i] = x
    update(i, x)

for i in range(m + k):
    a, b, c = map(int, input().split())
    # 업데이트(update) 연산인 경우
    if a == 1:
        update(b, c - arr[b]) # 바뀐 크기(dif)만큼 적용
        arr[b] = c
    # 구간 합(interval sum) 연산인 경우
    else:
        print(interval_sum(b, c))
profile
끊임없이 성장하고 싶은 개발자

0개의 댓글