본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 11438번 LCA2 - DP (Java)

#11438 LCA2

난이도 : 플레 5

유형 : 트리 / LCA 

 

11438번: LCA 2

첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정

www.acmicpc.net

▸ 문제

N(2 ≤ N ≤ 100,000)개의 정점으로 이루어진 트리가 주어진다. 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다.

두 노드의 쌍 M(1 ≤ M ≤ 100,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력한다.

 입력

첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정점 쌍이 주어진다.

 출력

M개의 줄에 차례대로 입력받은 두 정점의 가장 가까운 공통 조상을 출력한다.

 

문제 풀이  

LCA와 다른 점은 N과 M의 범위가 더 커졌다는 점이다. 더 효율적인 탐색을 위해 DP나 세그먼트 트리를 사용해줘야 한다.

 

LCA + DP 

일반적으로 풀이 할 경우 계속해서 LCA를 구할 a,b노드가 주어질 때마다 계속해서 새롭게 탐색을 해줘야해서 최악의 경우 O(NM)의 시간복잡도가 발생한다. 이를 DP를 사용하여 반복되는 연산을 줄여주는 방식으로 효율적인 알고리즘을 설계할 수 있다.

 

DP 값 할당

  • dp[cur][h] = 해당 cur 노드의 2^h번째 부모노드를 저장해준다.

1. 트리의 최대 높이(h)를 구해준다.

(편향 트리는 고려하지 않고 최소 이진 트리라고 가정하고 높이를 구해줬는데 통과했다.)

static int getTreeHeight() {
	return(int)Math.ceil(Math.log(n)/Math.log(2)) +1;
}

 

2. DFS탐색을 통해 각 노드의 높이(depth)와  2^0(1)번째 부모노드의 값으로 초기화시켜준다. dp[cur][0] = 1번째 부모노드

static void init(int cur, int h, int pa) {
	depth[cur] = h;
	for(int nxt : list[cur]) {
		if(nxt != pa) {
			init(nxt, h+1, cur);
			parent[nxt][0] = cur; // nxt의 부모 cur 
		}
	}
}

 

3. 나머지 2^0, 2^1, ... , 2^h-1번째의 부모노드를 채워준다.

static void fillParents() {
	for(int i=1; i<h; i++) {
		for(int j=1; j<n+1; j++) {
			parent[j][i] = parent[parent[j][i-1]][i-1];
		}
	}
}

 

 

LCA 구하기

dp값을 모두 할당해줬으면 이제 해당 데이터로 LCA를 구해주면 된다.

  1. a와 b노드가 주어지면 해당 노드의 높이가 낮은 노드를 기준으로 높이를 맞춰준다.
    1. 높이를 맞췄는데 a==b이면, LCA = a이므로 바로 출력해준다. (LCA가 1일때 예외처리이기도 함)
  2. a와 b노드의 dp값을 비교해가며 LCA를 찾아준다.
static int LCA(int a, int b) {
	int ah = depth[a];
	int bh = depth[b];
	 // ah > bh로 세팅
	if(ah < bh) {
		int tmp = a;
		a = b;
		b = tmp;
	} 
	// 1. 높이 맞추기 
	for (int i=h-1; i>=0; i--) {
		if(Math.pow(2, i) <= depth[a] - depth[b]){
			a = parent[a][i];
		}
	}
	if(a==b) return a;
    
	// 2. LCA찾기
	for(int i=h-1; i>=0; i--) {
		if(parent[a][i] != parent[b][i]) {
			a = parent[a][i];
			b = parent[b][i];
		}
	}
	return parent[a][0];
}

 

시뮬레이션

해당 트리의 dp값을 할당해보자.

예제 트리

dp[cur][log(h)]의 값은 다음과 같이 저장된다. 

h 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
1 0 1 1 2 2 2 3 3 4 4 5 5 7 7 11
2 0 0 0 1 1 1 1 1 2 2 2 2 3 3 5
4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1

☛ 예로, 15번의 4번째 부모노드 dp[15][log4]= dp[15][2] = 1 임을 알 수 있다.

 

 

그럼 이제 6번, 11번 노드의 LCA를 구해보자.

1. 11번 노드의 높이가 더 높으므로 11번 노드를 a로 설정해준다. a = 11, b= 6

 

 

2. 11번 노드와 6번 노드의 높이를 맞춰준다. if(Math.pow(2, i) <= depth[a] - depth[b]) a = parent[a][i]

dp에는 2^i번째의 부모노드들을 저장했기 때문에 높이가 더 높은 a번 노드를 2^i씩 점프하여 depth를 맞춰준다.  

i=0일때 성립되므로, a= parent[a][0] = 5

 

3. 이제 두 노드의 높이가 같으므로 dp값을 비교하면서 LCA를 구해준다.

6번, 11번 노드의 LCA는 2

 

 

풀이 코드 

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;

public class Main {
	static int n, h;
	static List<Integer>[] list;
	static int[][] parent;
	static int[] depth;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		n = Integer.parseInt(br.readLine());
		
		list = new ArrayList[n+1];
		for(int i=1; i<n+1; i++) {
			list[i] = new ArrayList<>();
		}
		StringTokenizer st = null;
		for(int i=0; i<n-1; i++) {
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			list[a].add(b);
			list[b].add(a);
		}
		
		h = getTreeHeight();
		depth = new int[n+1];
		parent = new int[n+1][h];
 
		init(1,1,0);
		fillParents();

		StringBuilder sb = new StringBuilder();
		int m = Integer.parseInt(br.readLine());
		for(int i=0; i<m; i++) {
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			sb.append(LCA(a,b)+"\n");
		}
		System.out.println(sb.toString());
	}
	
	static int getTreeHeight() {
		return(int)Math.ceil(Math.log(n)/Math.log(2)) +1;
	}
	static void init(int cur, int h, int pa) {
		depth[cur] = h;
		for(int nxt : list[cur]) {
			if(nxt != pa) {
				init(nxt, h+1, cur);
				parent[nxt][0] = cur; // nxt의 부모 cur 
			}
		}
	}
	
	static void fillParents() {
		for(int i=1; i<h; i++) {
			for(int j=1; j<n+1; j++) {
				parent[j][i] = parent[parent[j][i-1]][i-1];
			}
		}
	}
	
	static int LCA(int a, int b) {
		int ah = depth[a];
		int bh = depth[b];
		 // ah > bh로 세팅
		if(ah < bh) {
			int tmp = a;
			a = b;
			b = tmp;
		} 
		
		for (int i=h-1; i>=0; i--) {
			if(Math.pow(2, i) <= depth[a] - depth[b]){
				a = parent[a][i];
			}
		}
		if(a==b) return a;
		
		for(int i=h-1; i>=0; i--) {
			if(parent[a][i] != parent[b][i]) {
				a = parent[a][i];
				b = parent[b][i];
			}
		}
		return parent[a][0];
	}
}