좋아! 이번엔 **동적 계획법이랑 함께 자주 나오는 자료구조인 세그먼트 트리(Segment Tree)**를 아주 간단한 예제와 함께 보여줄게.
정수 배열이 주어진다.
이 배열에서 구간 합을 구하는 프로그램을 세그먼트 트리로 구현해보자.
배열: [1, 2, 3, 4]
질문: 1번 인덱스부터 2번 인덱스까지의 합은?
👉 정답: 2 + 3 = 5
public class Main {
static int[] arr = {1, 2, 3, 4}; // 원본 배열
static int[] tree = new int[16]; // 세그먼트 트리 배열 (충분히 크게)
public static void main(String[] args) {
int n = arr.length;
build(1, 0, n - 1); // 세그먼트 트리 생성
int left = 1, right = 2; // 구간 [1, 2]의 합을 구하자
System.out.println(query(1, 0, n - 1, left, right)); // 정답 출력
}
static void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
tree[node] = tree[2 * node] + tree[2 * node + 1];
}
}
static int query(int node, int start, int end, int left, int right) {
if (right < start || end < left) {
return 0; // 겹치지 않음
}
if (left <= start && end <= right) {
return tree[node]; // 완전히 포함됨
}
int mid = (start + end) / 2;
int l = query(2 * node, start, mid, left, right);
int r = query(2 * node + 1, mid + 1, end, left, right);
return l + r;
}
}
static int[] arr = {1, 2, 3, 4};arr[0] = 1, arr[1] = 2, arr[2] = 3, arr[3] = 4static int[] tree = new int[16];4 * n크기만큼 잡음.int n = arr.length;build(1, 0, n - 1);node = 1 → 트리의 루트start = 0, end = 3 → 전체 구간build() 함수 동작 흐름 (상황 예시)트리 구조 그림:
[1]
/ \
[2] [3]
/ \ / \
[4] [5] [6] [7]
0 1 2 3
세그먼트 트리에 저장될 값:
tree[4] = arr[0] = 1
tree[5] = arr[1] = 2
tree[2] = tree[4] + tree[5] = 3
tree[6] = arr[2] = 3
tree[7] = arr[3] = 4
tree[3] = tree[6] + tree[7] = 7
tree[1] = tree[2] + tree[3] = 10
query(1, 0, 3, 1, 2);
node 1 → 범위 걸침 → 왼쪽, 오른쪽 자식 탐색
왼쪽: node 2 = [0,1]
오른쪽: node 3 = [2,3]
결과: 2 + 3 = 5
5
| 항목 | 내용 |
|---|---|
| 세그먼트트리 사용 이유 | 구간합을 빠르게 구하려고 |
build() | 트리 만들기 (재귀적으로) |
query() | 주어진 범위의 합 구하기 |
| 입력 예시 | [1, 2, 3, 4], 쿼리 [1, 2] |
| 출력 | 5 |
좋아! 너가 지정한 이 build() 함수는 세그먼트 트리를 만드는 핵심 부분이야.
그 안에서 재귀 호출을 통해 트리를 완성시키는 구조인데,
초보자가 이해하려면 진짜 "디버깅 수준으로 천천히" 해설해줘야 해.
static void build(int node, int start, int end)
arr[start ~ end] 구간에 대해node번째 위치에| 파라미터 | 뜻 |
|---|---|
node | 트리의 인덱스 번호 (루트는 1번부터 시작) |
start | 배열에서 현재 노드가 담당하는 구간의 시작 인덱스 |
end | 배열에서 현재 노드가 담당하는 구간의 끝 인덱스 |
arr = {1, 2, 3, 4}우리가 만드는 트리는 다음과 같은 구조로 만들어짐:
node 1 : [0~3]
/ \
node 2:[0~1] node 3:[2~3]
/ \ / \
node 4:[0] node5:[1] node6:[2] node7:[3]
if (start == end)arr[0], arr[1]처럼 하나만 담당하는 구간처음엔 build(1, 0, 3)부터 시작
→ 이후 쪼개지며 build(4, 0, 0)이 호출됨
→ 이때 start == end == 0이므로 리프노드 도달!
tree[node] = arr[start];build(4, 0, 0); // arr[0] = 1
→ tree[4] = arr[0] = 1;
| 트리 인덱스 | 값 |
|---|---|
| tree[4] | 1 |
else {start != end인 경우→ 이 노드는 하나 이상의 원소를 포함하므로,
→ 왼쪽 자식/오른쪽 자식으로 쪼개서 재귀 호출을 해야 돼
int mid = (start + end) / 2;build(1, 0, 3); → mid = (0 + 3) / 2 = 1
→ 이제 구간 [0,1], [2,3]로 나눌 수 있음
build(2 * node, start, mid);2 * node는 현재 노드의 왼쪽 자식 번호build(1, 0, 3); → 왼쪽 자식: build(2, 0, 1);
→ 이 호출도 다시 쪼개져서:
build(4, 0, 0) → tree[4] = arr[0] = 1build(5, 1, 1) → tree[5] = arr[1] = 2build(2 * node + 1, mid + 1, end);2 * node + 1은 현재 노드의 오른쪽 자식 번호build(1, 0, 3); → 오른쪽 자식: build(3, 2, 3);
→ 그 안에서 build(6,2,2), build(7,3,3)
→ 각각 tree[6] = arr[2] = 3, tree[7] = arr[3] = 4
tree[node] = tree[2 * node] + tree[2 * node + 1];tree[4] = 1, tree[5] = 2 → tree[2] = 1 + 2 = 3tree[6] = 3, tree[7] = 4 → tree[3] = 3 + 4 = 7tree[1] = 3 + 7 = 10| 트리 노드 | 의미 | 값 |
|---|---|---|
| tree[1] | 전체 [0~3] | 10 |
| tree[2] | [0~1] | 3 |
| tree[3] | [2~3] | 7 |
| tree[4] | [0] | 1 |
| tree[5] | [1] | 2 |
| tree[6] | [2] | 3 |
| tree[7] | [3] | 4 |
start == end면 leaf node → 그대로 복사tree[node]에 저장🔸 이 줄들이 왜 이 위치에 있어야 하나요?
좋아! 이제 build()로 세그먼트 트리를 만들었으니,
이제는 그 트리를 이용해서 특정 구간의 합을 빠르게 구하는 query() 함수도 디버깅 수준으로 한 줄 한 줄 완벽하게 해설해줄게.
static int query(int node, int start, int end, int left, int right) {
if (right < start || end < left) {
return 0;
}
if (left <= start && end <= right) {
return tree[node];
}
int mid = (start + end) / 2;
int l = query(2 * node, start, mid, left, right);
int r = query(2 * node + 1, mid + 1, end, left, right);
return l + r;
}
세그먼트 트리를 이용해서
arr[left] ~ arr[right]구간의 합을 구하는 함수
O(log n)시간에 구간합을 찾는 게 목표야
arr = {1, 2, 3, 4}
query(1, 0, 3, 1, 2) // arr[1] + arr[2] = 2 + 3 = 5
| 매개변수 | 설명 |
|---|---|
node | 현재 트리 노드 번호 (루트는 1) |
start | 이 노드가 담당하는 배열의 시작 인덱스 |
end | 이 노드가 담당하는 배열의 끝 인덱스 |
left | 구하고 싶은 구간의 시작 인덱스 |
right | 구하고 싶은 구간의 끝 인덱스 |
if (right < start || end < left) { return 0; }[start, end][left, right]if (left <= start && end <= right) { return tree[node]; }[start ~ end][left ~ right]에 완전히 포함된 경우int mid = (start + end) / 2;int l = query(2 * node, start, mid, left, right);[start, mid]2 * nodeint r = query(2 * node + 1, mid + 1, end, left, right);[mid+1, end]2 * node + 1return l + r;query(1, 0, 3, 1, 2)| node | 구간 | 값 |
|---|---|---|
| 1 | [0~3] | 10 |
| 2 | [0~1] | 3 |
| 3 | [2~3] | 7 |
| 4 | [0] | 1 |
| 5 | [1] | 2 |
| 6 | [2] | 3 |
| 7 | [3] | 4 |
query(1, 0, 3, 1, 2)
├─ query(2, 0, 1, 1, 2)
│ ├─ query(4, 0, 0, 1, 2) → 겹치지 않음 → 0
│ ├─ query(5, 1, 1, 1, 2) → 완전 포함 → tree[5] = 2
│ └─ 합계: 0 + 2 = 2
├─ query(3, 2, 3, 1, 2)
│ ├─ query(6, 2, 2, 1, 2) → 완전 포함 → tree[6] = 3
│ ├─ query(7, 3, 3, 1, 2) → 겹치지 않음 → 0
│ └─ 합계: 3 + 0 = 3
└─ 최종 반환: 2 + 3 = **5**
System.out.println(query(1, 0, 3, 1, 2)); // 출력: 5
| 조건 | 행동 |
|---|---|
[left ~ right]와 전혀 겹치지 않음 | 0 반환 |
[start ~ end]가 완전히 포함됨 | tree[node] 바로 반환 |
| 일부만 겹침 | 왼쪽/오른쪽 나눠서 합쳐 계산 |