본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 1208번 부분수열의 합 2 (Java)

    #1208 부분수열의 합2

    난이도 : 골드 2

    유형 : 이진 탐색 / 투 포인터

     

    1208번: 부분수열의 합 2

    첫째 줄에 정수의 개수를 나타내는 N과 정수 S가 주어진다. (1 ≤ N ≤ 40, |S| ≤ 1,000,000) 둘째 줄에 N개의 정수가 빈 칸을 사이에 두고 주어진다. 주어지는 정수의 절댓값은 100,000을 넘지 않는다.

    www.acmicpc.net

    ▸ 문제

    N개의 정수로 이루어진 수열이 있을 때, 크기가 양수인 부분수열 중에서 그 수열의 원소를 다 더한 값이 S가 되는 경우의 수를 구하는 프로그램을 작성하시오.

     입력

    첫째 줄에 정수의 개수를 나타내는 N과 정수 S가 주어진다. (1 ≤ N ≤ 40, |S| ≤ 1,000,000) 둘째 줄에 N개의 정수가 빈 칸을 사이에 두고 주어진다. 주어지는 정수의 절댓값은 100,000을 넘지 않는다.

     출력

    첫째 줄에 합이 S가 되는 부분수열의 개수를 출력한다.

     

    문제 풀이 

    N개의 정수로 이루어진 수열의 부분수열을 구한 다음 모든 합을 고려해주어야 한다. 연속된 부분수열이 아님을 주의하자. 부분수열의 총 갯수는 2^N-1인데, 해당 문제는 공집합(0)도 포함하므로 총 2^N개를 고려해줘야 한다. 2^40은 13자리를 가지는 정수이므로 범위를 초과한다.

     

    그래서 N/2, N-(N/2)를 가지는 두 수열로 나눈 다음 각 부분수열을 구하게되면 최대 2^20 = 104,8576이므로 충분히 커버가 가능하다. 이제 다음과 같이 비트마스킹을 사용하여 수열의 부분집합을 구하면 된다. 

    • 만약, n==5이면 3, 2의 크기를 가지는 부분수열을 가진다.
      • 크기가 3인 부분수열은 총 8개의 부분수열을 얻게 된다.
      • 크기가 2인 부분수열은 총 4개의 부분수열을 얻게 된다.  
    for(int i=0; i<(1<<n-size); i++) {
    	for(int j=0; j<n-size; j++) {
    		if((i&(1<<j))==(1<<j)) {
    			a[i] +=arr[j];
    		}
    	}
    }
    
    for(int i=0; i<(1<<size); i++) {
    	for(int j=0; j<size; j++) {
    		if((i&(1<<j))==(1<<j)) {
    			b[i]+= arr[j+(n-size)];
    		}
    	}
    }

     

     

    투 포인터

    두 부분수열의 집합을 정렬한 다음 a는 처음부터, b는 끝부터 시작하여 가운데로 모이는 투포인터 방식을 사용해주면 된다.

    1. a배열 포인터(ap)와 b배열 포인터(bp)를 0과 b.length-1로 설정한다.
    2. a[ap] + b[bp]의 합이 s인 곳을 찾는다.
      1. 해당 값이 중복되는 구간을 찾는다. ac++, bc++
      2. 그 구간의 곱해서 카운트해준다. cnt += ac*bc;
    3. a[ap] + b[bp] < s  크기가 s보다 작으면 a의 포인터를 증가시킨다.
    4. a[ap] + b[bp] > s  크기가 s보다 크면 b의 포인터를 감소시킨다.

     

    이진탐색

    이진탐색은 upper_bound와 lower_bound를 사용하여 합이 s가 되는 구간을 구해서 카운트를 해주면 된다.

    1. 하나의 부분수열의 집합을 기준으로 값을 설정한다. for i:0~a.length-1  v = a[i];
    2. a배열에서 v인 구간과 b배열에서 s-v인 구간을 찾는다.
      1. long aTerm = upper_bound(a, av) -lower_bound(a, av);
      2. long bTerm = upper_bound(b, s-av)-  lower_bound(b, s-av);

     

    마지막으로 a배열과 b배열에 공집합이 각각 있으므로 만약 s가 0이라면 cnt-1을 해주어야한다.

    투 포인터 풀이 코드 

    import java.io.*;
    import java.util.*;
    
    public class Main {
    
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		StringTokenizer st = new StringTokenizer(br.readLine());
    		
    		int n = Integer.parseInt(st.nextToken());
    		int s = Integer.parseInt(st.nextToken());
    		
    		int[] arr = new int[n];
    		st = new StringTokenizer(br.readLine());
    		for(int i=0; i<n; i++){
    			arr[i] = Integer.parseInt(st.nextToken());
    		}
    		
    		int size = n/2;
    		int[] a = new int[1<<(n-size)];
    		int[] b = new int[1<<(size)];
    		for(int i=0; i<(1<<n-size); i++) {
    			for(int j=0; j<n-size; j++) {
    				if((i&(1<<j))==(1<<j)) {
    					a[i] +=arr[j];
    				}
    			}
    		}
    		for(int i=0; i<(1<<size); i++) {
    			for(int j=0; j<size; j++) {
    				if((i&(1<<j))==(1<<j)) {
    					b[i]+= arr[j+(n-size)];
    				}
    			}
    		}
    		
    		Arrays.sort(a);
    		Arrays.sort(b);
    		
    		int ap =0;
    		int bp = b.length-1;
    		long cnt = 0;
    		while(ap<a.length && bp>-1){
    			int av = a[ap], bv = b[bp];
    			if(av+bv==s) {
    				long ac=0, bc=0;
    				while(ap<a.length && av == a[ap]) {
    					ac++;
    					ap++;
    				}
    				
    				while(bp>-1&& bv == b[bp]) {
    					bc++;
    					bp--;
    				}
    				cnt += ac*bc;
    			}
    			
    			if(av+bv < s) {
    				ap++;
    			}else if(av+bv>s) {
    				bp--;
    			}
    		}
    		if(s==0) cnt--;
    		System.out.println(cnt);
    	}
    }

     

    이진탐색 풀이 코드

    import java.io.*;
    import java.util.*;
    
    public class Main {
    
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		StringTokenizer st = new StringTokenizer(br.readLine());
    		
    		int n = Integer.parseInt(st.nextToken());
    		int s = Integer.parseInt(st.nextToken());
    		
    		int[] arr = new int[n];
    		st = new StringTokenizer(br.readLine());
    		for(int i=0; i<n; i++){
    			arr[i] = Integer.parseInt(st.nextToken());
    		}
    		
    		int size = n/2;
    		int[] a = new int[1<<(n-size)];
    		int[] b = new int[1<<(size)];
    		for(int i=0; i<(1<<n-size); i++) {
    			for(int j=0; j<n-size; j++) {
    				if((i&(1<<j))==(1<<j)) {
    					a[i] +=arr[j];
    				}
    			}
    		}
    		for(int i=0; i<(1<<size); i++) {
    			for(int j=0; j<size; j++) {
    				if((i&(1<<j))==(1<<j)) {
    					b[i]+= arr[j+(n-size)];
    				}
    			}
    		}
    		
    		Arrays.sort(a);
    		Arrays.sort(b);
    		long cnt=0;
    		for(int i=0; i<a.length;) {
    			int av = a[i];
    			long aTerm = upper_bound(a, av) -lower_bound(a, av);
    			long bTerm = upper_bound(b, s-av)-  lower_bound(b, s-av);
    			cnt+= aTerm*bTerm;
    			i+=aTerm;
    		}
    		if(s==0) cnt--;
    		System.out.println(cnt);
    	}
    	
    	static int upper_bound(int[] arr, int t) {
    		int left = 0, right =arr.length; 
    		while(left<right) {
    			int mid = (left+right)/2;
    			if(t >= arr[mid]) {
    				left = mid+1;
    			}else {
    				right =mid;
    			}
    		}
    		return right;
    	}
    	
    	static int lower_bound(int[] arr, int t) {
    		int left = 0, right =arr.length; 
    		while(left<right) {
    			int mid = (left+right)/2;
    			if(t <= arr[mid]) {
    				right =mid;
    			}else {
    				left = mid+1;
    			}
    		}
    		return right;
    	}
    }

     

    실행결과

     

    투포인터 실행결과
    이진탐색 실행결과