Min-max heap
을 구현하여 제출했지만 계속 시간초과.. Min-max heap
에 대한 내용은 밑에 있다.
[정답 코드]
import sys
import heapq
input = sys.stdin.readline
def insert(min_list, max_list, list_dict, num):
heapq.heappush(min_list, num)
heapq.heappush(max_list, (-num, num))
if list_dict.get(num) == None:
list_dict[num] = 1
else:
list_dict[num] += 1
def delete_min(min_list, list_dict):
while min_list:
pop_item = heapq.heappop(min_list)
if list_dict.get(pop_item) > 0:
list_dict[pop_item] -= 1
break
def delete_max(max_list, list_dict):
while max_list:
pop_item = heapq.heappop(max_list)[1]
if list_dict.get(pop_item) > 0:
list_dict[pop_item] -= 1
break
t = int(input())
for i in range(t):
min_list = []
max_list = []
list_dict = {}
k = int(input())
for j in range(k):
command = input().rstrip()
if command[0] == 'I':
insert(min_list, max_list, list_dict, int(command[2:]))
elif command[0] == 'D':
if int(command[2:]) == -1:
delete_min(min_list, list_dict)
else:
delete_max(max_list, list_dict)
ans = [key for key, value in list_dict.items() if value > 0]
if ans:
ans.sort()
print(ans[-1], ans[0])
else:
print("EMPTY")
[풀이]
list_dict[num] = 1
로 추가해준다. dict에 이미 존재한다면(중복되는 값), += 1을 해준다.(최소 힙과 최대 힙의 데이터 동기화를 위해 처리하는 과정이다)defaultdict
으로 하는게 편하지 않았을까if list_dict.get(pop_item) > 0:
을 통해 item이 정말 heap에 있는지를 검사한다. 이를 반복하여 최솟값, 최댓값을 뽑아낸다.[오류 해결]
채점해보니 처음엔 30%에서, 그 이후엔 90%에서 틀렸습니다가 나왔다. 처음엔 input한 command를 처리하는 과정에서, 두 번째는 답을 print하는 과정에서 오류가 있었다.
command[2] == '-1'
->int(command[2:]) == -1
차라리 input을 split()으로 받고, int()처리하는 것이 가시적으로 좋을 수 있다.
정답코드의 print하는 과정은 다음과 같다.
ans = [key for key, value in list_dict.items() if value > 0]
if ans:
ans.sort()
print(ans[-1], ans[0])
else:
print("EMPTY")
처음에는 delete_min
과 delete_max
가 모두 최솟값 최댓값 (힙이 비어있다면 -1)을 반환하였다. 즉, 반환값을 기준으로 result를 print하였는데 90%에서 계속 틀렸습니다가 나왔다.
min_item = delete_min(min_list, list_dict)
max_item = delete_max(max_list, list_dict)
if min_item == -1 and max_item == -1:
print('EMPTY')
elif min_item == -1:
print(f'{max_item} {max_item}')
elif max_item == -1:
print(f'{min_item} {min_item}')
else:
print(f'{max_item} {min_item}')
원인..음..
[적용 자료구조 및 알고리즘]
[구현 코드]
import sys
def check_level(index): # index starts from 1
# min_level -> 0
# max_level -> 1
cnt = 0
while index != 1:
index //= 2
cnt += 1
return 0 if cnt % 2 == 0 else 1
def push_up_max(hp, i):
while (i//2)//2 != 0 and hp[i] > hp[(i//2)//2]:
hp[i], hp[(i//2)//2] = hp[(i//2)//2], hp[i]
i = (i//2)//2
def push_up_min(hp, i):
while (i//2)//2 != 0 and hp[i] < hp[(i//2)//2]:
hp[i], hp[(i//2)//2] = hp[(i//2)//2], hp[i]
i = (i//2)//2
def push_up(hp, i):
# adjust the heap after appending the value the end of the array
if i != 1: # not the root
if check_level(i) == 0: # min_level
if hp[i] > hp[i//2]:
hp[i], hp[i//2] = hp[i//2], hp[i]
push_up_max(hp, i//2)
else:
push_up_min(hp, i)
else:
if hp[i] < hp[i//2]: # max_level
hp[i], hp[i//2] = hp[i//2], hp[i]
push_up_min(hp, i//2)
else:
push_up_max(hp, i)
def find_the_smallest_child(hp, i):
if len(hp) - 1 >= i*4:
temp = hp[i*4]
index = i*4
for j in range(i*4, len(hp)):
if temp >= hp[j]:
temp = hp[j]
index = j
if j > i*4 + 3:
break
return index
elif len(hp) - 1 >= i*2 + 1:
return i*2 + 1 if hp[i*2] > hp[i*2 + 1] else i*2
else:
return i*2
def find_the_largest_child(hp, i):
index = 0
if len(hp) - 1 >= i*4:
temp = hp[i*4]
for j in range(i*4, len(hp)):
if temp <= hp[j]:
temp = hp[j]
index = j
if j > i*4 + 3:
break
return index
elif len(hp) - 1 >= i*2 + 1:
return i*2 + 1 if hp[i*2] < hp[i*2 + 1] else i*2
else:
return i*2
def push_down_min(hp, m):
while len(hp) - 1 >= m*2:
i = m
m = find_the_smallest_child(hp, i)
# print(f'the smallest child {m}')
if hp[m] < hp[i]:
if m >= i*4:
hp[m], hp[i] = hp[i], hp[m]
if hp[m] > hp[m//2]:
hp[m], hp[m//2] = hp[m//2], hp[m]
else:
hp[m], hp[i] = hp[i], hp[m]
else:
break
def push_down_max(hp, m):
while len(hp) - 1 >= m*2:
i = m
m = find_the_largest_child(hp, i)
if hp[m] > hp[i]:
if m >= i*4:
hp[m], hp[i] = hp[i], hp[m]
if hp[m] < hp[m//2]:
hp[m], hp[m//2] = hp[m//2], hp[m]
else:
hp[m], hp[i] = hp[i], hp[m]
else:
break
def push_down(hp, i):
if check_level(i) == 0: # min_level
push_down_min(hp, i)
else:
push_down_max(hp, i)
def remove_min(hp):
temp = hp[1]
hp[1] = hp[len(hp) - 1]
hp.pop()
push_down(hp, 1)
return temp
def remove_max(hp):
if len(hp) == 2:
max_index = 1
elif len(hp) == 3:
max_index = 2
else:
max_index = 2 if hp[2] >= hp[3] else 3
temp = hp[max_index]
hp[max_index] = hp[len(hp) - 1]
hp.pop()
push_down(hp, max_index)
return temp
t = int(sys.stdin.readline())
for i in range(t):
max_min_hp = [-1]
k = int(sys.stdin.readline())
for j in range(k):
command = sys.stdin.readline().rstrip()
if command[0] == 'I':
max_min_hp.append(int(command[2:]))
push_up(max_min_hp, len(max_min_hp) - 1)
elif command[0] == 'D':
if len(max_min_hp) == 1:
continue
if command[2] == '1':
remove_max(max_min_hp)
else: # -1
remove_min(max_min_hp)
if len(max_min_hp) > 2:
print(remove_max(max_min_hp), end=' ')
print(remove_min(max_min_hp))
elif len(max_min_hp) == 2:
temp = remove_min(max_min_hp)
print(temp, end=' ')
print(temp)
else:
print('EMPTY')
Wikipedia(Min-max heap)를 참조하여 코드를 구현했다. 제대로 구현했다면 시간복잡도는 O(logN)인데, 왜 시간초과가 나는지 모르겠다.(사실 이해하기 조금 복잡해서 review하기 귀찮다..)