[백준] 2042번 구간 합 구하기

HL·2021년 5월 24일
0

백준

목록 보기
94/104
post-custom-banner

문제 링크

https://www.acmicpc.net/problem/2042

문제 설명

  • 숫자 리스트 주어짐 (100만개)
  • 수정, 구간 합 구하기 반복 (2만번)

풀이

  • 세그먼트 트리
  • init 구현
    • 리프가 아닐때 왼쪽 노드 생성
    • 리프가 아닐때 오른쪽 노드 생성
    • 재귀
    • O(NlogN)
  • update 구현
    • 루트 노드부터 리프 노드까지 재귀
    • 리프 노드부터 수정
    • 유니온 파인드처럼
    • O(logN)
  • sum 구현
    • 구간을 나누어 재귀적으로 return
    • O(logN)

후기

  • 처음 공부하는 자료구조여서 이것저것 찾아봤는데
  • 이해는 되는데 구현이 너무 어려웠다
  • 그래서 그냥 혼자 구현해봤다

파이썬 코드

import sys


class Node:
    def __init__(self, start, end, value):
        self.start = start
        self.end = end
        self.value = value
        self.left = None
        self.right = None


def solution():

    # 입력 받기
    read = sys.stdin.readline
    n, m, k = map(int, read().split())
    numbers = [int(read()) for _ in range(n)]
    commands = [list(map(int, read().split())) for _ in range(m + k)]

    root = Node(0, n, sum(numbers))
    if n >= 2:
        init_child(root, numbers)

    for a, b, c in commands:
        # b번째 수를 c로 변경
        if a == 1:
            update(root, b-1, c)
        # b부터 c까지 합 구하기
        elif a == 2:
            print(get_sum(root, b-1, c-1))


def init_child(curr, numbers):

    mid = (curr.start + curr.end) // 2

    left_sum = sum(numbers[curr.start:mid])
    curr.left = Node(curr.start, mid, left_sum)

    right_sum = sum(numbers[mid:curr.end])
    curr.right = Node(mid, curr.end, right_sum)

    if mid - curr.start > 1:
        init_child(curr.left, numbers)
    if curr.end - mid > 1:
        init_child(curr.right, numbers)


def update(curr, b, c):

    # 리프 노드일 때
    if curr.end - curr.start == 1:
        diff = c - curr.value
        curr.value = c
        return diff

    mid = (curr.start + curr.end) // 2
    diff = 0
    
    # 재귀적으로 자식 노드 수정
    if b < mid:
        diff = update(curr.left, b, c)
    else:
        diff = update(curr.right, b, c)

    # 자식 노드 수정 후 현재 노드 수정
    curr.value += diff
    return diff


def get_sum(curr, b, c):

    # 리프 노드일 때
    if curr.end - curr.start == 1:
        return curr.value

    mid = (curr.start + curr.end) // 2
    
    # 딱 맞을 때
    if curr.start == b and curr.end-1 == c:
        return curr.value
    # 왼쪽 + 오른쪽
    elif curr.start <= b < mid and mid <= c < curr.end:
        return get_sum(curr.left, b, mid-1) + get_sum(curr.right, mid, c)
    # 왼쪽
    elif curr.start <= c < mid:
        return get_sum(curr.left, b, c)
    # 오른쪽
    elif mid <= b < curr.end:
        return get_sum(curr.right, b, c)


solution()
profile
Frontend 개발자입니다.
post-custom-banner

0개의 댓글