본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 13511번 트리와 쿼리2 (Java)

    #13511 트리와 쿼리2 

    난이도 : 플레 3

    유형 : 트리 / LCA

     

    13511번: 트리와 쿼리 2

    N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다. 아래의 두 쿼리를 수행하

    www.acmicpc.net

    ▸ 문제

    N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다.

    아래의 두 쿼리를 수행하는 프로그램을 작성하시오.

    • 1 u v: u에서 v로 가는 경로의 비용을 출력한다.
    • 2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.

     입력

    첫째 줄에 N (2 ≤ N ≤ 100,000)이 주어진다.

    둘째 줄부터 N-1개의 줄에는 i번 간선이 연결하는 두 정점 번호 u와 v와 간선의 비용 w가 주어진다.

    다음 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.

    다음 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다.

    간선의 비용은 항상 1,000,000보다 작거나 같은 자연수이다.

     출력

    각각의 쿼리의 결과를 순서대로 한 줄에 하나씩 출력한다.

     

    문제 풀이  

    트리의 장점은 탐색 속도가 배열이나 리스트보다 훨씬 빠르다는 것이다. 그래서 쿼리를 로그 함수 속도로 빠르게 처리할 수 있다.  그래서 최소공통조상(LCA) 알고리즘을 응용하여 해당 문제를 풀이해주면 된다. 그런데 그것을 알아도 구현에 꽤나 시간이 걸리는 문제이다. 쿼리가 2가지나 주어졌고 하나만 최적화하기도 힘든데 두 쿼리 모두 트리 탐색의 장점을 살려서 구현해야하기 때문이다.

     

    1번 쿼리

    1번 쿼리는 각 비용을 dp배열에 전위순회 탐색으로 저장한 다음 cost[u] + cost[v] - 2*cost[lca]로 구할 수 있다.  그림을 그려보고 계산해보면 해당 풀이를 이해하기 쉽다.  아래의 트리에서 9번에서 5번으로 이동하는 거리의 비용을 계산하려면 다음과 같이 구해주면 된다.

    • cost[9] = A + B + D
    • cost[5] = A + C
    • cost[2] = A
    • 9~5 이동 비용 : B+C+D = cost[9] + cost[5] - 2*cost[2]

    백준 1716번 문제는 해당 1번 쿼리만 구현하는 문제로 연습삼아 풀기 좋다.

     

    정점 9에서 5로 이동하는 비용 구하기

     

     

     

    2번 쿼리

    2번 쿼리는 u에서 v로 가는 정점 중 k번째 정점을 출력해주면 된다. 이또한 lca를 활용하여 쿼리 시간을 단축할 수 있다.

    1. 9번 노드와 5번노드 사이에 있는 k번째 노드를 구한다고 하자. 그러면 LCA를 기준으로 왼쪽 트리와 오른쪽 트리로 나뉘게 되므로 LCA의 순서를 구해준다.
      1. int mid = 9번 depth - 2번(lca) depth +1 = xh - depth[lca] + 1 = 3, 3번째 순서가 LCA임을 구할 수 있다.
    2. if(mid > k) k== 1 or 2,  k번째 노드는 왼쪽 트리에 존재한다.
      1. k-1번째에 존재하는 노드는 9번노드의 k-1번째 부모노드이므로, parents 배열에 저장되어 있는 값으로 구해준다.
    3. if(mid< k) k == 4, k번째 노드는 오른쪽 트리에 존재한다.
      1. 오른쪽 트리 또한 왼쪽트리와 같은 방식으로 구해주면 된다. 
      2. k = mid + yh - depth[lca] - k;

     

    정점 9에서 5의 경로 중 k번째 노드 구하기

     

    오른쪽 트리의 k번째 노드를 구할 때, k = mid + yh - depth[lca] - k인 이유는 다음과 같다.

    • xh + yh - 2*depth[lca]+1은 x와 y사이에 존재하는 노드의 개수이다.
    • 거기서 k를 빼주면 순서가 반대방향으로 바뀌게 된다.

     

    풀이 코드 

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.StringTokenizer;
    
    public class Main {
    	
    	static class Node{
    		int to;
    		int cost;
    		
    		public Node(int to, int cost) {
    			this.to = to;
    			this.cost = cost;
    		}
    	}
    	static int n,h;
    	static List<Node>[] list;
    	static int[] depth;
    	static long[] cost;
    	static int[][] parents;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		h = (int)Math.ceil(Math.log(n)/Math.log(2)) +1;
    		list = new ArrayList[n+1];
    		parents = new int[n+1][h];
    		depth = new int[n+1];
    		cost = new long[n+1];
    		for(int i=1; i<n+1; i++) {
    			list[i] = new ArrayList<>();
    		}
    		StringTokenizer st;
    		for(int i=0; i<n-1; i++) {
    			st = new StringTokenizer(br.readLine());
    			int u = Integer.parseInt(st.nextToken());
    			int v = Integer.parseInt(st.nextToken());
    			int w = Integer.parseInt(st.nextToken());
    			
    			list[u].add(new Node(v,w));
    			list[v].add(new Node(u,w));
    			
    		}
    		
    		init(1,0,-1);
    		fillParents();
    		StringBuilder sb = new StringBuilder();
    		int m = Integer.parseInt(br.readLine());
    		for(int i=0; i<m; i++) {
    			st = new StringTokenizer(br.readLine());
    			int op = Integer.parseInt(st.nextToken());
    			int u = Integer.parseInt(st.nextToken());
    			int v = Integer.parseInt(st.nextToken());
    			
    			int lca = LCA(u,v);
    			if(op == 1) {
    				sb.append(cost[u] + cost[v] - 2*cost[lca] +"\n");
    			}else {
    				int k = Integer.parseInt(st.nextToken());
    				sb.append(kNode(u, v, lca, k) +"\n");
    			}
    		}
    		System.out.println(sb.toString());
    	}
    	static int kNode(int x, int y, int root, int k) {
    		int xh = depth[x];
    		int yh = depth[y];
    		
    		int mid = xh-depth[root]+1;
    		int tmp= 0;
    		if(mid == k) { // mid 
    			return root;
    		}else if(mid > k) { // left
    			k -=1;
    			tmp = x;
    		}else { // right
    			k = mid + yh - depth[root] - k;
    			tmp = y;
    		}
    		
    		for(int i=h-1; i>=0; i--) {
    			if((k & (1<<i)) !=0) { 
    				k ^= (1<<i); // 2^i번째 부모로 이동 
    				tmp = parents[tmp][i];
    			}
    		}
    		return tmp;
    	}
    	
    	static int LCA(int x, int y) {
    		int xh = depth[x];
    		int yh = depth[y];
    		
    		// make (xh > yh)
    		if(xh <yh) {
    			int tmp  = x;
    			x = y;
    			y = tmp;
    			xh = depth[x];
    			yh = depth[y];
    		}
    		
    		// matching depth 
    		for(int i=h-1; i>=0; i--) {
    			if(Math.pow(2, i) <= depth[x]-depth[y]) {
    				x = parents[x][i];
    			}
    		}
    		if(x==y) return x;
            
    		// find LCA
    		for(int i=h-1; i>=0; i--) {
    			if(parents[x][i] != parents[y][i]) {
    				x = parents[x][i];
    				y = parents[y][i];
    			}
    		}
    		return parents[x][0];
    	}
    	
    	static void init(int cur, int h, int pa) {
    		depth[cur] = h;
    		for(Node nxt : list[cur]) {
    			if(nxt.to != pa) {
    				cost[nxt.to] += cost[cur] + nxt.cost;
    				init(nxt.to, h+1, cur);
    				parents[nxt.to][0] = cur; // nxt의 부모 cur 
    			}
    		}
    	}
    	
    	static void fillParents() {
    		for(int i=1; i<h; i++) {
    			for(int j=1; j<n+1; j++) {
    				parents[j][i] = parents[parents[j][i-1]][i-1];
    			}
    		}
    	}
    }