힙(Heap)

Hyeon·2022년 9월 30일
0

자료구조

목록 보기
3/3
post-thumbnail
post-custom-banner

힙(Heap)

부모가 자식보다 항상 큰(또는 작은) key값을 가지는 완전 이진 트리이다.

  • 최대 힙 : 부모 \ge 자식
  • 최소 힙 : 부모 \le 자식

최대 힙 구현

완전 이진 트리를 1차원 배열로 구현하기 위해, 다음과 같은 방법을 사용할 수 있다.

  • 첫번째 index(root 노드)는 1부터
  • 부모노드 index = 자식노드 index // 2 (나눗셈의 몫)
  • 왼쪽 자식노드 index = 부모노드 index ×\times 2
  • 오른쪽 자식노드 index = 부모노드 index ×\times 2 ++ 1

push()

최대 힙에 값 n을 저장하는 방법은 아래와 같다.

  1. 배열의 끝에 n을 저장한다.
  2. 부모 노드와 n을 비교하여, n이 부모 노드의 값보다 더 크면 위치를 바꾼다.
  3. '2번'을 반복하다가,
    n이 부모 노드의 값 이하이거나 root 노드에 도달하면 멈춘다.

[ 코드 ]

max_heap = [None] * (4 * n)
max_heap[0] = 999999999999
last_index = 0

def push(num):
    global max_heap
    global last_index
	
    # 배열의 끝에 값을 저장하기위해 현재까지 기록된 마지막 index를 1증가시킨다.
    last_index += 1
    
    # 증가된 마지막 index에 저장할 값을 넣어준다.
    max_heap[last_index] = num
    
    # 부모 노드와 값을 비교하며 위치를 결정한다.
    cur = last_index
    while cur != 0:
    	# index 0은 사용하지 않는다.
        # 저장한 값이 부모 노드보다 작을 때 까지 반복한다.
		if cur//2 != 0 and max_heap[cur//2] < max_heap[cur]:
            max_heap[cur], max_heap[cur//2] = max_heap[cur//2], max_heap[cur]
            cur //= 2
        else:
            break

pop()

최대 힙에서 가장 큰 값은 root 노드(1번 index)에 있다.
값을 꺼내고 삭제하면 힙을 유지하기 위해 다시 값을 정렬해 주어야 한다.

값을 꺼낸 뒤 힙을 정렬시키는 과정은 다음과 같다.

  1. 마지막 index에 저장된 n을 index 1에 저장한다.
  2. max(왼쪽 자식 노드 key값, 오른쪽 자식 노드 key값)n보다 작다면
    서로 위치를 바꿔준다.
  3. '2번'을 반복하다가, n이 두 자식 노드의 key값보다 크거나,
    마지막 index에 도달해서 더 이상 진행할 수 없다면 중단한다.

[ 코드 ]

max_heap = [None] * (4 * n)
max_heap[0] = 999999999999
last_index = 0

def pop():
    global max_heap
    global last_index
    
    if last_index == 0:
        return 0
    else:
        top = max_heap[1]
        cur = 1
        max_heap[cur] = max_heap[last_index]
        max_heap[last_index] = None
        last_index -= 1
        while cur <= last_index:
            left = cur * 2
            right = cur * 2 + 1
            next_idx = get_next(cur, left, right)
            if max_heap[next_idx] > max_heap[cur]:
                max_heap[next_idx], max_heap[cur] = max_heap[cur], max_heap[next_idx]
                cur = next_idx
            else:
                break
        return top


def get_next(cur, left, right):
    global last_index
    global max_heap

    if right <= last_index:
        return left if max_heap[left] > max_heap[right] else right
    elif left <= last_index:
        return left
    else:
        return cur

def get_next(cur, left, right) 는 반복문 내부 조건문을 줄이기 위해 임의로 만들어준 함수로,
좌, 우 자식 노드의 index마지막 index보다 작은지 확인하고,
조건에 부합하는 indexreturn해준다.
조건문을 모두 통과하면 현재 nindex를 돌려주어
반복문 내부 조건에 의해 반복문이 종료된다.

문제 풀이

문제 : BOJ 11279 최대 힙

문제의 조건에 따라, 최대 힙에 저장되는 N은 1N100,0001 \le N \le 100,000 이며,
00이 입력되면 최대값을 출력 후 제거해야한다.

[ 전체 코드 ]

import sys

n = int(sys.stdin.readline().rstrip())

max_heap = [None] * (n+2)
last_index = 0

def add(num):
    global max_heap
    global last_index

    last_index += 1
    max_heap[last_index] = num
    cur = last_index
    while cur != 0:
        if cur//2 != 0 and max_heap[cur//2] < max_heap[cur]:
            max_heap[cur], max_heap[cur//2] = max_heap[cur//2], max_heap[cur]
            cur //= 2
        else:
            break    
    

def pop():
    global max_heap
    global last_index
    if last_index == 0:
        return 0
    else:
        top = max_heap[1]
        cur = 1
        max_heap[cur] = max_heap[last_index]
        max_heap[last_index] = None
        last_index -= 1
        while cur <= last_index:
            left = cur * 2
            right = cur * 2 + 1
            next_idx = get_next(cur, left, right)
            if max_heap[next_idx] > max_heap[cur]:
                max_heap[next_idx], max_heap[cur] = max_heap[cur], max_heap[next_idx]
                cur = next_idx
            else:
                break
        return top


def get_next(cur, left, right):
    global last_index
    global max_heap

    if right <= last_index:
        return left if max_heap[left] > max_heap[right] else right
    elif left <= last_index:
        return left
    else:
        return cur


while n > 0:
    k = int(sys.stdin.readline().rstrip())
    if k == 0:
        x = pop()
        sys.stdout.write(f'{x}\n')
    else:
        add(k)
    n -= 1
profile
그럼에도 불구하고
post-custom-banner

0개의 댓글