본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 19581번 두 번째 트리의 지름 (Java)

#19581 두 번째 트리의 지름

난이도 : 골드 1

유형 : 트리 / BFS

 

19581번: 두 번째 트리의 지름

트리에 N개의 정점이 있고, 각 정점 별로 1부터 N까지의 번호가 붙어있다. 트리에서 가장 먼 두 정점 간의 거리를 트리의 지름이라고 한다. 트리의 지름을 구하는 문제는 너무 많기 때문에 우리

www.acmicpc.net

▸ 문제

트리에 N개의 정점이 있고, 각 정점 별로 1부터 N까지의 번호가 붙어있다.

트리에서 가장 먼 두 정점 간의 거리를 트리의 지름이라고 한다. 트리의 지름을 구하는 문제는 너무 많기 때문에 우리는 두 번째 트리의 지름을 구하려고 한다.

두 번째 트리의 지름은 무엇이냐? 바로 두 번째로 가장 먼 두 정점 간의 거리를 의미한다. (두 번째 트리의 지름은 트리의 지름과 같을 수 있다.)

바로 두 번째 트리의 지름을 구해보자.

 입력

첫 번째 줄에는 정점의 개수 N(3 ≤ N ≤ 100,000)이 들어온다.

둘째 줄부터 N번째 줄까지 각 간선에 대한 정보가 들어온다. 간선에 대한 정보는 세 개의 정수로 이루어져 있다. 첫 번째 정수와 두 번째 정수는 간선과 연결된 정점 번호를 나타내고, 세 번째 정수는 간선의 가중치를 나타낸다. 간선의 가중치는 20,000 이하의 자연수이다.

 출력

첫째 줄에 두 번째 트리의 지름을 출력한다.

 

문제 풀이  

먼저 트리의 지름을 구한 다음에 트리의 지름이 되는 두 노드에서 지름을 제외한 가장 긴 거리의 노드를 찾으면 된다. 둘 중에서 더 큰 거리의 노드를 가지는 쪽이 두 번째 트리의 지름이 된다.

 

먼저 트리의 지름을 구하는 방법은 아무 노드에서 가장 거리가 먼 노드를 구한 다음에 그 노드에서 다시 가장 거리가 먼 노드를 구하면 된다. 

  • 자세한 트리의 지름을 구하는 과정은 이 글을 참고

트리 지름 구하기

 

풀이 코드 

import java.io.*;
import java.util.*;

public class Main {
	
	static class Node{
		int idx;
		int dis;
		
		Node(int idx, int dis){
			this.idx = idx;
			this.dis = dis;
		}
	}

	static int n;
	static List<Node>[] list;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		n = Integer.parseInt(br.readLine());
		
		list = new ArrayList[n+1];
		for(int i=1; i<=n; i++) {
			list[i] = new ArrayList<>();
		}
		
		StringTokenizer st;
		for(int i=0; i<n-1; i++) {
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			int d = Integer.parseInt(st.nextToken());
			
			list[a].add(new Node(b,d));
			list[b].add(new Node(a,d));
		}
		
		Node f = bfs(1,0);
		Node r = bfs(f.idx,0);
		
		int dis1 = bfs(f.idx, r.idx).dis;
		int dis2 = bfs(r.idx, f.idx).dis;
		System.out.println(Math.max(dis1, dis2));
	}
	
	static Node bfs(int s, int e) {
		Queue<Node> q = new LinkedList<>();
		boolean[] visited = new boolean[n+1];
		Node endNode = new Node(s,0);
		q.add(endNode);
		visited[s] = true;
		while(!q.isEmpty()) {
			Node p = q.poll();
			
			if(p.dis > endNode.dis && p.idx != e) {
				endNode.idx = p.idx;
				endNode.dis = p.dis;
			}
			
			for(Node nxt : list[p.idx]) {
				if(visited[nxt.idx]) continue;
				visited[nxt.idx] = true;
				q.add(new Node(nxt.idx, p.dis + nxt.dis));
				
			}
		}
		return endNode;
	}
}