본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 11438번 LCA2 - DP (Java)

    #11438 LCA2

    난이도 : 플레 5

    유형 : 트리 / LCA 

     

    11438번: LCA 2

    첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정

    www.acmicpc.net

    ▸ 문제

    N(2 ≤ N ≤ 100,000)개의 정점으로 이루어진 트리가 주어진다. 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다.

    두 노드의 쌍 M(1 ≤ M ≤ 100,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력한다.

     입력

    첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정점 쌍이 주어진다.

     출력

    M개의 줄에 차례대로 입력받은 두 정점의 가장 가까운 공통 조상을 출력한다.

     

    문제 풀이  

    LCA와 다른 점은 N과 M의 범위가 더 커졌다는 점이다. 더 효율적인 탐색을 위해 DP나 세그먼트 트리를 사용해줘야 한다.

     

    LCA + DP 

    일반적으로 풀이 할 경우 계속해서 LCA를 구할 a,b노드가 주어질 때마다 계속해서 새롭게 탐색을 해줘야해서 최악의 경우 O(NM)의 시간복잡도가 발생한다. 이를 DP를 사용하여 반복되는 연산을 줄여주는 방식으로 효율적인 알고리즘을 설계할 수 있다.

     

    DP 값 할당

    • dp[cur][h] = 해당 cur 노드의 2^h번째 부모노드를 저장해준다.

    1. 트리의 최대 높이(h)를 구해준다.

    (편향 트리는 고려하지 않고 최소 이진 트리라고 가정하고 높이를 구해줬는데 통과했다.)

    static int getTreeHeight() {
    	return(int)Math.ceil(Math.log(n)/Math.log(2)) +1;
    }

     

    2. DFS탐색을 통해 각 노드의 높이(depth)와  2^0(1)번째 부모노드의 값으로 초기화시켜준다. dp[cur][0] = 1번째 부모노드

    static void init(int cur, int h, int pa) {
    	depth[cur] = h;
    	for(int nxt : list[cur]) {
    		if(nxt != pa) {
    			init(nxt, h+1, cur);
    			parent[nxt][0] = cur; // nxt의 부모 cur 
    		}
    	}
    }

     

    3. 나머지 2^0, 2^1, ... , 2^h-1번째의 부모노드를 채워준다.

    static void fillParents() {
    	for(int i=1; i<h; i++) {
    		for(int j=1; j<n+1; j++) {
    			parent[j][i] = parent[parent[j][i-1]][i-1];
    		}
    	}
    }

     

     

    LCA 구하기

    dp값을 모두 할당해줬으면 이제 해당 데이터로 LCA를 구해주면 된다.

    1. a와 b노드가 주어지면 해당 노드의 높이가 낮은 노드를 기준으로 높이를 맞춰준다.
      1. 높이를 맞췄는데 a==b이면, LCA = a이므로 바로 출력해준다. (LCA가 1일때 예외처리이기도 함)
    2. a와 b노드의 dp값을 비교해가며 LCA를 찾아준다.
    static int LCA(int a, int b) {
    	int ah = depth[a];
    	int bh = depth[b];
    	 // ah > bh로 세팅
    	if(ah < bh) {
    		int tmp = a;
    		a = b;
    		b = tmp;
    	} 
    	// 1. 높이 맞추기 
    	for (int i=h-1; i>=0; i--) {
    		if(Math.pow(2, i) <= depth[a] - depth[b]){
    			a = parent[a][i];
    		}
    	}
    	if(a==b) return a;
        
    	// 2. LCA찾기
    	for(int i=h-1; i>=0; i--) {
    		if(parent[a][i] != parent[b][i]) {
    			a = parent[a][i];
    			b = parent[b][i];
    		}
    	}
    	return parent[a][0];
    }

     

    시뮬레이션

    해당 트리의 dp값을 할당해보자.

    예제 트리

    dp[cur][log(h)]의 값은 다음과 같이 저장된다. 

    h 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
    1 0 1 1 2 2 2 3 3 4 4 5 5 7 7 11
    2 0 0 0 1 1 1 1 1 2 2 2 2 3 3 5
    4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1

    ☛ 예로, 15번의 4번째 부모노드 dp[15][log4]= dp[15][2] = 1 임을 알 수 있다.

     

     

    그럼 이제 6번, 11번 노드의 LCA를 구해보자.

    1. 11번 노드의 높이가 더 높으므로 11번 노드를 a로 설정해준다. a = 11, b= 6

     

     

    2. 11번 노드와 6번 노드의 높이를 맞춰준다. if(Math.pow(2, i) <= depth[a] - depth[b]) a = parent[a][i]

    dp에는 2^i번째의 부모노드들을 저장했기 때문에 높이가 더 높은 a번 노드를 2^i씩 점프하여 depth를 맞춰준다.  

    i=0일때 성립되므로, a= parent[a][0] = 5

     

    3. 이제 두 노드의 높이가 같으므로 dp값을 비교하면서 LCA를 구해준다.

    6번, 11번 노드의 LCA는 2

     

     

    풀이 코드 

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.StringTokenizer;
    
    public class Main {
    	static int n, h;
    	static List<Integer>[] list;
    	static int[][] parent;
    	static int[] depth;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		
    		list = new ArrayList[n+1];
    		for(int i=1; i<n+1; i++) {
    			list[i] = new ArrayList<>();
    		}
    		StringTokenizer st = null;
    		for(int i=0; i<n-1; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			list[a].add(b);
    			list[b].add(a);
    		}
    		
    		h = getTreeHeight();
    		depth = new int[n+1];
    		parent = new int[n+1][h];
     
    		init(1,1,0);
    		fillParents();
    
    		StringBuilder sb = new StringBuilder();
    		int m = Integer.parseInt(br.readLine());
    		for(int i=0; i<m; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			sb.append(LCA(a,b)+"\n");
    		}
    		System.out.println(sb.toString());
    	}
    	
    	static int getTreeHeight() {
    		return(int)Math.ceil(Math.log(n)/Math.log(2)) +1;
    	}
    	static void init(int cur, int h, int pa) {
    		depth[cur] = h;
    		for(int nxt : list[cur]) {
    			if(nxt != pa) {
    				init(nxt, h+1, cur);
    				parent[nxt][0] = cur; // nxt의 부모 cur 
    			}
    		}
    	}
    	
    	static void fillParents() {
    		for(int i=1; i<h; i++) {
    			for(int j=1; j<n+1; j++) {
    				parent[j][i] = parent[parent[j][i-1]][i-1];
    			}
    		}
    	}
    	
    	static int LCA(int a, int b) {
    		int ah = depth[a];
    		int bh = depth[b];
    		 // ah > bh로 세팅
    		if(ah < bh) {
    			int tmp = a;
    			a = b;
    			b = tmp;
    		} 
    		
    		for (int i=h-1; i>=0; i--) {
    			if(Math.pow(2, i) <= depth[a] - depth[b]){
    				a = parent[a][i];
    			}
    		}
    		if(a==b) return a;
    		
    		for(int i=h-1; i>=0; i--) {
    			if(parent[a][i] != parent[b][i]) {
    				a = parent[a][i];
    				b = parent[b][i];
    			}
    		}
    		return parent[a][0];
    	}
    }