결정 알고리즘(Parametric Search)은 이진 탐색을 통해 문제를 해결하는 방법론이다.
다음과 같은 경우, 결정 알고리즘을 통해 문제를 해결할 수 있다.
결정 알고리즘은 다음과 같은 과정으로 구현된다.
binarySearch(l, r)
mid = l + (r - l) / 2
ans = mid
binarySearch(l, mid - 1)
binarySearch(mid + 1, r)
if(l > r) return ans
이 문제를 메서드로 정의하자면 아래와 같을 것이다.
def find_valid_max_dist(x: int, y: int, node_count: int, spots: list):
단 이 때, node_count
는 spots
의 크기보다 커서는 안 된다.
def find_valid_max_dist(x: int, y: int, node_count: int, spots: list):
if node_count > len(spots):
raise Exception("count of spots is smaller than node_count!")
이제 값이 존재할 수 있는 범위 내에서 이진 탐색을 진행하면 된다.
def find_valid_max_dist(x: int, y: int, node_count: int, spots: list):
if node_count > len(spots):
raise Exception("count of spots is smaller than node_count!")
return binary_search(min_dist=1,
max_dist=y - x,
node_count=node_count,
spots=spots,
valid_max_dist=-1)
y-x
valid_max_dist
는 -1 로 초기화했다.이제 이진탐색 메서드를 정의해보자.
def binary_search(min_dist: int, max_dist: int, node_count: int, spots: list, valid_max_dist: int):
if min_dist > max_dist: # 종료 조건
return valid_max_dist
mid_dist = min_dist + int((max_dist - min_dist) / 2)
if is_valid_dist(dist=mid_dist, node_count=node_count, spots=spots): # 유효하다면
return binary_search(min_dist=mid_dist + 1, # 더 큰 값을 찾는다.
max_dist=max_dist,
node_count=node_count,
spots=spots,
valid_max_dist=mid_dist) # 값을 기록 한다.
else: # 유효 하지 않다면
return binary_search(min_dist=min_dist,
max_dist=mid_dist - 1, # 유효값을 찾는다.
node_count=node_count,
spots=spots,
valid_max_dist=valid_max_dist)
유효한 값인지는 아래와 같이 검사해 볼 수 있을 것이다.
def is_valid_dist(dist: int, node_count: int, spots: list):
cnt = 0
last_spot = None
for spot in spots:
if last_spot:
if spot - last_spot >= dist:
cnt += 1
last_spot = spot
else:
last_spot = spot
cnt = 1
return cnt >= node_count
만약, 최대한 멀리 떨어진 노드들의 조합을 찾아야 한다면 dfs 를 사용해 구할 수 있다.
def find_node_combinations(node_count: int, spots: list, dist: int):
ans = list()
dfs(node_count, spots, dist, ans, [], 0, None)
return ans
def dfs(node_count, spots, dist, ans, now_list, now_idx, last_spot):
if len(now_list) == node_count:
ans.append(now_list)
return
if now_idx >= len(spots):
return
now_spot = spots[now_idx]
if not last_spot or now_spot - last_spot >= dist:
# 선택
dfs(node_count, spots, dist, ans, now_list + [now_spot], now_idx + 1, now_spot)
# 선택x
dfs(node_count, spots, dist, ans, now_list, now_idx + 1, last_spot)
spots
를 탐색하며 거리 상 유효하다면 노드를 배치할 수 있다. (하지 않을 수도 있다)def find_best_node_spots(x: int, y: int, node_count: int, spots: list):
if node_count > len(spots):
raise Exception("count of spots is smaller than node_count!")
valid_max_dist = binary_search(min_dist=1,
max_dist=y - x,
node_count=node_count,
spots=spots,
ans=-1)
print(valid_max_dist)
if valid_max_dist == -1:
raise Exception("Error! No valid distance found!")
else:
return find_node_combinations(node_count=node_count, spots=spots, dist=valid_max_dist)
def binary_search(min_dist: int, max_dist: int, node_count: int, spots: list, ans: int):
if min_dist > max_dist: # 종료 조건
return ans
mid_dist = min_dist + int((max_dist - min_dist) / 2)
if is_valid_dist(dist=mid_dist, node_count=node_count, spots=spots): # 유효 하다면
return binary_search(min_dist=mid_dist + 1, # 더 큰 값을 찾는다.
max_dist=max_dist,
node_count=node_count,
spots=spots,
ans=mid_dist) # 값을 기록 한다.
else: # 유효 하지 않다면
return binary_search(min_dist=min_dist,
max_dist=mid_dist - 1, # 유효 값을 찾는다.
node_count=node_count,
spots=spots,
ans=ans)
def is_valid_dist(dist: int, node_count: int, spots: list):
cnt = 0
last_spot = None
for spot in spots:
if last_spot:
if spot - last_spot >= dist:
cnt += 1
last_spot = spot
else:
last_spot = spot
cnt = 1
return cnt >= node_count
def find_node_combinations(node_count: int, spots: list, dist: int):
ans = list()
dfs(node_count, spots, dist, ans, [], 0, None)
return ans
def dfs(node_count, spots, dist, ans, now_list, now_idx, last_spot):
if len(now_list) == node_count:
ans.append(now_list)
return
if now_idx >= len(spots):
return
now_spot = spots[now_idx]
if not last_spot or now_spot - last_spot >= dist:
dfs(node_count, spots, dist, ans, now_list + [now_spot], now_idx + 1, now_spot)
dfs(node_count, spots, dist, ans, now_list, now_idx + 1, last_spot)
# 테스트
if __name__ == '__main__':
best_comb = find_best_node_spots(1, 31, 8,
[1, 2, 3, 4, 6, 7, 8, 9,
11, 12, 13, 14, 16, 17,
18, 19, 21, 22, 23, 24,
26, 27, 28, 29])
print(best_comb)