이 문제를 푸느라 시간도 너무 오래 걸렸고 멘탈도 나갔는데(!) 일단 맞았다고 떴으니 됐지 않나 생각한다. (이 포스팅을 쓸 때까지도 멘탈이 회복되지 않아서 오타가 넘쳐나고 있다. ㅋㅋㅋ)
https://www.acmicpc.net/problem/1654
우선 문제를 이해하도록 해보자.
첫째줄 입력값은 가지고 있는 랜선의 개수 K, 필요한 랜선의 개수 N이다. 둘째줄부터 입력되는 값들은 자연수로 랜선들의 길이이다. K≤N이다. 그리고 기존의 K개의 랜선으로 N개의 랜선을 만들 수 없는 경우는 없다고 가정한다고 했다.
솔직히 문제가 이해가 안 가는데 아무튼 '이진 탐색'으로 푸는 문제로 분류되어 있다.
그렇다면 예상되는 값들에서 중간값을 계속 탐색하는 방법을 통해 풀도록 코드를 작성하면 될 것이다.
탐색하는 대상은 어떤 값들일까?
우선 랜선의 길이의 최솟값은 1이다. 못 만드는 경우는 없다고 가정했으니까 더 긴 길이를 끝끝내 못 찾았다면 1cm씩 N개를 잘라도 된다.
또한 랜선의 길이의 최댓값은...
처음에는 가장 짧은 길이라고 생각했는데 아니었다.
왜냐하면 모든 랜선을 반드시 잘라서 사용해야 하는 게 아니기 때문이다.
그러므로 N=1이라면 그냥 제일 긴 랜선(그런데 K≤N
이라고 했으니까 랜선도 그냥 한 개겠지?) 하나를 안 자르고 쓰는 거고
N=2라면 두 번째로 긴 랜선의 길이만큼 제일 긴 랜선을 자를 수도 있고, 가장 긴 랜선 하나가 압도적으로 길다면 그거 하나를 반으로 자를 수도 있고...
엄청 긴 랜선과 엄청 무의미하게 짧은 랜선이 섞여 있다면 짧은 랜선들은 무시하고 긴 랜선들만 사용하는 게 답일 수도 있다.
아무튼 선택할 수 있는 범위의 랜선 길이의 최댓값은 k[]
의 최댓값이거나
또는 N이 1보다 크다면 sum(k[])//N
일 수도 있으므로 이렇게 1과 최댓값 사이에서 탐색을 실시한다.
그러면 비교하는 과정은 이렇게 된다.
임의의 값 L이 정답인지 확인하려면
k[]
)을 각각 L로 나눈 나눗셈의 몫의 합 S를 구한다S==N
이라면 L이 정답일 수도 있다. 그러나 더 긴 길이에서도 성립할 수 있으니까 현재 최댓값을 업데이트해놓고 L보다 큰 값 중에서 계속 탐색해야 한다.S<N
이라면 길이 L로는 N만큼의 랜선을 구할 수 없으니 L보다 작은 값 중에서 탐색해야 한다.S>N
이라면 길이 L로는 N보다 많은 랜선이 생기는 것이니까 L보다 큰 값 중에서 탐색해야 한다. 단! N보다 많이 자른 것도 N만큼 자른 것이 포함된다고 했다. 따라서 L보다 큰 값을 못 구할 수도 있으니까 현재의 최댓값을 일단 저장하고 L보다 큰 값 중에서 탐색한다.이렇게 해서 비교를 하는데...
여기에서 중요한 함정은 두 개가 있다.
N개보다 많이 만드는 것도 N개를 만드는 것에 포함된다.
예를 들어서 K=3
, N=3
이고 갖고 있는 랜선은 k[]= 10000,10000,1
이라고 가정해 보자.
5000을 넣어서 탐색해보자. k[]
의 각각의 값들을 나눈 몫은 각 2+2+0
으로 4
개의 랜선을 만들었다. 그러면 못 만든 거니까 5001부터 10000까지 다시 탐색해야 하나? 아니다. 어차피 5001부터 10000까지의 수로 k[]
의 값들을 나누어 보면 몫은 각각 1,1,0이니까 만들 수 있는 랜선은 2개로 줄어들어 버린다. 따라서 그렇게 탐색하면 딱 3개의 랜선을 만들 수 있는 결과를 끝까지 찾을 수 없고 최댓값을 찾지 못해서 1로 줄어든 결과가 나온다. 그런데 3개 이상의 랜선을 만들 수 있는 최대 길이는 1이 아니고 5000이다.
따라서 N개보다 많이 만든 것도 N개를 만드는 데 성공했다고 간주하고 최댓값 업데이트에 반영한 뒤에 계속 N개 이상의 랜선을 만들면서 최댓값이 늘어나게 할 수 있는지 탐색을 계속해야 한다.
다 찾아야 하면 왜 이진 탐색을 하지? 아마도 1
부터 max(k[])
까지 1씩 늘려가며 계속 찾으면 값이 매우 큰 경우에는 시간이 오래 걸릴 수 있으니까 중간값부터 시작해서 전혀 아닌 값들을 배제할 수 있도록 하기 위해서 이진 탐색을 쓰는 것 같다.
풀이의 과정은 이렇게 되었다.
우선 입력값을 받는 과정을 시작한다.
import sys
K,N=map(int,sys.stdin.readline().split())
k=[]
for i in range(K):
k.append(int(sys.stdin.readline()))
k.sort()
기본적인 이진탐색 코드부터 만들어본다.
참고문헌들:
이관용·김진욱 (2018) 『알고리즘』 서울: 한국방송통신대학교출판문화원
https://stackoverflow.com/questions/62746868/binary-search-recursive-returns-none
파이썬에서 배열과 함수를 처리하는 방법은 아무래도 다른 언어와 다르기 때문에 자료에 나온 수도코드들을 참고해서 작성해 본 뒤에 오류가 발생하는 경우는 비슷한 오류에 대해 검색해서 수정했다.
def BinarySearch(arr, left, right, x):
if left<=right:
mid=(left+right)//2
#print("Check if",x,"== arr[",mid,"]")
if x==arr[mid]:
#print(x,"== arr[",mid,"]")
return int(mid)
elif x<arr[mid]:
#print(x,"< arr[",mid,"]")
return BinarySearch(arr, left, mid-1,x)
else:
#print(x,"> arr[",mid,"]")
return BinarySearch(arr, mid+1, right, x)
else:
return -1
물론 기본 이진탐색과 다른 점이 있다.
첫째, 배열에 있는 값들이 키값과 동일한지 확인하는 것이 아니라 키값으로 k[]
에서 계산한 랜선 개수를 계산한 뒤에 비교해야 하고
둘째, 랜선 개수가 동일할 때 탈출하지 않고 더 큰 최댓값을 찾을 수 있는지 계속 탐색해야 하고
셋째, 랜선 개수가 더 많으면 탐색에 실패했다고 간주하는 게 아니라 일단 하나 찾았다고 간주하고 더 큰 값(랜선 개수가 더 줄어들 수 있는 값들)들에 대해서 탐색을 계속해야 한다.
그래서 위의 이진탐색 함수는 이렇게 수정되었다.
def wires_count(arr,L):
table=[]
for unit in arr:
table.append(unit//L)
return sum(table)
def BinarySearchL(k_min, k_max, key,arr,current_max):
if k_min<=k_max:
mid=(k_min+k_max)//2
if wires_count(arr,mid)==key:
if mid>current_max:
current_max=mid
return BinarySearchL(mid+1, k_max,key,arr,current_max)
elif wires_count(arr,mid)<key:
return BinarySearchL(k_min, mid-1,key,arr,current_max)
else:
if mid>current_max:
current_max=mid
return BinarySearchL(mid+1, k_max,key,arr,current_max)
else:
return current_max
그래서 전체 코드는 이렇게 되었다.
import sys
#import math
def wires_count(arr,L):
table=[]
for unit in arr:
table.append(unit//L)
return sum(table)
def BinarySearchL(k_min, k_max, key,arr,current_max):
if k_min<=k_max:
mid=(k_min+k_max)//2
if wires_count(arr,mid)==key:
if mid>current_max:
current_max=mid
return BinarySearchL(mid+1, k_max,key,arr,current_max)
elif wires_count(arr,mid)<key:
return BinarySearchL(k_min, mid-1,key,arr,current_max)
else:
if mid>current_max:
current_max=mid
return BinarySearchL(mid+1, k_max,key,arr,current_max)
else:
return current_max
K,N=map(int,sys.stdin.readline().split())
k=[]
for i in range(K):
k.append(int(sys.stdin.readline()))
k.sort()
r = BinarySearchL(1,min(sum(k)//N,k[-1]),N,k,1)
print(r)
처음에는 도무지 어디가 막혔는지 찾을 수가 없었다.
그래서 위 문제의 질문게시판을 보니 주요 '반례'들이 댓글에 나와 있었고 해당 반례들을 입력해보면서 오답이 나오는 경우 코드의 오류를 하나하나 찾아나갔다.
https://www.acmicpc.net/board/search/all/problem/1654
헷갈리기 쉬운 점들은 이런 것들이었다.
코드를 검사하는 데 유용했던 반례를 좀 정리해보려고 했는데, 문제를 처음부터 다시 읽어보니 K≤N이라서 대부분의 극단적인 반례는 아예 입력 대상에 포함되지도 않는다. 물론 계산이 맞게 되고 있는지 확인하기 위해서 K>N인 반례들도 입력해보는 것도 나쁘지 않다. 예를 들어서 K,N=3,1
이고 k[]=[100,100,1]
이라면 계산 결과는 100이 나와야 하는 건 맞다. 다만 K≤N
이니 유효한 테스트 케이스는 아니다.
이진 탐색을 원하는 값만 찾으면 즉시 탈출하는 기본적인 방법이 아니라 최댓값을 찾을 때까지 계속 찾기 위해서 사용하는 것은 처음 해보았다.
그래도 끝까지 포기하지 않고 해서 마침내 합격했더니 뿌듯함을 넘어서 감동적이고... 솔직히 너무 어려워서 가능할 거라고 생각하지 않았다.