본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 1761번 정점들의 거리 - LCA, DP (Java)

    #1761 정점들의 거리

    난이도 : 플레 5

    유형 : 트리 / LCA / DP

     

    1761번: 정점들의 거리

    첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩

    www.acmicpc.net

    ▸ 문제

    N(2 ≤ N ≤ 40,000)개의 정점으로 이루어진 트리가 주어지고 M(1 ≤ M ≤ 10,000)개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.

     입력

    첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩 입력된다. 두 점 사이의 거리는 10,000보다 작거나 같은 자연수이다.

    정점은 1번부터 N번까지 번호가 매겨져 있다.

     출력

    M개의 줄에 차례대로 입력받은 두 노드 사이의 거리를 출력한다.

     

    문제 풀이  

    그냥 일반적으로 DFS를 통해 두 정점 a,b의 거리를 구하게 되면 만약 트리가 편향트리일 경우 대략 (N-1)*M정도의 연산이 발생하므로 시간초과가 발생하게 된다.

     

    어떻게든 연산속도를 향상시켜야 한다. 정점과 쿼리의 수를 줄일 수는 없으니 두 정점의 거리를 구하는 연산 로직의 시간 복잡도를 줄여주는 방향으로 생각해 볼 수 있다. 문제에서 트리의 구조로 데이터가 주어졌기 때문에 트리를 활용한 알고리즘을 생각해보면, 두 정점이 연관된 알고리즘으로 LCA(최소 공통 조상) 알고리즘을 떠올려 볼 수 있다.

     

    LCA(최소 공통 조상) 알고리즘은 두 정점의 부모 노드를 구하는 알고리즘으로 정점 N개가 있을 때 O(logN)으로 쿼리를 처리해낼 수 있다. 따라서 O(MlogN)의 시간복잡도로 해당 문제를 해결할 수 있게 된다. 그럼 이제 어떻게 최소 공통 조상 노드의 데이터로 두 정점간의 거리를 구할 수 있을까? 트리의 구조를 한 번 그려보면 쉽게 알아낼 수 있다.

     

    2번 노드를 최소 공통 조상으로 갖는 5번과 9번 노드의 거리는 다음과 같이 이어져있다. 문제에서 주어진 데이터는 5,9번 노드의 값이고 우리가 LCA알고리즘을 통해 구한 노드는 2번 노드라고 보면 된다.

     

    1번 루트 노드를 기준으로 각 노드의 거리를 구한 데이터를 dis[]에 저장했다고 하면 각 데이터는 다음과 같이 표현할 수 있다.

    • dis[2] = A
    • dis[5] = A+C
    • dis[9]=  A+B+D

    우리가 구하고자 하는 값은 5와 9사이의 거리은 B+C+D이다. 따라서 이를 구하기 위해서 단순히 두 정점의 거리에서 중복되는 최소 공통 조상 노드의 거리*2를 빼주면 된다.

    • 정점 5, 9 사이의 거리 = dis[5]+dis[9] - 2*dis[2] = B+C+D

    정점 5,9 사이의 거리 구하기

     

     

    설계

    • 주어진 트리 데이터로 LCA구하는 데 필요한 데이터를 만든다. 루트노드가 주어지지 않았으므로 임의로 1번 노드를 루트노드라고 가정하고 푼다. init(1,1,0);
      • depth[정점] : 해당 정점이 위치한 트리 높이를 구해준다.
      • dp[정점][log(트리 높이)] : 해당 정점의 log(트리 높이)에 해당하는 부모 노드를 구해준다.
        • log(0) = 1층 위에 있는 부모 노드를 구해줬으면 나머지 log(h)높이의 부모노드의 값도 구해준다. fillParents();
      • dis[정점] : 1번을 루트 노드로 가정하고 1번에서 정점노드 까지의 거리를 구해준다.
    • 두 정점 a,b가 주어지면 LCA 알고리즘을 통해 공통 조상 노드를 구한 후 거리를 계산해준다. LCA(a,b);
      • dis[a] + dis[b] - 2*dis[LCA(a,b)]

     

    풀이 코드 

    세그먼트 트리를 활용하여 LCA 알고리즘을 푸는 방법도 존재한다. DP는 O(NlogN + MlogN)이고 세그먼트 트리는 O(N + MlogN)으로 세그먼트 트리가 더 빠르지만 과정이 꽤나 복잡하여 정말 높은 효율성을 요구하는 설계가 아니면 DP로 해결해도 될 것 같으니 알아두기만 하자.

    import java.io.*;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.StringTokenizer;
    
    public class Main {
    
    	static class Node{
    		int to;
    		int w;
    		
    		public Node(int to, int w) {
    			this.to = to;
    			this.w = w;
    		}
    	}
    	
    	static int n,h;
    	static List<Node>[] list;
    	static int[][] dp;
    	static int[] dis;
    	static int[] depth;
    	static StringBuilder sb = new StringBuilder();
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		StringTokenizer st =null;
    		
    		list = new ArrayList[n+1];
    		for(int i=0; i<n+1; i++) {
    			list[i] = new ArrayList<>();
    		}
    		
    		for(int i=0; i<n-1; i++) {
    			st = new StringTokenizer(br.readLine());
    			int from = Integer.parseInt(st.nextToken());
    			int to = Integer.parseInt(st.nextToken());
    			int w= Integer.parseInt(st.nextToken());
    			
    			list[from].add(new Node(to,w));
    			list[to].add(new Node(from,w));
    		}
    		
    		h = getTreeH();
    		depth = new int[n+1];
    		dis = new int[n+1];
    		dp = new int[n+1][h];
    		
    		init(1,1,0);
    		fillParents();
    		
    		int m = Integer.parseInt(br.readLine());
    		for(int i=0; i<m; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			
    			int res = LCA(a,b);
    			sb.append(dis[a] + dis[b] -2*dis[res]).append("\n");
    			
    		}
    		
    		System.out.println(sb.toString());
    		
    	}
    	static int getTreeH() {
    		return (int)Math.ceil(Math.log(n)/Math.log(2))+1;
    	}
    	
    	static void init(int cur, int h, int pa) {
    		depth[cur] = h;
    		for(Node nxt : list[cur]) {
    			if(nxt.to!=pa) {
    				dis[nxt.to] = dis[cur] + nxt.w;
    				init(nxt.to, h+1, cur);
    				dp[nxt.to][0] = cur;
    			}
    		}
    	}
    	
    	static void fillParents() {
    		for(int i=1; i<h; i++) {
    			for(int j=1; j<n+1; j++) {
    				dp[j][i] = dp[dp[j][i-1]][i-1];
    			}
    		}
    	}
    	
    	static int LCA(int a, int b) {
    		int ah = depth[a];
    		int bh = depth[b];
    		
    		if(ah < bh) {
    			int tmp =a;
    			a= b;
    			b= tmp;
    		}
    		
    		for(int i=h-1; i>=0; i--) {
    			if(Math.pow(2, i) <= depth[a] - depth[b]) {
    				a = dp[a][i];
    			}
    		}
    		
    		if(a==b) return a;
    		
    		for(int i=h-1; i>=0; i--) {
    			if(dp[a][i] != dp[b][i]) {
    				a = dp[a][i];
    				b = dp[b][i];
    			}
    		}
    		return dp[a][0];
    	}
    }