Algorithm/PS 알고리즘 정리

[알고리즘] 최소 공통 조상 LCA(Lowest Common Ancestor)

샤아이인 2023. 12. 8. 17:57

 

1. LCA란?

LCA(Lowest Common Ancestor)는 주어진 두 노드 a와 b의 최소 공통 조상을 찾는 알고리즘입니다.

 

예를 들어 다음 그림과 같은 트리가 있다고 했을 때, 12번과 15번 노드의 최소 공통 조상 LCA는 5번 노드가 됩니다.

 

우선 두 노드의 LCA를 찾는 가장 간단한 O(N)이 걸리 방법으로 2가지 정도가 있다.

2가지에 대하여 간단하게 알아본 후, 이글의 Main 목표인 O(LogN) 안에 해결하는 알고리즘에 대하여 알아보자.

 

2. O(N) 알고리즘

2-1) 두 Node의 Level을 통일시키는 방식

  1. A와 B의 깊이가 다를 경우 더 깊은 정점의 부모를 따라 하나씩 올라가면서 A와 B의 깊이를 맞춘다.
  2. 위의 결과 A와 B가 같다면 A(or B)가 최소 공통조상이다.
  3. A != B 인 경우 A와 B를 동시에 부모를 따라 한 칸씩 올라가면서 A와 B가 같아진다면 그 정점이 LCA이다.

​위의 메커니즘은 LCA를 구하는 과정에서 하나씩 올라가는 특징을 가진다.

즉, 트리의 깊이가 H라면 O(H)의 시간복잡도를 가진다. 최악의 경우 (skewed형 트리) O(N)이 된다.

 

2-2) DFS를 사용하는 방식

  1. 시작 정점 a로부터 DFS, BFS 탐색을 시작한다
  2. 모든 Node를 방문하면, 최종적으로 a부터 b까지 가는 Path를 찾을 수 있다.
  3. 찾은 Path를 따라 이동하면서 방문하는 모든 Node를 저장한다.
  4. 해당 Path를 방문하면서 저장된 Node들 중 Level이 가장 작은 값을 갖는 Node가 LCA이다.

이 방식 또한 DFS 한 번에 O(N)의 시간이 걸린다.

조금 더 빠른 방법이 없을까? 이에 대하여 알아보자!

 

3. O(logN) 알고리즘

3-1) 대략적인 이해

 

필기해 둔 부분을 다시 글로 적어보면

  1. 일단 비교하고자 하는 두 Node a와 b의 높이를 동일하게 만들어준다.
  2. 두 노드로부터 2^i(2의 i승) 만큼 멀리 떨어진 노드부터 2^0승까지 노드를 비교한다.
    1. 비교하는 노드의 번호가 같다면, 즉 (a로부터 2^j만큼 떨어진 조상) == (b로부터 2^j만큼 떨어진 조상)이라면 2^(j-1) 승을 다시 비교한다.
    2. 비교하는 노드의 번호가 다르다면, 해당 지점부터 위로 탐색을 진행한다.
  3. 위로 계속 탐색을 진행하다 처음으로 동일한 Node가 나온다면, 해당 Node가 LCA이다.

이를 그림으로 확인해 보면 다음과 같다.

3-1-1) 시작 트리

다음 트리의 16번 Node와 13번 Node의 LCA를 찾아보자!

 

3-1-2) 시작 높이 동일하게 만들기

두 비교 노드의 높이 중 더 낮은 Level 4로 동일하게 만들면 각각 12와 13번 노드가 된다.

 

3-1-3) 서로 처음으로 달라지는 조상 노드로 이동

현재 보고 있는 두 정점의 높이를 맞춘 후(12, 13번)에는 lca를 확인하기 위해 서로의 조상을 비교하는 작업도 2^i씩 거슬러 올라갈 수 있습니다. 이때 반복문을 통하여 i(height-1)부터 -> j -> 0까지, 즉 12, 13번 노드로부터 2^i만큼 위에 있는 노드부터 ~ 2^0만큼 위에 있는 노드까지 비교를 하되, 2^j번째 조상이 서로 달라지는 순간 2^j번째 조상으로 동시에 점프하게 됩니다.

따라서 위 그림에서는 12, 13번 노드에서 2^2만큼 위에 있는 노드인 1번은 공통 노드이니,

다음으로 2^1만큼 위에 있는 조상을 비교하면 6과 7로 처음으로 다른 값이 나오며, 해당 6과 7로 각각 이동하게 됩니다.

 

3-1-4) 위로 계속 탐색하여 공통 조상 찾기

6과 7을 시작으로 계속 위로 올라가면서 비교하여 , 공통의 Node가 처음 나오면 해당 Node가 LCA에 해당됩니다. 

 

3-2) 전처리

위와 같은 알고리즘을 작성하기 전에 전처리 과정이 필요하다.

dp 2차원 배열에 해당 cur 노드의 2^h번째 부모노드를 저장해 줌으로써 연산 횟수를 줄여주고 중복되는 연산을 제거해 준다.

  • parent[cur][h]

 

1. DFS탐색을 통해 각 노드의 높이(depth)와  2^0(1) 번째 부모노드의 값으로 초기화시켜 준다. 

parent[nextNode][0] = nowNode  해당 코드 부분에서 "나"를 기준으로 자식을 호출하여 "나"로 부모를 초기화하는 작업이다.

public static void dfs(int nowNode, int nowDepth, int parentNode) {
    depth[nowNode] = nowDepth;

    for (Integer nextNode : adj[nowNode]) {
        if (nextNode != parentNode) {
            parent[nextNode][0] = nowNode;
            dfs(nextNode, nowDepth + 1, nowNode);
        }
    }
}

 

2. 나머지 2^0, 2^1,... , 2^h-1번째의 부모노드를 채워준다.

public static void fillParents() {
    for (int i = 1; i <= 20; i++) {
        for (int j = 1; j <= N; j++) {
            parent[j][i] = parent[parent[j][i - 1]][i - 1];
        }
    }
}

 

3-3) LCA 구하기

Parent값을 모두 할당해 줬으면 이제 해당 데이터로 LCA를 구해주면 된다.

  1. a와 b노드가 주어지면 해당 노드의 높이가 낮은 노드를 기준으로 높이를 맞춰준다.
    1. 높이를 맞췄는데 a==b이면, LCA = a이므로 바로 출력해 준다. (LCA가 1일 때 예외처리이기도 함)
  2. a와 b노드의 parent값을 비교해 가며 LCA를 찾아준다.
private static int LCA(int a, int b) {
    int aDepth = depth[a];
    int bDepth = depth[b];

    if (aDepth < bDepth) { // a의 Depth가 더 크도록
        int tmp = a;
        a = b;
        b = tmp;
    }

    for (int i = 20; i >= 0; i--) { // 높이 동일하게
        if (depth[a] - depth[b] >= (1 << i)) {
            a = parent[a][i];
        }
    }
    if (a == b) return a; // 높이를 맞췄는데 공통 조상인 경우

    for (int i = 20; i >= 0; i--) { // LCA 찾기
        if (parent[a][i] != parent[b][i]) {
            a = parent[a][i];
            b = parent[b][i];
        }
    }

    return parent[a][0];
}

 

4. 관련 문제

https://www.acmicpc.net/problem/11438

 

11438번: LCA 2

첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정

www.acmicpc.net

 

풀이

import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;

public class Main {

    private static int N;
    private static int M;
    private static List<Integer> adj[];
    private static int[][] parent;
    private static int[] depth;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));

        N = Integer.parseInt(br.readLine().trim());

        initArray();
        initGraph(br);

        dfs(1, 1, 0);
        fillParents();
        String result = findAllLCA(br);

        bw.write(result);
        bw.flush();
        bw.close();
        br.close();
    }

    private static int LCA(int a, int b) {
        int aDepth = depth[a];
        int bDepth = depth[b];

        if (aDepth < bDepth) { // a의 Depth가 더 크도록
            int tmp = a;
            a = b;
            b = tmp;
        }

        for (int i = 20; i >= 0; i--) { // 높이 동일하게
            if (depth[a] - depth[b] >= (1 << i)) {
                a = parent[a][i];
            }
        }
        if (a == b) return a; // 높이를 맞췄는데 공통 조상인 경우

        for (int i = 20; i >= 0; i--) { // LCA 찾기
            if (parent[a][i] != parent[b][i]) {
                a = parent[a][i];
                b = parent[b][i];
            }
        }

        return parent[a][0];
    }

    public static void dfs(int nowNode, int nowDepth, int parentNode) {
        depth[nowNode] = nowDepth;

        for (Integer nextNode : adj[nowNode]) {
            if (nextNode != parentNode) {
                parent[nextNode][0] = nowNode;
                dfs(nextNode, nowDepth + 1, nowNode);
            }
        }
    }

    public static void fillParents() {
        for (int i = 1; i <= 20; i++) {
            for (int j = 1; j <= N; j++) {
                parent[j][i] = parent[parent[j][i - 1]][i - 1];
            }
        }
    }

    private static void initArray() {
        parent = new int[N + 1][21];
        depth = new int[N + 1];
        adj = new List[N + 1];
        for (int i = 0; i <= N; i++) {
            adj[i] = new ArrayList<>();
        }
    }

    private static void initGraph(BufferedReader br) throws IOException {
        StringTokenizer st = null;
        for (int i = 0; i < N - 1; i++) {
            st = new StringTokenizer(br.readLine().trim());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            adj[a].add(b);
            adj[b].add(a);
        }
    }

    private static String findAllLCA(BufferedReader br) throws IOException {
        StringBuilder sb = new StringBuilder();

        StringTokenizer st = null;
        M = Integer.parseInt(br.readLine().trim());
        for (int i = 0; i < M; i++) {
            st = new StringTokenizer(br.readLine().trim());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            sb.append(LCA(a, b)).append("\n");
        }

        return sb.toString();
    }
}