[python] 백준 11505 : 구간 곱 구하기

장선규·2022년 2월 5일
0

알고리즘

목록 보기
26/40
post-custom-banner

문제 링크
https://www.acmicpc.net/problem/11505

접근

연속된 구간에서 구간 곱을 구하는 문제이다. 그리고 숫자들이 바뀌고 그것대로 값을 갱신하는 문제이므로 세그먼트 트리로 풀어야겠다고 생각했다.

풀이

arr = [int(input()) for _ in range(n)]

b = math.ceil(math.log2(n)) + 1
node_n = 1 << b
seg = [0 for _ in range(node_n)]


def make_seg(idx, s, e):
    if s == e:
        seg[idx] = arr[s]
        return seg[idx]

    mid = (s + e) // 2

    l = make_seg(idx * 2, s, mid)
    r = make_seg(idx * 2 + 1, mid + 1, e)
    seg[idx] = (l * r) % MOD

    return seg[idx]

우선 세그먼트 트리를 만든다.
왼쪽 절반 구간의 구간곱과 오른쪽 절반 구간의 구간곱을 곱하여
현재 구간곱으로 취한다.

이를 반복하면 세그먼트 트리가 완성된다.


def change(idx, s, e):
    if b - 1 < s or e < b - 1:  # 범위 밖
        return seg[idx]

    if s == e:
        seg[idx] = new
        return new

    # seg[idx] = seg[idx] // cur * new
    # a/b mod M = (a mob M)/b mod M이 성립하지 않습니다

    mid = (s + e) // 2

    l = change(idx * 2, s, mid)
    r = change(idx * 2 + 1, mid + 1, e)
    seg[idx] = (l * r) % MOD
    return seg[idx]

값을 바꿔주는 역할을 하는 change() 함수이다.
이 부분에서 WA가 나왔는데, 나머지 연산을 하는 것에서 오답이 나왔다.

처음엔 구간합에서 한 것과 같이 바꿀 값 cur을 나눠주고 새 값 new를 곱해주는 식으로 하면 될 것이라 생각했다.

그러나 나머지 연산에서 a/b mod M(a mob M)/b mod M 의 값이 다르게 나온다는 것을 깨달았다.
무슨 말인가 하면,
우리의 구간 곱이 1000000000000 (1조) 라고 하자.
이 구간 안에 어떤 수가 2였다고 치자.
이 수 2를 3으로 바꾸려고 한다.

그러면 (1조 // 2) * 3 을 해주면 될 것이다.
하지만 그러다보면 숫자가 너무 커져서 용량또한 커질 것이다.

그래서 우리는 이 1조를 MOD = 1,000,000,007 로 나누었던 것이다.
1조를 MOD 로 나누면 999993007 라는 나머지가 나오는데... 이 수는 2로 나누어지지도 않는다.

결론적으로 모듈로 연산이 된 수는 나누기 연산에서는 전혀 다른 값이 될 수 있다는 것이다.
우리가 예상한 값이 보장되지 않으므로 기존 값을 나누고 새 값을 곱하는 방법은 사용할 수 없다.


그리하여 어쩔 수 없이 해당 숫자가 포함된 범위의 모든 세그먼트 트리 노드들을 다시 만들어주는 방법을 사용하였다.

def change(idx, s, e):
    if b - 1 < s or e < b - 1:  # 범위 밖
        return seg[idx]

    if s == e:
        seg[idx] = new
        return new

    mid = (s + e) // 2

    l = change(idx * 2, s, mid)
    r = change(idx * 2 + 1, mid + 1, e)
    seg[idx] = (l * r) % MOD
    return seg[idx]

곱하기는 모듈로 연산에서 항상 그 값이 보장이 된다.


def get(idx, s, e):
    # 탐색 영역 : s~e
    if to < s or e < frm:  # 범위 밖
        return 1

    mid = (s + e) // 2
    if frm <= s and e <= to:  # 탐색 영역이 b~c 완전히 안에 있음
        return seg[idx]

    else:  # 탐색 영역이 더 큰 경우나 범위 걸친 경우
        l = get(idx * 2, s, mid)
        r = get(idx * 2 + 1, mid + 1, e)
        return (l * r) % MOD

마지막으로 구간 곱을 반환해주는 get() 함수이다.
탐색 영역이 s~e 이고, 우리가 구하고자 하는 구간 곱의 영역이 frm~to 라고 쳤을 때,
총 세가지 경우로 나눌 수 있다.


먼저 s~e 와 frm~to 가 전혀 다른 곳에 있는 경우이다.
이 경우 s~e 전 범위를 다 뒤져도 우리가 원하는 구간이 나오지 않는다. 아니 그냥 구간 자체가 다르므로 생각할 필요가 없는 것이다. 1을 반환해주자.


다음으로 탐색 영역 s~e 가 frm~to 사이에 완전히 들어가 있는 경우이다.
frm <= s~e <= to 이런 식으로 말이다.
이 경우 굳이 더 내려가 탐색하지 않고 그 자체가 다 쓰이므로 해당 구간 곱을 바로 리턴해준다.
(s~e 구간이 다 쓰여짐!)


마지막으로 탐색 영역 s~e 가 frm~to 사이에 완전히 들어가 있지 않은 경우이다.
범위가 걸쳐있는 경우 + s <= frm~to <= e 와 같은 경우이다.
이 경우 좌 우로 나누어 해당 구간 곱들을 구하고 그것을 곱한것을 반환한다.
(더 탐색할 필요가 있음!)

정답 코드

import math
import sys

sys.setrecursionlimit(10 ** 8)  # pypy 제출시 삭제!
input = lambda: sys.stdin.readline().rstrip()
# in_range = lambda y,x: 0<=y<n and 0<=x<m
MOD = 1000000007

n, m, k = map(int, input().split())
arr = [int(input()) for _ in range(n)]

b = math.ceil(math.log2(n)) + 1
node_n = 1 << b
seg = [0 for _ in range(node_n)]


def make_seg(idx, s, e):
    if s == e:
        seg[idx] = arr[s]
        return seg[idx]

    mid = (s + e) // 2

    l = make_seg(idx * 2, s, mid)
    r = make_seg(idx * 2 + 1, mid + 1, e)
    seg[idx] = (l * r) % MOD

    return seg[idx]


def change(idx, s, e):
    if b - 1 < s or e < b - 1:  # 범위 밖
        return seg[idx]

    if s == e:
        seg[idx] = new
        return new


    mid = (s + e) // 2

    l = change(idx * 2, s, mid)
    r = change(idx * 2 + 1, mid + 1, e)
    seg[idx] = (l * r) % MOD
    return seg[idx]


def get(idx, s, e):
    # 탐색 영역 : s~e
    if to < s or e < frm:  # 범위 밖
        return 1

    mid = (s + e) // 2
    if frm <= s and e <= to:  # 탐색 영역이 b~c 완전히 안에 있음
        return seg[idx]

    else:  # 탐색 영역이 더 큰 경우나 범위 걸친 경우
        l = get(idx * 2, s, mid)
        r = get(idx * 2 + 1, mid + 1, e)
        return (l * r) % MOD


make_seg(1, 0, len(arr) - 1)


for _ in range(m + k):
    a, b, c = map(int, input().split())
    if a == 1:
        cur = arr[b - 1]
        new = c
        change(1, 0, len(arr) - 1)
    else:
        frm, to = b - 1, c - 1
        print(get(1, 0, len(arr) - 1) % MOD)
profile
코딩연습
post-custom-banner

0개의 댓글