CS/자료구조

[자료구조] 트리 : 최소 공통 조상(LCA)

mhko411 2021. 10. 19. 23:52
728x90

트리에서 최소 공통 조상이 무엇이며 이를 구하기 위한 기본적인 방법과 메모리를 사용하여 시간 복잡도를 개선한 알고리즘을 알아보자.


최소 공통 조상(Lowest Common Ancestor)

최소 공통 조상은 두 노드의 공통된 조상 중에서 가장 가까운 조상을 의미한다. 아래의 트리를 통해 9번 노드와 11번 노드의 LCA를 알아보자.

9번 노드의 부모 노드는 4 -> 2 -> 1이며 11번 노드의 부모 노드는 5 -> 2 -> 1이다. 이때 2와 1은 두 노드의 공통 조상이된다. 이처럼 두 개의 노드에서 공통 조상이 될 수 있는 노드는 여러 개일 수 있으며 여기서 레벨이 가장 높은(두 개의 노드에서 가장 가까운)노드가 최소 공통 조상이 된다.

 

LCA 알고리즘 Ⅰ

기본적인 LCA 알고리즘의 개념은 다음과 같다.

  • 모든 노드가 위치한 Level을 계산한다.
  • LCA를 구할 두 개의 노드가 위치한 Level이 다르다면 맞춰준다.
  • 동일한 Level에서 출발하여 같은 부모 노드가 나올 때까지 거슬러 올라간다.

기본적인 LCA 알고리즘을 위의 트리에서 적용해보자. 9번과 7번 노드의 LCA를 구해보자. 먼저 9번은 Level 3에 위치하고 있으며 7번은 Level 2에 위치하고 있다. 따라서 9번 노드의 Level을 끌어올려 Level 2로 맞춰준다. 여기서 맞춰준다는 것은 자신의 부모 노드로 이동하여 Level을 맞춰주는 것이다.

 

Level을 맞췄다면 노드 4번, 노드 7번에서 동시에 출발하여 같은 부모 노드가 나올 때까지 이동한다. 결과적으로 노드 1번을 LCA로 구할 수 있을 것이다. 이제 이를 Python으로 구현해보자.

 

코드로 구현하기

  • 루트 노드를 시작으로 DFS를 통해 노드들의 깊이를 조사한다.
  • 이를 활용하여  LCA를 구할 두 개의 노드가 깊이가 다를 때 깊이를 맞춰준다.
  • 깊이가 맞춰졌다면 해당 노드부터 동시에 거슬러 올라가서 같은 노드가 될 때를 찾는다.
import sys
sys.setrecursionlimit(int(1e5))

def dfs(cur, level):
    visited[cur] = 1 # 현재 노드를 방문했다는 표시를 한다.
    depth[cur] = level # 현재 레벨을 방문한 노드의 레벨로 기록한다.

    for child in tree[cur]: # 현재 노드의 자식 노드들을 방문한다.
        if not visited[child]:
            parent[child] = cur # child 노드의 부모 노드를 현재 방문한 노드로 기록한다.
            dfs(child, level+1)

def lca(a, b):
    while depth[a] != depth[b]: # 깊이가 맞춰질 때까지
        if depth[a] > depth[b]: # 더 깊이 있는 곳의 노드를 부모 노드로 이동시킨다.
            a = parent[a]
        else:
            b = parent[b]

    while a != b: # a와 b가 같아질 때까지
        a = parent[a] # 거슬러 올라간다.
        b = parent[b]

    return a

N = int(input())
tree = [[] for _ in range(N+1)]

for _ in range(N-1):
    a, b = map(int, input().split())
    tree[a].append(b)
    tree[b].append(a)

parent = [0] * (N+1) # 노드의 부모 노드를 기록할 리스트
depth = [0] * (N+1) # 노드의 깊이를 기록할 리스트
visited = [0] * (N+1) # 노드의 방문 표시

dfs(1, 0)
M = int(input())
for _ in range(M):
    a, b = map(int, input().split())
    print(lca(a, b))