결정 알고리즘

이영규·2023년 6월 10일
0
post-custom-banner

결정 알고리즘(Parametric Search)은 이진 탐색을 통해 문제를 해결하는 방법론이다.

사용 가능한 케이스

다음과 같은 경우, 결정 알고리즘을 통해 문제를 해결할 수 있다.

  • 찾고자 하는 답이 특정 범위 내에 있는 경우
  • 특정 값이 답으로 유효한지 판단이 가능한 경우

구현

결정 알고리즘은 다음과 같은 과정으로 구현된다.

  • 값이 존재할 수 있는 범위 안에서 이진 탐색을 진행한다. binarySearch(l, r)
  • 중간 값을 찾아 해당 값이 유효한지 확인한다. mid = l + (r - l) / 2
  • 값이 유효하다면 기록해둔다. ans = mid
  • 값이 유효한지 유효하지 않은지에 따라 탐색의 범위를 바꿔가며 최적값을 찾아 계속 진행한다.
    • binarySearch(l, mid - 1)
    • binarySearch(mid + 1, r)
  • 더 이상 탐색할 수 있는 범위가 없으면 종료한다. if(l > r) return ans

예시) 가장 가까운 거리가 가장 크게 만들기

  • 1차원 정수 범위(x ~ y) 내에 n 개의 노드를 배치한다.
  • 노드는 아무 곳에나 위치할 수 없으며, 각 노드가 위치할 수 있는 후보(spots)가 주어진다.
  • 각 노드는 서로 최대한 멀리 떨어져야 한다.
    • = 노드 간 가장 가까운 거리가 가장 크게 만들고 싶다.
  • 이 때의 노드 간의 가장 가까운 거리의 최댓값을 구하라.

문제정의

이 문제를 메서드로 정의하자면 아래와 같을 것이다.

def find_valid_max_dist(x: int, y: int, node_count: int, spots: list):

단 이 때, node_countspots 의 크기보다 커서는 안 된다.

def find_valid_max_dist(x: int, y: int, node_count: int, spots: list):
    if node_count > len(spots):
        raise Exception("count of spots is smaller than node_count!")

이제 값이 존재할 수 있는 범위 내에서 이진 탐색을 진행하면 된다.

def find_valid_max_dist(x: int, y: int, node_count: int, spots: list):
    if node_count > len(spots):
        raise Exception("count of spots is smaller than node_count!")

    return binary_search(min_dist=1,
                         max_dist=y - x,
                         node_count=node_count,
                         spots=spots,
                         valid_max_dist=-1)
  • 정수 범위이므로, 거리의 최솟값은 1이다.
  • 거리의 최댓값은 y-x
  • 최적값 valid_max_dist 는 -1 로 초기화했다.

이진탐색

이제 이진탐색 메서드를 정의해보자.

def binary_search(min_dist: int, max_dist: int, node_count: int, spots: list, valid_max_dist: int):
    if min_dist > max_dist: # 종료 조건
        return valid_max_dist

    mid_dist = min_dist + int((max_dist - min_dist) / 2)
    
    if is_valid_dist(dist=mid_dist, node_count=node_count, spots=spots): # 유효하다면
        return binary_search(min_dist=mid_dist + 1,  # 더 큰 값을 찾는다.
                             max_dist=max_dist,
                             node_count=node_count,
                             spots=spots,
                             valid_max_dist=mid_dist)	# 값을 기록 한다.
    else: # 유효 하지 않다면
        return binary_search(min_dist=min_dist,
                             max_dist=mid_dist - 1,  # 유효값을 찾는다.
                             node_count=node_count,
                             spots=spots,
                             valid_max_dist=valid_max_dist)
  • 더 이상 탐색할 범위가 없으면 탐색을 종료한다.
  • 유효한 거리 값이라면 값을 기록하고 더 큰 값을 찾아 떠난다.
  • 유효하지 않다면, 더 작은 값의 범위에서 유효값을 찾아본다.

유효값 검증

유효한 값인지는 아래와 같이 검사해 볼 수 있을 것이다.

def is_valid_dist(dist: int, node_count: int, spots: list):
    cnt = 0
    last_spot = None

    for spot in spots:
        if last_spot:
            if spot - last_spot >= dist:
                cnt += 1
                last_spot = spot
        else:
            last_spot = spot
            cnt = 1

    return cnt >= node_count
  • 해당 거리 값으로 모든 노드의 배치가 가능하다면 유효한 거리 값이다.

번외) 조합 찾기

만약, 최대한 멀리 떨어진 노드들의 조합을 찾아야 한다면 dfs 를 사용해 구할 수 있다.

def find_node_combinations(node_count: int, spots: list, dist: int):
    ans = list()
    dfs(node_count, spots, dist, ans, [], 0, None)
    return ans


def dfs(node_count, spots, dist, ans, now_list, now_idx, last_spot):
    if len(now_list) == node_count:
        ans.append(now_list)
        return

    if now_idx >= len(spots):
        return

    now_spot = spots[now_idx]
    if not last_spot or now_spot - last_spot >= dist:
    	# 선택
        dfs(node_count, spots, dist, ans, now_list + [now_spot], now_idx + 1, now_spot)

	# 선택x
    dfs(node_count, spots, dist, ans, now_list, now_idx + 1, last_spot)
  • spots 를 탐색하며 거리 상 유효하다면 노드를 배치할 수 있다. (하지 않을 수도 있다)
  • 유효하지 않다면 노드를 배치하지 않고 넘어간다.

최종 코드 전체

def find_best_node_spots(x: int, y: int, node_count: int, spots: list):
    if node_count > len(spots):
        raise Exception("count of spots is smaller than node_count!")

    valid_max_dist = binary_search(min_dist=1,
                                   max_dist=y - x,
                                   node_count=node_count,
                                   spots=spots,
                                   ans=-1)
    print(valid_max_dist)

    if valid_max_dist == -1:
        raise Exception("Error! No valid distance found!")
    else:
        return find_node_combinations(node_count=node_count, spots=spots, dist=valid_max_dist)


def binary_search(min_dist: int, max_dist: int, node_count: int, spots: list, ans: int):
    if min_dist > max_dist: # 종료 조건
        return ans

    mid_dist = min_dist + int((max_dist - min_dist) / 2)
    if is_valid_dist(dist=mid_dist, node_count=node_count, spots=spots):    # 유효 하다면
        return binary_search(min_dist=mid_dist + 1,  # 더 큰 값을 찾는다.
                             max_dist=max_dist,
                             node_count=node_count,
                             spots=spots,
                             ans=mid_dist)  # 값을 기록 한다.
    else:   # 유효 하지 않다면
        return binary_search(min_dist=min_dist,
                             max_dist=mid_dist - 1,  # 유효 값을 찾는다.
                             node_count=node_count,
                             spots=spots,
                             ans=ans)


def is_valid_dist(dist: int, node_count: int, spots: list):
    cnt = 0
    last_spot = None

    for spot in spots:
        if last_spot:
            if spot - last_spot >= dist:
                cnt += 1
                last_spot = spot
        else:
            last_spot = spot
            cnt = 1

    return cnt >= node_count


def find_node_combinations(node_count: int, spots: list, dist: int):
    ans = list()
    dfs(node_count, spots, dist, ans, [], 0, None)
    return ans


def dfs(node_count, spots, dist, ans, now_list, now_idx, last_spot):
    if len(now_list) == node_count:
        ans.append(now_list)
        return

    if now_idx >= len(spots):
        return

    now_spot = spots[now_idx]
    if not last_spot or now_spot - last_spot >= dist:
        dfs(node_count, spots, dist, ans, now_list + [now_spot], now_idx + 1, now_spot)

    dfs(node_count, spots, dist, ans, now_list, now_idx + 1, last_spot)


# 테스트
if __name__ == '__main__':
    best_comb = find_best_node_spots(1, 31, 8,
                                     [1, 2, 3, 4, 6, 7, 8, 9,
                                      11, 12, 13, 14, 16, 17,
                                      18, 19, 21, 22, 23, 24,
                                      26, 27, 28, 29])
    print(best_comb)
profile
더 빠르게 더 많이 성장하고 싶은 개발자입니다
post-custom-banner

0개의 댓글