본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 7469번 K번째 수 (Java)

    #7469 K번째 수

    난이도 : 플레 3

    유형 : 세그먼트 트리 / 머지소트 트리 / 이진 탐색 / 파라메트릭 서치

     

    7469번: K번째 수

    현정이는 자료 구조 프로젝트를 하고 있다. 다른 학생들은 프로젝트 주제로 스택, 큐와 같은 기본 자료 구조를 구현하는 주제를 선택했다. 하지만, 현정이는 새로운 자료 구조를 만들었다. 현정

    www.acmicpc.net

    ▸ 문제

    현정이는 자료 구조 프로젝트를 하고 있다. 다른 학생들은 프로젝트 주제로 스택, 큐와 같은 기본 자료 구조를 구현하는 주제를 선택했다. 하지만, 현정이는 새로운 자료 구조를 만들었다.

    현정이가 만든 자료구조는 배열을 응용하는 것이다. 배열 a[1...n]에는 서로 다른 수가 n개 저장되어 있다. 현정이는 여기에 Q(i,j,k)라는 함수를 구현해 모두를 놀라게 할 것이다.

     

    Q(i,j,k): 배열 a[i...j]를 정렬했을 때, k번째 수를 리턴하는 함수

     

    예를 들어, a = (1,5,2,6,3,7,4)인 경우 Q(2,5,3)의 답을 구하는 과정을 살펴보자. a[2...5]는 (5,2,6,3)이고, 이 배열을 정렬하면 (2,3,5,6)이 된다. 정렬한 배열에서 3번째 수는 5이다. 따라서 Q(2,5,3)의 리턴값은 5이다.

    배열 a가 주어지고, Q함수를 호출한 횟수가 주어졌을 때, 각 함수의 리턴값을 출력하는 프로그램을 작성하시오.

     입력

    첫째 줄에 배열의 크기 n과 함수 Q를 호출한 횟수 m이 주어진다. (1 ≤ n ≤ 100,000, 1 ≤ m ≤ 5,000)

    둘째 줄에는 배열에 포함된 정수가 순서대로 주어진다. 각 정수는 절댓값이 109를 넘지 않는 정수이다.

    다음 m개 줄에는 Q(i,j,k)를 호출할 때 사용한 인자 i,j,k가 주어진다. (1 ≤ i ≤ j ≤ n, 1 ≤ k ≤ j-i+1)

     출력

    Q함수를 호출할 때마다 그 함수의 리턴값을 한 줄에 하나씩 출력한다. 

     

    문제 풀이  

    특정 구간의 k번째수, k보다 작은 수, 큰 수와 같은 문제는 머지소트 트리를 이용하여 풀이를 할 수 있다. 머지소트 트리는 세그먼트 트리의 형태를 지니고 있는 데 각 노드들이 리스트 구조를 있다는 점만 다르다.

     

    예를 들어, [ 1 5 2 6 3 7 4 ]의 머지소트 트리를 만들면 다음과 같이 생성된다. 각 노드의 리스트들은 정렬을 시켜준 상태이다.

     

    머지소트 트리

     

     

    [a, b]구간의 K번째 수 구하기

    이제 여기서 [a, b] 구간의 K번째 수를 찾으려면 이진탐색을 사용해야 한다. [a,b] 구간에서 K번째로 큰 수 x를 구한다면 해당 구간에 x이하의 수는 k개 미만 존재함을 알 수 있다. 그러면 문제는 다시 '[a,b] 구간에서 x보다 작은 수가 k개가 되도록 할 수 있는가?' 로 바꿔 물어볼 수 있다.

     

    그러면 다음과 같이 수열의 가장 최솟값과 최댓값을 사이로 이진탐색을 해주면서 x보다 작은 수를 쿼리로 구하고 이 수가 k개가 되는 수열 값을 출력해주면 된다. (문제에서 수열의 원소들은 중복되지 않는다고 했기 때문에 이러한 계산이 가능하다.)

    int l = -MAX, r = MAX; // MAX = 1_000_000_001;
    while(l<=r) {
    	int mid = (l+r)/2;
    	// mid보다 작은 값의 갯수 < k 
    	if(query(1,n,1,a,b,mid) < k) {
    		l = mid +1;
    	}else {
    		r = mid-1;
    	}
    }
    System.out.println(r);

     

    mid값보다 작은 수의 개수

     

    mid값보다 작은 수의 개수를 구하는 쿼리에서 또한 이진탐색 lowerbound를 사용한다. 다음 그림과 같이 [a, b] 구간에서 mid보다 작은 수를 재귀를 통해 구해주면 된다. 

     

    mid보다 작은 수의 개수 구하기

     

    이 쿼리의 시간복잡도는 m개의 쿼리를 이진탐색 2번에 트리 높이만큼의 탐색을 수행하기 때문에 O(m*log^3 n)를 가진다.

     

    풀이 코드 

    import java.io.*;
    import java.util.*;
    
    public class Main {
    	static List<Integer>[] tree;
    	static int[] arr;
    	static final int MAX = 1_000_000_001;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		StringTokenizer st = new StringTokenizer(br.readLine());
    		
    		int n = Integer.parseInt(st.nextToken());
    		int q = Integer.parseInt(st.nextToken());
    		
    		tree = new ArrayList[n*4];
    		for(int i=0; i<4*n; i++) {
    			tree[i] = new ArrayList<>();
    		}
    		
    		arr = new int[n+1];
    		st = new StringTokenizer(br.readLine());
    		for(int i=1; i<=n; i++) {
    			arr[i] = Integer.parseInt(st.nextToken());
    			update(1,n, 1, i);
    		}
    		
    		
    		for(int i=0; i<4*n; i++) {
    			Collections.sort(tree[i]);
    		}
    		StringBuilder sb = new StringBuilder();
    		for(int i=0; i<q; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			int k = Integer.parseInt(st.nextToken());
    			
    			int l = -MAX, r = MAX;
    			while(l<=r) {
    				int mid = (l+r)/2;
    				if(query(1,n,1,a,b,mid) < k) {
    					l = mid +1;
    				}else {
    					r = mid-1;
    				}
    			}
                sb.append(r+"\n");
    		}
            System.out.println(sb.toString());
    	}
    	
    	static void update(int s, int e, int node, int idx) {
    		if(idx < s|| e < idx) return;
    		
    		tree[node].add(arr[idx]);
    		if(s == e) return;
    		
    		int mid = (s+e)/2;
    		update(s, mid, node*2, idx);
    		update(mid+1, e, node*2+1, idx);
    	}
    	
    	static int query(int s, int e, int node ,int l, int r, int val) {
    		if(r < s || l > e ) return 0;
    		if(l <= s && e <= r) {
    			return lowerbound(tree[node], val);
    		}
    		
    		int mid = (s+e)/2;
    		return query(s, mid, node*2, l, r, val) +query(mid+1, e, node*2+1, l, r, val);
    	}
    	
    	static int lowerbound(List<Integer> data, int val) {
    		int s = 0;
    		int e = data.size();
    		
    		while(s < e){
    			int mid = (s+e)/2;
    			if(data.get(mid) >= val) {
    				e = mid;
    			}else {
    				s = mid+1;
    			}
    		}
    		return e;
    	}
    }