코드
import sys
input = sys.stdin.readline
N, M = map(int, input().split())
arr = []
for _ in range(N):
arr.append(list(map(int, input().split())))
for i in range(N):
for j in range(N):
a = b = c = d = 0
a = 0 if (i-1 < 0 or j-1 < 0) else arr[i-1][j-1]
b = 0 if i-1 < 0 else arr[i-1][j]
c = 0 if j-1 < 0 else arr[i][j-1]
d = arr[i][j]
arr[i][j] = (b + c + d) - a
for _ in range(M):
x1, y1, x2, y2 = map(lambda x: x-1, map(int, input().split()))
a = 0 if (x1-1 < 0 or y1-1 < 0) else arr[x1-1][y1-1]
b = 0 if x1-1 < 0 else arr[x1-1][y2]
c = 0 if y1-1 < 0 else arr[x2][y1-1]
d = arr[x2][y2]
print((a+d)-(b+c))
결과
풀이 방법
- 표에서 직사각형 구간의 합을 구하는 문제이다.
- 기존 표의 (x,y)좌표 각각에 대해 (0,0)~(x,y) 구간의 합을 구해놓으면 직사각형 (x1,y1)~(x2,y2) 구간합을 O(1)의 연산으로 구할 수 있다.
- 컴퓨터그래픽스 교과목 시간에 구간합을 빠르게 구하는 방법을 공부한 적이 있어서 실마리가 기억났던 문제이다.