세그먼트 트리 구현

박진형·2021년 8월 4일
1

algorithm

목록 보기
55/111

세그먼트트리 유형이 코테에서 몇번 나오는 것을 봤기때문에 이참에 세그먼트 트리에대해서 완벽하게 알고 넘어가고자 공부를 했다.

세그먼트 트리는 이진트리로 배열의 구간에 대한 문제를 빠르게 해결할 수 있다.
이진트리이므로 O(logN)의 시간복잡도로 성능이 굉장히 좋다.

이번 기회를 통해 세그먼트 트리를 만들고, 구간의 합을 구하고, 특정 배열값을 변경하는 코드를 한번 짜봤다.

핵심은 탐색 범위를 반으로 쪼개어 왼쪽 서브트리, 오른쪽 서브트리 형식으로 탐색을한다는 것.

세그먼트 트리를 통해서
https://www.acmicpc.net/problem/2042 구간 합 구하기

이번 포스팅에서 구현한 세그먼트 트리를 응용을 해서
https://www.acmicpc.net/problem/2357 최솟값과 최댓값

등의 문제를 풀 수 있다.

최솟값과 최댓값 문제는 다음번에 포스팅을 해봐야겠다.

설명 및 구현

import java.io.*;
import java.util.StringTokenizer;

/*
	세그먼트 트리 연습
 */
public class Main {

	static int segTree[];
	static int arr[];

	static double logB(double x, double base) { return Math.log(x)/Math.log(base); }


	/*
	세그먼트 트리 만들기
	왼쪽 서브 트리, 오른쪽 서브 트리로 나눈다.
	왼쪽 서브트리의 범위 - start ~ (start + end) / 2
	오른쪽 서브트리 범위 - (start + end) / 2
	start와 end가 같을 경우 종료를 해준다(리프 노드).
	 */
	static int make_seg(int node, int start, int end)
	{
		if(start == end)
			return segTree[node] = arr[start];
		segTree[node] =make_seg(node * 2,start, (start + end)/2)+ make_seg(node * 2 + 1,(start+end)/2+1,end);
		return segTree[node];
	}

	/*
	구간합 구하기
	세 가지 경우로 나뉜다.
	1. 구하고자 하는 범위안에 현재 탐색하고 있는 범위가 완전히 포함될 때
	2. 구하고자 하는 범위안에 현재 탐색하고 있는 범위가 겹치는 부분이 하나도 없을 때
	3. 일부만 겹칠 때

	1번 경우에는 현재 탐색하고 있는 범위를 반환하면 됨.
	2번 경우에는 필요 없는 탐색이므로 종료하면 됨.
	3번 경우에는 더 깊은 탐색을 해보면 됨.
	 */
	static int sum_seg(int node, int start, int end, int left, int right)
	{
		/*
		node 		- 현재 세그먼트 트리의 노드 번호
		start, end	- 현재 탐색하고 있는 범위
		left, right	- 구간합을 구할 범위
		 */

		//2번 케이스
		if (left > end || right < start) return 0;
		//1번 케이스
		if (left <= start && end <= right)
			return segTree[node];
		//3번 케이스
		int sum = sum_seg(node * 2, start, (start + end) / 2 , left, right) +
			sum_seg(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
		return sum;
	}

	/*
	값 바꾸기

	두 가지 경우로 나뉜다.
	1. 바꾸고자 하는 index가 범위에 포함되지 않을경우
	2. 바꾸고자 하는 index가 범위에 포함될 경우

	1번 경우에는 더 이상 탐색할 필요가 없다.
	2번 경우에는 계속 탐색해야 한다.

	예를 들어
	서브 트리의 부모노드의 값과 바꾸고자하는 값의 차이가 3이라면 그 서브트리에 속한 모든 노드들의 값도 3차이만큼 변경시켜줘야 한다.

	 */
	static void update_seg(int node, int start, int end, int idx, int diff)
	{
		//1번 경우
		if(idx < start || idx > end) return;
		segTree[node] += diff;
		//끝까지 왔다면 더 이상 탐색할 필요가 없다.
		if(start==end)
			return;
		update_seg(node * 2, start, (start + end)/2,idx,diff);
		update_seg(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
	}
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));

		int n = Integer.parseInt(br.readLine());
		int treeHeight = (int)Math.ceil(logB(n,2));
		int treeSize = (int) Math.pow(2,treeHeight+1);
		arr = new int[n+1];
		segTree = new int[treeSize+2];
		StringTokenizer st = new StringTokenizer(br.readLine());
		for(int i=0;i<n;i++)
		{
			arr[i] = Integer.parseInt(st.nextToken());
		}
		make_seg(1,0,n-1);

		/*
		예제
		n = 10
		arr = {75 30 100 38 50 51 52 20 81 5}
		 */
		//바꾸기 전의 arr[0] ~ arr[9] 까지의 구간 합 = 293
		System.out.println(sum_seg(1, 0, n - 1,0,4));

		//arr[0]을 6으로 바꾼다.
		update_seg(1,0,n-1,3,6 - arr[0] );

		//바꾸고난 후의 arr[0] ~ arr[9] 까지의 합 224
		System.out.println(sum_seg(1, 0, n - 1,0,4));
	}

}

0개의 댓글