본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 11438번 LCA2 - SegmentTree (Java)

    #11438 LCA2

    난이도 : 플레 5

    유형 : 트리 / LCA / 세그먼트 트리

     

    11438번: LCA 2

    첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정

    www.acmicpc.net

    ▸ 문제

    N(2 ≤ N ≤ 100,000)개의 정점으로 이루어진 트리가 주어진다. 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다.

    두 노드의 쌍 M(1 ≤ M ≤ 100,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력한다.

     입력

    첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정점 쌍이 주어진다.

     출력

    M개의 줄에 차례대로 입력받은 두 정점의 가장 가까운 공통 조상을 출력한다.

     

    문제 풀이  

    세그먼트 트리를 활용한 LCA 풀이법이 있다고하여 공부를 해봤다. 세그먼트 트리 리프노드에 트리의 전위 순회한 방문 순서를 노드를 기준으로 저장하여 최소 세그먼트 트리의 형태로 부모 노드에 저장하는 식으로 형성된다.

     

    예를 들어, 1 2 3 4 3 2 5 2 1 6 1의 순으로 전위 순회를 한다면 구간 2~6 (2 3 4 3 2 5)에서 LCA를 구해본다면 2번째 순서를 갖는 노드가 가장 빠른 순으로 조회되기 때문에 LCA임을 알아내는 방식이다. 

    • 왜 가장 빠른 순서인 노드가 LCA이나면 전위 순회는 root > left > right순으로 조회하므로 부모노드인 root를 가장 빠르게 순회하기 때문이다.

     

    다소 복잡한 풀이로 인해 아래의 트리를 예시로 설명을 하는 것이 편할 것 같다. 해당 노드 개수는 총 6개로, 루트 노드는 1이다.

    전위 순회  (1 2 4 6 4 2 5 2 1 3 1)

    세그먼트 트리를 활용한 LCA 풀이에 필요한 데이터 자료구조는 다음과 같다.

    • depth : 해당 노드의 높이  1번 노드 1, 4번노드 3  → 해당 문제에서는 필요없음
    • trip: 해당 트리 전위 순회 방문순서  ( 1 2 3 4 3 2 5 2 1 6 1)
    • locInTrip : trip에 처음으로 저장된 노드번호 index  [0 1 9 2 6 3 ] // 1번노드 trip.get(0), 3번노드 trip.get(9)
    • no2serial : 노드 번호에 따른 전위 순회 탐색 순서   [1 2 6 3 5 4] // 3번노드는 6번째 순서로 방문
    • serial2no : 전위 순회 탐색 순서에 따른 노드 번호  [1 2 4 6 5 3] // 3번째 순서로 방문한 노드는 4번노드

     

    구상

    1. 주어진 트리를 전위 순회 탐색한다.
      1. 노드 번호에 따른 순서(no2serial)를 기록한다. 
      2. 순서에 따른 노드 번호(serial2no)를 기록한다. 
      3. 전위 순회로 방문되는 순서(no2serial)를 차례대로 모두 trip이라는 List 자료구조에 기록해준다. (방문한 노드의 경로도 모두 기록해야 한다.)
      4. 노드 번호가 첫 방문된 trip의 위치(locInTrip)를 기록한다. 
    2. 1번에서 기록한 전위 순회에 대한 노드 번호 순서(serial2no)를 세그먼트 트리 리프노드에 넣어 초기화 시켜준다.
      1. 해당 세그먼트 트리는 먼저 탐색한 노드(MIN)를 부모노드에 기록해주어 최소 세그먼트 트리의 형태를 가진다.
      2. 전위 순회 탐색은 root > left > right순이기 때문에 결국 가장 최솟값은 루트노드가 될 것이다.
    3. LCA를 구하기 위해 주어진 a,b의 노드가 처음으로 저장된 trip의 index를 찾아준다. (locInTrip[a], locInTrip[b])
      1. 왜냐하면 해당 index(locInTrip)이 세그먼트 트리 구간이기 때문이다. (ex, a=6, b=5, locInTrip[a]=3, locInTrip[b]=6)
      2. query문을 사용하여 해당 구간에 가장 먼저 들어온(MIN) 순서를 찾는다. 
      3. 해당 순서에 따른 노드 번호(serial2no)를 출력해주면된다.

     

    로직 1번. traversal (전위 순회)

    static void traversal(int cur, int h, int pa) {
    	depth[cur] = h;
    	no2serial[cur] = serialNum;
    	serial2no[serialNum] = cur;
    	serialNum++;
    	
    	locInTrip[cur] = trip.size();
    	System.out.println(cur +" : "+ no2serial[cur]);
    	trip.add(no2serial[cur]);
    	
    	for(int nxt : list[cur]) {
    		if(nxt != pa) {
    			traversal(nxt, h+1, cur);
    			System.out.println("#" +cur +" : "+ no2serial[cur]);
    			trip.add(no2serial[cur]);
    		}
    	}
    }

     

    로직 2번. init (최소 세그먼트 트리  초기화)

    static int init(int start, int end, int node) {
    	if(start == end) {
    		System.out.println(node+":"+trip.get(start));
    		return tree[node] = trip.get(start);
    	}
    	int mid = (start+end)/2;
    	return tree[node] = Math.min(init(start, mid, node*2), init(mid+1, end, node*2+1));
    }

     

    로직 3번. query (a,b 노드의 LCA 찾기)

    해당 query는 LCA의 해당 구간에 가장 먼저 들어온 순서를 출력하기 때문에 답은 해당 순서에 따른 노드 번호 serial2no[lca]를 출력해줘야 한다.

    static int query(int start, int end, int left, int right, int node) {
    	if(right < start || end < left) return Integer.MAX_VALUE;
    	if(left <= start && end <= right) {
    		return tree[node];
    	}
    	
    	int mid = (start+end) /2;
    	
    	return Math.min(query(start, mid, left, right, node*2),
    			query(mid+1, end, left, right, node*2+1));
    }

     

     

    시뮬레이션

    1. 초기화

    위의 예제 노드를 최소 세그먼트 트리로 초기화시키면 상태는 다음과 같다.

     

    최소 세그먼트 트리

    → 해당 리프노드는 전위 순회 방문순서에 대한 값이다. 이에 대응하는 노드번호는 serial2no으로 알아낼 수 있다.

     

     

    2. LCA(4,3)

    a=4과 b=3에 대응하는 trip의 방문순서는 locInTrip으로 구할 수 있다.


    locInTrip[a] = 2, locInTrip[b] = 9이므로, 이제 세그먼트 트리에서 [2~9]구간에서 최솟값을 구해주는 query를 동작시킨다.

    (만약 locInTrip[a]가 더 클 경우 구간이 꼬이므로 swap해줘야한다.)

     

    LCA(4,3)

    그러면 query는 1이라는 값을 출력해준다. 해당 데이터는 순서이므로 순서에 따른 노드번호를 뽑아줘야한다.

     

    따라서 3번, 4번 노드LCA는 serial2no[1] = 1번 노드임을 알 수 있다.

     

    풀이 코드 

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.StringTokenizer;
    
    public class Main {
    	static int n, h, serialNum=1;
    	static List<Integer>[] list;
    	static List<Integer> trip;
    	static int[] tree;
    	static int[] depth, no2serial, serial2no, locInTrip;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		
    		list = new ArrayList[n+1];
    		for(int i=1; i<n+1; i++) {
    			list[i] = new ArrayList<>();
    		}
    		StringTokenizer st = null;
    		for(int i=0; i<n-1; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			list[a].add(b);
    			list[b].add(a);
    		}
    		
    		h = (int)Math.ceil(Math.log(n)/Math.log(2)) +1;
    		
    		trip = new ArrayList<>(); 
    		depth = new int[n+1]; no2serial = new int[n+1];
    		serial2no = new int[n+1]; locInTrip = new int[n+1];
    		traversal(1,1,0);
    
    		int len = trip.size();
    		int size = getTreeSize(len);
    		tree = new int[size];
    		init(0, len-1, 1);
    		
    		StringBuilder sb = new StringBuilder();
    		int m = Integer.parseInt(br.readLine());
    		for(int i=0; i<m; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			a = locInTrip[a];
    			b = locInTrip[b];
    			if(a>b) {
    				int tmp = a;
    				a = b;
    				b = tmp;
    			}
    			int lca = query(0, len-1, a, b ,1);
    			sb.append(serial2no[lca]+"\n");
    			
    		}
    		System.out.println(sb.toString());
    	}
    	
    	static int getTreeSize(int size) {
    		int h = (int)Math.ceil(Math.log(size)/Math.log(2)) +1;
    		return (int)Math.pow(2, h)-1;
    	}
    	
    	static void traversal(int cur, int h, int pa) {
    		depth[cur] = h;
    		no2serial[cur] = serialNum;
    		serial2no[serialNum] = cur;
    		serialNum++;
    		
    		locInTrip[cur] = trip.size();
    		trip.add(no2serial[cur]);
    		
    		for(int nxt : list[cur]) {
    			if(nxt != pa) {
    				traversal(nxt, h+1, cur);
    				trip.add(no2serial[cur]);
    			}
    		}
    	}
    	
    	static int init(int start, int end, int node) {
    		if(start == end) {
    			return tree[node] = trip.get(start);
    		}
    		int mid = (start+end)/2;
    		return tree[node] = Math.min(init(start, mid, node*2), init(mid+1, end, node*2+1));
    	}
    	
    	static int query(int start, int end, int left, int right, int node) {
    		if(right < start || end < left) return Integer.MAX_VALUE;
    		if(left <= start && end <= right) {
    			return tree[node];
    		}
    		
    		int mid = (start+end) /2;
    		
    		return Math.min(query(start, mid, left, right, node*2),
    				query(mid+1, end, left, right, node*2+1));
    	}
    }

     

    성능 비교

    세그먼트 풀이
    dp 풀이

     

    풀이 결과 DP를 활용환 최적화 풀이보다 세그먼트 트리를 이용한 풀이가 연산속도가 더 빨랐다. 그러나 자료구조가 간단하게 설계되어 있지 않아서 구현이 상당히 복잡하다 실제로 공간복잡도는 좀 더 높은 것을 볼 수 있다.. 효율성을 위한 알고리즘 설계가 아니라면 평소에는 DP로 접근하고 후에 성능 이슈를 다룰 때 세그먼트 트리를 활용하는 것이 좋을 듯 싶다.

     

    백준 11438번 LCA2 dp풀이 보러가기