본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 13209번 검역소 (Java)

    #13209 검역소

    난이도 : 플레 3

    유형 : 트리 DP 

     

    13209번: 검역소

    3번 도시와 5번 도시를 잇는 도로와 4번 도시와 3번 도시를 잇는 도로에 검역소를 설치하면 치료제를 11 인분만 비축해도 된다. 1번 도시에 전염병이 발생할 경우 1번 도시와 3번 도시의 10명의 사

    www.acmicpc.net

    ▸ 문제

    연약한 사람들이 모여 사는 나라가 있다. 이 곳에는 N 개의 도시들이 있고, 두 도시 사이를 연결하는 길이 N − 1개 있어 어느 두 도시도 오직 하나의 경로로만 서로 통행할 수 있게 되어 있다.

    이 곳에는 몇 년에 한 번씩 전염병이 창궐하여 큰 피해가 일어난다. 정부에서는 이 문제를 해결하기 위해 N −1 개의 길들 중 K 개의 길에 검역소를 운영하려고 한다. 검역소는 감염된 사람이 지나가지 못하게 함으로서 전염병이 전파될 수 없는 장벽과 같은 역할을 해 준다.

    하지만 검역소들만으로는 전염병이 일어나지 못하게 할 수는 없다. 따라서 어떤 사람이 전염병에 감염될 경우를 대비하여 치료제를 비축해 두려고 한다. 어떤 한 사람이 전염병에 감염될 때에도 전염병에 걸릴 수 있는 모든 사람들이 치료제를 하나씩 받을 수 있게 하기 위해 비축해야 할 치료제의 최소 개수를 구하여라.

     입력

    첫 번째 줄에 테스트 케이스의 수 T 가 주어진다.

    각 테스트 케이스의 첫 번째 줄에는 도시의 개수를 의미하는 자연수 N(2 ≤ N ≤ 100, 000)과 설치할 수 있는 검역소의 개수 K(1 ≤ K  N − 1)가 주어진다.

    두 번째 줄에는 N 개의 자연수가 주어지며, i번째 자연수는 i번 도시의 인구 Xi(1 ≤ Xi ≤ 1, 000, 000, 000)를 의미한다.

    세 번째 줄부터 N − 1개의 줄에는 도로의 정보가 한 줄에 하나씩 주어진다. 한 줄에는 두 개의 정수 Ai , Bi (1 ≤ Ai ≤ N, 1 ≤ Bi ≤ N)가 주어지며 이것은 Ai 번 도시와 Bi 번 도시가 도로로 연결되어 있음을 의미한다.

     출력

    전염병이 퍼질 경우에 대비해 정부에서 비축해야 하는 치료제의 개수를 하나의 정수로 출력한다.

     

    문제 풀이  

    후위 순회 + DP (solve x)

    실패한 풀이 ↓

    더보기

    이진트리라면 두 개의 자식노드만 비교해주면 되지만 해당 문제는 이진트리가 아니기 때문에 한 노드에 자식 노드가 여러 개 있을 수도 있다는 것을 고려해야 한다. 

     

    DP 값 할당

    • dp[pos][0] : 현재 pos노드까지 세워진 장벽의 개수
    • dp[pos][1] : 현재 pos노드를 포함한 장벽이 세워져있지 않고 연결된 노드의 총 인원 

    탐색 조건

    1. 현재 노드와 자식 노드 합쳐주기 if(dp[pos][1]+dp[child][1] <= mid)
    2. 1번에 해당하지않으면, 자식 노드와 장벽 세우기 if(dp[pos][1] <= mid)
      1. 만약 해당 노드 2번째 방문(자식 노드가 2개 이상이면)이고 1번에 해당하지 않을 경우 if(now + dp[child][1] <= mid)
        1. 이전 자식노드가 더 크면 현재 자식노드와 그룹맺고 이전 자식노드와는 장벽을 세움 if((dp[pos][1] -now) > dp[child][1])
        2. 현재 자식노드가 더 크면 이전 자식노드와 그대로 그룹 상태 유지
    3. 1,2번에 해당하지않으면, 해당 mid값으로는 장벽 나누기가 불가능 

     

    적절한 값을 찾아내기 위해 파라메트릭 서치와 후위순회 트리 DP를 사용하여 풀이하였다. 아무래도 자식 노드가 여러 개일 경우에 반례가 있지 않을까 싶어서 여러 케이스를 설정해봤지만 끝내 반례를 찾지 못했다. 범위도 너무 광범위해서 찾기가 정말 어려운 것 같다.

     

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.StringTokenizer;
    
    public class Main {
    
    	static int MAX = 100_001;
    	static long mid, INF = MAX*MAX*MAX; 
    	static int[] data;
    	static long[][] dp;
    	static boolean[] checked;
    	static List<Integer>[] list;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		StringBuilder sb = new StringBuilder();
    		int tc = Integer.parseInt(br.readLine());
    		for(int t=0; t<tc; t++) {
    			StringTokenizer st = new StringTokenizer(br.readLine());
    			int n = Integer.parseInt(st.nextToken());
    			int k = Integer.parseInt(st.nextToken());
    			
    			st = new StringTokenizer(br.readLine());
    			data = new int[n+1];
    			long sum = 0;
    			for(int i=1; i<n+1; i++) {
    				data[i] = Integer.parseInt(st.nextToken());
    				sum += data[i];
    			}
    			
    			list = new ArrayList[n+1];
    			for(int i=1; i<n+1; i++) {
    				list[i] = new ArrayList<>();
    			}
    			
    			for(int i=0; i<n-1; i++) {
    				st = new StringTokenizer(br.readLine());
    				int a = Integer.parseInt(st.nextToken());
    				int b = Integer.parseInt(st.nextToken());
    				
    				list[a].add(b);
    				list[b].add(a);
    			}
    			
    			long left = 0;
    			long right = sum;
    			int max =0;
    			for (int i = 1; i <n+1; i++) {
    				max = Math.max(max, data[i]);
    			}
    			while(left+1 < right) {
    				mid = (left+right)/2;
    				dp = new long[n+1][2];
    				checked = new boolean[n+1];
    				solve(1,0);
    				if(max <= mid && dp[1][0] <= k) {
    					right = mid;
    				}else {
    					left= mid;
    				}
    			}
    			sb.append(right+"\n");
    		}
    		System.out.println(sb.toString());
    	}
    	
    	static void solve(int pos, int prev){
    		long now = data[pos];
    		
    		if(data[pos] <= mid) {
    			dp[pos][0] = 0; 
    			dp[pos][1] = now;
    		}else {
    			dp[pos][0] = MAX;
    			dp[pos][1] = INF;
    		}
    		
    		for(int child : list[pos]) {
    			if(child != prev) {
    				solve(child, pos);
    				 
    				if(!checked[pos]) {
    					checked[pos]= true;
    					// 1. 두 트리 합치기 
    					if(dp[pos][1]+dp[child][1] <= mid) {
    						dp[pos][0] = dp[child][0];
    						dp[pos][1] = dp[child][1] + dp[pos][1];
    					}
    					// 2. 자식 노드와 바리게이트 pos --- child
    					else if(dp[pos][1] <= mid) {
    						dp[pos][0] = dp[child][0] +1;
    						dp[pos][1] = dp[pos][1];
    					}
    					// 3. 해당 w 탐색 불가
    					else {
    						dp[pos][0] = MAX;
    						dp[pos][1] = INF;
    					}
    				}else {
    					// 해당 pos 노드 재방문 (자식노드 2개 이상)
    					if(dp[pos][1] + dp[child][1] <= mid) {
    						dp[pos][0] += dp[child][0];
    						dp[pos][1] = dp[child][1] + dp[pos][1];
    					}
    					// 자식노드 갈라야하면 가장 높은 자식노드 제거 
    					else if(now + dp[child][1] <= mid) {
    						// 이전 자식노드가 크다면 
    						if((dp[pos][1] - now) > dp[child][1]) {
    							dp[pos][0] += dp[child][0] + 1;
    							dp[pos][1] = now + dp[child][1];
    						}
    						// 현재 자식노드가 크다면 
    						else {
    							dp[pos][0] += dp[child][0] + 1;
    						}
    					}
    					// 자식노드와 바리게이트 
    					else if(dp[pos][1] <= mid) {
    						dp[pos][0] += dp[child][0] +1;
    						dp[pos][1] = dp[pos][1];
    					}
    					else {
    						dp[pos][0] = MAX;
    						dp[pos][1] = INF;
    					}
    				}
    			}
    		}
    	}
    }

     

    📝 메모

    아무래도 2-1번 로직 쪽에 오류가 있는 것 같은데...

    아래 우선순위 큐처럼 dp로 모든 자식 노드들의 값을 탐색하면서 비교하려면

    1. 한 pos노드에 자식 노드들의 값을 들고있으려면 일단 for문 밖에서 계산이 이뤄져야 한다.

    2. 그렇다고 또 밖으로 빼내면 pos노드들을 탐색할 때 자식노드들을 또 꺼내어 탐색해야 한다.

    → 그럼 결국 for문 밖에서 dp[pos]노드의 값을 설정할 때 모든 자식 노드들을 조회해야 한다. 그럼 여기서 문제는 그 노드가 자식노드인지 상위 노드인지 모른다. 루트 노드가 주어진 것이 아니기 때문에 

     그러면 데이터 자료구조를 바꿔야 한다. 1번 노드을 루트 노드라고 가정하고 1번 노드를 기준으로 저장하면 되려나?

    1 2

    4 2

    3 4

     정리하면, 여러 개의 자식노드를 비교할 때가 문제라고 생각하기 때문에 해당 pos노드에 값을 넣을 때 자식노드들을 모두 비교해줘야 한다. 그러러면 for문 바깥으로 갖고나와서 각 dp[child]값을 비교해줘야 한다. 그런데 양방향 트리로 데이터를 저장해줬기때문에 child구분이 가질 않는다. 그 부분을 고쳐서 풀어보자. (음 후위순회니깐 상위노드는 방문안했을테니 방문체크도 고려해보자)

     

    --- 최종 결론

    위의 메모대로 다시 풀이를 해봤지만 해당 방식으로는 풀이가 안된다는 것을 깨달았다. pos노드와 child List를 받아서 풀어보려했는데 결국 이진트리가 아니므로 여러 자식노드들 중에서 최적화된 값을 빼줘야하는데 (pos + allChild >= mid, 가장 큰 자식노드부터 제거) 이는 우선순위 큐가 제일 적당하며 매 탐색마다 큰 자식노드를 찾아서 빼주는 로직은 매우 비효율적인 방식이다. 그러므로 해당 문제는 우선순위큐로 푸는 것이 맞다.

     


    후위 순회 + 우선순위 큐 

    다른 풀이를 보니 로직은 똑같았다. 후위 순회를 하면서 mid값이 초과되면 자식 노드 중 가장 큰 값을 제거해주는 방식을 사용했다. 그런데 해당 풀이에서는 dp가 아닌 우선순위 큐를 사용했다. 

     

    탐색 과정

    1. 현재 노드 + 자식 트리 노드를 더하면 queue에 새롭게 추가한 자식 트리노드 값을 넣어준다.
    2. queue의 생명주기는 현재 노드의 탐색을 할 때 생성되고 해당 노드의 자식 트리 노드를 모두 탐색했으면 사라진다.
    3. 현재의 값이 mid보다 크면 queue에 들어있는 맨 앞의 자식 노드를 제거해주고 장벽(cnt)을 카운트 해준다.

     

    풀이 코드 

    import java.io.*;
    import java.util.*;
    
    public class Main {
    
    	static long cnt, mid;
    	static int[] data;
    	static List<Integer>[] list;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		StringBuilder sb = new StringBuilder();
    		int tc = Integer.parseInt(br.readLine());
    		for(int t=0; t<tc; t++) {
    			StringTokenizer st = new StringTokenizer(br.readLine());
    			int n = Integer.parseInt(st.nextToken());
    			int k = Integer.parseInt(st.nextToken());
    			
    			st = new StringTokenizer(br.readLine());
    			data = new int[n+1];
    			long sum = 0;
    			for(int i=1; i<n+1; i++) {
    				data[i] = Integer.parseInt(st.nextToken());
    				sum += data[i];
    			}
    			
    			list = new ArrayList[n+1];
    			for(int i=1; i<n+1; i++) {
    				list[i] = new ArrayList<>();
    			}
    			
    			for(int i=1; i<n; i++) {
    				st = new StringTokenizer(br.readLine());
    				int a = Integer.parseInt(st.nextToken());
    				int b = Integer.parseInt(st.nextToken());
    				list[a].add(b);
    				list[b].add(a);
    			}
    			
    			long left = 0;
    			long right = sum;
    			int max =0;
    			for (int i = 1; i <n+1; i++) {
    				max = Math.max(max, data[i]);
    			}
    			while(left +1< right) {
    				long m = (left+right)/2;
    				mid = m; cnt = 0;
    				solve(1,1);
    				if(max <= m && cnt <= k) {
    					right = m;
    				}else {
    					left= m;
    				}
    			}
    			sb.append(right).append("\n");
    		}
    		System.out.println(sb.toString());
    	}
    	
    	static long solve(int pos, int prev){
    		long now = data[pos];
    		Queue<Long> q = new PriorityQueue<>(Collections.reverseOrder());
    		for(int nxt : list[pos]) {
    			if(nxt != prev) { 
    				long t = solve(nxt, pos);
    				now += t;
    				q.offer(t);
    			}
    		}
    		
    		while(!q.isEmpty()&& now > mid) {
    			now -= q.poll();
    			cnt++;
    		}
    		return now;
    	}
    }