본문 바로가기
스터디/알고리즘

최소 공통 조상 (Lowest Common Ancestor, LCA)

by 궁금한 준이 2024. 7. 11.
728x90
반응형

최소 공통 조상 (LCA) 알고리즘

Problem Setup

$N$개의 node로 이루어진 트리가 주어진다. (트리이므로 간선의 개수는 $N-1$개이다)

$q$개의 query에 대하여 두 node사이의 거리를 계산한다.

백준 11437 예제

 

위 그림은 백준 11437 LCA 문제이다. 1번 노드가 루트이다.

6번과 11번 노드의 최소 공통 조상은 2번 노드이다. (1번 노드 역시 공통 조상이지만, 최소가 아니다.)

10번과 9번 노드의 최소 공통 조상은 4번이다. 즉, LCA(10, 9) = 4

8번과 15번 노드의 최소 공통 조상은 1번이다. 즉, LCA(8, 15) - 1

 

루트노드에서 DFS/BFS 알고리즘을 통해 각 노드의 깊이를 계산하여 depth 배열에 저장한다.

1의 depth는 0, 2와 3의 depth는 1, 4번~8번 노드의 depth는 2, 9번~14번 노드의 depth는 3, 15번 노드의 depth는 4가 된다.

 

그리고 쿼리의 두 노드 $u$, $v$의 깊이가 같으면 반복문을 통해 부모가 같은 노드를 찾는다.

예를 들어, 10번과 9번의 노드는 깊이가 같기 때문에, 바로 반복문을 통해 공통 조상을 찾는다.

from collections import deque

def dfs(adj, visited, depth, parent, start):
    visited[start] = True
    depth[start] = 0
    stk = deque([(start, 0)]) # stack, (x)
    while stk:
        cur, cur_depth = stk.pop()
        for nxt in adj[cur]:
            if not visited[nxt]:
                visited[nxt] = True
                depth[nxt] = cur_depth + 1
                parent[nxt] = cur
                stk.append((nxt, cur_depth + 1))

N = int(input()) # node id: [1, N]
parent = [0] * (N + 1)
adj = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        A, B = map(int, input().split())
        adj[A].append(B)
        adj[B].append(A)
            
    depth = [0] * (N + 1)
    visited = [False] * (N + 1)
    dfs(adj, visited, depth, parent, 1)



LCA (lazy)

만약 두 노드의 깊이가 다르면, 깊이가 깊은 노드를 얕은 노드의 깊이의 조상까지 찾고, 위의 과정을 반복한다.

예를 들어, 8번과 15번의 노드는 깊이가 다르기 때문에, 더 깊은 15번 노드의 (8번 노드와 깊이가 같은) 조상 노드를 찾는다. 여기서 5번 노드가 된다. 이후, 8번 노드와 5번 노드에서 반복문을 통해 공통조상을 찾으면 1번 노드가 된다.

def find_lca(depth, x, y):
    # check same depth
    while depth[x] != depth[y]:
        if depth[x] > depth[y]:
            x = parent[x]
        else:
            y = parent[y]
            
    # find same node
    while x != y:
        x = parent[x]
        y = parent[y]
        
    return x

LCA (fast)

단순 반복문을 통해 최소공통조상을 찾는 과정은 $O(depth)$가 된다.

즉, 트리의 깊이가 깊어지면 비효율적이다.

이진법을 이용하여 더 빠르게 조상을 거슬러 올라갈 수 있다.

예를 들어, 14의 경우 이진수로 나타내면 1110이고, 이 숫자를 계속 shifting하여 1이 나오는 횟수만큼 거슬러 올라가면 된다. 이 경우, $2^1, 2^2, 2^3$번째 조상노드로 거슬러 올라가면 된다. 4번의 반복문(이진수의 길이가 4)만 이용하면 된다.

따라서 이진수로 나타낼 수 있는 최대 수 MAX_LOG만큼만 시간이 소요된다.

 

MAX_LOG는 N을 2진수로 나타낼 수 있는 비트 수와 동일하다.

최악의 경우, 트리의 노드들이 일렬로 연결된 것으로 생각할 수 있다.

이때 최대 길이는 edge 수와 같은 N-1이므로 2진수의 길이는 N이면 충분하다. 

 

이제, parent[x]가 아니라 parent[x][i]라는 2차원 배열을 이용하여 깊이가 깊은 노드의 조상을 미리 찾는다.

parent[x][i]는 x번 노드의 $2^i$번째 조상 번호를 저장한다. ($i=0, 1, 2, \dots$)

초기화를 $-1$로 하여 $2^i$번째 조상이 없으면 $-1$로 간주한다.

(이러한 자료구조를 sparse table, 희소 테이블이라고 부른다)

 

우선 parent를 (N + 1, MAX_LOG + 1) shape의 2차원 배열을 -1로 초기화하자.

0번째 인덱스는 사용하지 않으므로, 각각 +1을 해준다.

# number of nodes. [1, 2, ..., N]
N = int(input()) 
adj = [[] for _ in range(N + 1)]
MAX_DEPTH = math.ceil(math.log2(N)) + 1
for _ in range(N - 1):
    A, B = map(int, input().split())
    adj[A].append(B)
    adj[B].append(A)

parent = [[-1 for _ in range(MAX_DEPTH)] for _ in range(N + 1)]
depth = [0] * (N + 1)​

 

DFS(혹은 BFS)를 이용하여 parent[x][0] = depth를 저장한다. (자기 자신의 depth)

그리고 preprocess라는 함수를 이용하여 parent[x][i]를 초기화하자.

def preprocess(parent):
    N = len(parent) - 1
    MAX_LOG = len(parent[0]) - 1
    for j in range(1, MAX_LOG + 1):
        for i in range(1, N + 1):
            if parent[i][j - 1] != -1:
                parent[i][j] = parent[parent[i][j - 1]][j - 1]

 

LCA (x, y)함수도 수정한다.

x가 더 깊은 노드라고 가정하자.

그러면 x를 y의 깊이만큼 조상을 거슬러 올라간다.

그런데 parent[x][i]가 $2^i$번째 조상이므로 선형시간이 아니라 $\log$ 시간만큼 빠르게 거슬러 올라갈 수 있다.

def find_lca(parent, depth, x, y):
    MAX_LOG = len(parent[0]) - 1
    # setup: x is deeper node
    if depth[x] < depth[y]:
        x, y = y, x
        
    diff = depth[x] - depth[y]
    for i in range(MAX_LOG + 1):
        if (diff >> i) & 1:
            x = parent[x][i]
            
    if x == y:
        return x
    
    for i in range(MAX_LOG, -1, -1):
        if parent[x][i] != parent[y][i]:
            x = parent[x][i]
            y = parent[y][i]
            
    return parent[x][0]

 

전체 코드는 다음과 같다.

import sys
#sys.stdin = open('input.txt', 'r')
input = sys.stdin.readline

import math
from collections import deque

# DFS 또는 BFS로 같은 depth 배열 저장
def bfs(adj, visited, depth, parent, start):
    q = deque([(start, 0)]) # (node, depth)
    visited[start] = True
    depth[start] = 0
    
    while q:
        cur, cur_depth = q.popleft()
        for nxt in adj[cur]:
            if not visited[nxt]:
                visited[nxt] = True
                depth[nxt] = cur_depth + 1
                parent[nxt][0] = cur
                q.append((nxt, depth[nxt]))
                
def preprocess(parent):
    N = len(parent) - 1
    MAX_LOG = len(parent[0]) - 1
    for j in range(1, MAX_LOG + 1):
        for i in range(1, N + 1):
            if parent[i][j - 1] != -1:
                parent[i][j] = parent[parent[i][j - 1]][j - 1]
    
            
def find_lca(parent, depth, x, y):
    MAX_LOG = len(parent[0]) - 1
    # setup: x is deeper node
    if depth[x] < depth[y]:
        x, y = y, x
        
    diff = depth[x] - depth[y]
    for i in range(MAX_LOG + 1):
        if (diff >> i) & 1:
            x = parent[x][i]
            
    if x == y:
        return x
    
    for i in range(MAX_LOG, -1, -1):
        if parent[x][i] != parent[y][i]:
            x = parent[x][i]
            y = parent[y][i]
            
    return parent[x][0]

if __name__ == '__main__':
    # number of nodes: [1, 2, ..., N]
    N = int(input()) 
    adj = [[] for _ in range(N + 1)]
    MAX_LOG = math.ceil(math.log2(N))
    for _ in range(N - 1):
        A, B = map(int, input().split())
        adj[A].append(B)
        adj[B].append(A)
        
    parent = [[-1 for _ in range(MAX_LOG + 1)] for _ in range(N + 1)]
    depth = [0] * (N + 1)
    visited = [False] * (N + 1)
    bfs(adj, visited, depth, parent, 1)
    preprocess(parent)
    
    # number of query
    M = int(input())
    for _ in range(M):
        u, v = map(int, input().split())
        lca = find_lca(parent, depth, u, v)
        print(lca)

시간 복잡도

$N$을 노드 수라고 하자. 그러면 트리의 깊이는 $\log N$ 이다.

DFS/BFS 단계에서 $O(V+E) = O(N + N - 1) = O(N)$의 시간복잡도가 소요된다.

전처리 단계에서(parent 초기화) $O(N \log N)$의 시간복잡도가 소요된다.

LCA(u, v)를 계산할 때는 $O(\log N)$의 시간복잡도가 소요된다.

$q$개의 LCA(u, v) 쿼리가 주어지면 $O(q \log N)$이 된다.

 

BOJ 11437 (LCA 2): 전체 시간복잡도는 $O(N \log N + M \log N)$ ($M$은 쿼리 개수.)

BOJ 3584 (가장 가까운 공통 조상): 루트 노드를 찾는 과정이 필요하다. parent를 -1로 초기화하고, parent[x] (lazy version) 혹은 parent[x][0]의 값이 -1인 x가 루트 노드이다. (부모 노드가 없다는 뜻)

BOJ 1761 (정점들의 거리): lca를 구하고, dist[x] + dist[y] - 2 * dist[lca(x, y)]

728x90
반응형