Diameter of Binary Tree (Tree)

하루히즘·2021년 5월 2일
1

LeetCode

목록 보기
13/17

설명

LeetCode의 Diameter of Binary Tree다.

이진 트리의 지름(diameter)을 구하라는 게 무슨 소리인지 헷갈릴 수 있지만 문제에서는 트리의 두 노드 간 가장 먼 거리를 구하라고 명시하고 있다.
예를 들어 위와 같은 트리에서는 3을 반환해야 한다. 왜냐하면 노드 4와 노드 3이 거리 3으로 제일 멀리 떨어져 있기 때문이다. 트리 구조에 따라 가장 멀리 떨어져 있는 노드들이 여럿 존재(노드 4 ~ 노드 3 또는 노드 5 ~ 노드 3)할 수 있으나 어쨌든 거리만 반환하면 된다.

풀이

Tree Traversal #1(실패)

처음으로 시도했던 풀이는 루트 노드를 기준으로 트리를 왼쪽 서브 트리, 오른쪽 서브 트리로 나눈 후 각각의 높이 즉 가장 깊은 곳에 있는 노드까지 탐색 후 두 높이를 합한 결과를 반환하는 것이다.

위의 예시에서는 루트 노드 1을 기준으로 왼쪽 서브 트리, 오른쪽 서브 트리를 나누었을 때 왼쪽 서브 트리에서는 가장 깊은 곳에 있는 노드 4, 노드 5의 높이 2를 탐색할 수 있다. 오른쪽 서브 트리에서는 가장 깊은 곳에 있는 노드 3의 높이 1을 탐색할 수 있으며 이 둘을 합하면 2 + 1 = 3으로 지름을 구할 수 있다.

그러나 문제를 다시 읽어보면 다음과 같은 문장이 있었다.

This path may or may not pass through the root.

즉 가장 멀리 떨어져 있는 노드들이 항상 중간에 루트 노드를 두진 않는다는 것이다. 위에서 접근했던 풀이는 루트 노드를 기준으로 왼쪽, 오른쪽 서브 트리를 나누며 이는 가장 멀리 떨어진 노드 간 경로가 항상 루트 노드를 지나갈 것이라는 가정 하에 구현했던 방법이었다.

그러나 다음과 같은 반례를 보면 루트 노드는 상관없이 한쪽 서브 트리 내에서 가장 멀리 떨어진 노드들이 탐색되는 것을 볼 수 있다.
파란색 라인으로 그려진 경로가 처음 구상했던 풀이로 얻어진 경로며 빨간색 라인으로 그려진 경로가 문제에서 요구하는 경로다. 좌우 서브 트리의 높이를 더하는 방법으로 구한 파란색 라인은 7로 계산되지만 실제로 가장 먼 노드간 거리는 빨간색 라인으로 계산된 8이다.

즉 이 방법은 아예 접근 자체가 잘못됐기 때문에 올바르지 않은 풀이였다. 구현했던 코드는 다음과 같다.

import collections


class Solution:
    def diameterOfBinaryTree(self, root: TreeNode) -> int:
        # 좌우 서브 트리의 높이.
        leftFurthest = 0 if root.left is None else 1
        rightFurthest = 0 if root.right is None else 1
        
        # 루트 노드의 왼쪽 서브 트리부터 탐색 시작.
        queue = collections.deque()
        if root.left is not None:
            queue.append((root.left, 1)) # (노드, 높이) 튜플로 삽입.
            
        while queue:
            # 왼쪽 서브 트리에서 현재 노드의 레벨에 따라 트리의 높이 갱신.
            currentNode, currentLevel = queue.popleft()
            leftFurthest = max(leftFurthest, currentLevel)
            
            # 자식 노드가 있다면 탐색 대기열에 추가.
            if currentNode.left is not None:
                queue.append((currentNode.left, currentLevel+1))
            if currentNode.right is not None:
                queue.append((currentNode.right, currentLevel+1))
                
        queue.clear()
        # 루트 노드의 오른쪽 서브 트리 탐색.
        if root.right is not None:
            queue.append((root.right, 1))
            
        while queue:
            # 오른쪽 서브 트리에서 현재 노드의 레벨에 따라 트리의 높이 갱신.
            currentNode, currentLevel = queue.popleft()
            rightFurthest = max(rightFurthest, currentLevel)
            
            if currentNode.left is not None:
                queue.append((currentNode.left, currentLevel+1))
            if currentNode.right is not None:
                queue.append((currentNode.right, currentLevel+1))
         
        # 좌우 서브 트리의 높이의 합 반환. 
        return leftFurthest + rightFurthest

Tree Traversal #2(44 ms)

그렇다면 어떻게 가장 먼 노드간 거리를 구할 수 있을까? 해결책은 루트 노드부터 시작하는 게 아니라 말단 노드부터 탐색하면서 지름을 갱신하는 방법이었다.

말단 노드부터 갱신한다는 것은 위의 풀이처럼 트리를 루트 노드에서 좌우 서브 트리로 나누면서 말단 노드 방향으로 탐색하는 게 아니라 말단 노드부터 루트 노드로 탐색하면서 서브 트리를 슈퍼 트리로 합치는 것이다. 그리고 새로운 지름이 발견됐을 때 이를 실시간으로 갱신한다.

헷갈릴 수 있으니 일단 구현된 코드를 보자.

class Solution:
    # 재귀 호출에서 전역적으로 접근하기 위한 클래스 변수 '지름'.
    maxDiameter = 0
    
    # 지름 탐색 재귀 호출 함수.
    def findMaxDiameter(self, node):
        # 말단 노드라면 0을 반환하여 탐색 시작.
        if node is None:
            return 0
        
        # 왼쪽 서브 트리의 지름 탐색.
        leftSubTreeResult = self.findMaxDiameter(node.left)
        # 오른쪽 서브 트리의 지름 탐색.
        rightSubTreeResult= self.findMaxDiameter(node.right)
        # 최대 지름을 왼쪽, 오른쪽 서브 트리의 높이의 합과 비교하여 갱신.
        # 함수에서 더 큰 서브 트리의 높이를 기준으로 반환하기 때문에 가능한 것.
        self.maxDiameter = max(self.maxDiameter, leftSubTreeResult + rightSubTreeResult)
        
        # 서브 트리의 높이 중 더 큰 값에 1을 더해서 반환.
        # 상위 노드에서는 현재 서브 트리의 결과를 이용해서 지름을 계산해야 함.
        # 두 서브 트리 중 더 큰 쪽의 서브 트리를 이용하여 슈퍼 트리로 합치는 과정.
        return max(leftSubTreeResult, rightSubTreeResult) + 1
    
    def diameterOfBinaryTree(self, root: TreeNode) -> int:
        # 루트 노드부터 말단 노드로 탐색 시작.
        self.findMaxDiameter(root)
        return self.maxDiameter

최대한 말로 설명해보려 했지만 너무 까다로워서 코드의 주석으로 대체했다. 중요한 점은 트리의 지름을 별도의 외부 변수로 두고 말단 노드부터 탐색하면서 계산하여 실시간으로 갱신해나가는 것이다.
위의 그림을 보면 서브 트리에서 반환한 값을 상위 노드가 어떻게 활용하는지 알 수 있다. 상위 노드 입장에서는 서브 트리가 얼마나 복잡하든 결과적으로는 두 개의 자식 노드로 취급되어야 한다. 그리고 각 서브 트리가 자식 노드로 추상화될 때 두 자식 노드의 높이의 합을 지금까지 계산된 지름과 비교하여 갱신한다.

이처럼 말단 노드부터 루트 노드까지 탐색하게 된다면 결국에는 위처럼 추상화될 것이다. 높이의 합을 다룬다는 점에서 첫번째 풀이와 유사하기도 하지만 최종적으로 계산된 값이 아닌 탐색 중간에 실시간으로 지름이 갱신된다는 점이 특징이다.

몇줄 되지 않는 간단한 코드기 때문에 손으로 직접 탐색하는 것도 좋다. 특히 'return max(leftSubTreeResult, rightSubTreeResult) + 1'이 잘 이해가 되지 않았는데 위처럼 직접 그림을 그리면서 실행해보니 금방 이해할 수 있었다.

후기

Easy 난이도지만 생각보다 까다로운 문제였다. 특히 이 풀이를 이해하느라 몇 시간이 걸린 것 같은데 재귀는 역시 머릿속으로는 이해하기 어렵고 직접 그림을 그려봐야 더 잘 이해할 수 있는것 같다.

profile
YUKI.N > READY?

0개의 댓글