본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 16975번 수열과 쿼리 21 (Java) - Fenwick

    #16975 수열과 쿼리 21

    난이도 : 플레 4

    유형 : 펜웍 트리 / 세그먼트 트리 + Lazy propagation

     

    16975번: 수열과 쿼리 21

    길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. 1 i j k: Ai, Ai+1, ..., Aj에 k를 더한다. 2 x: Ax 를 출력한다.

    www.acmicpc.net

    ▸ 문제

    길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오.

    • 1 i j k: Ai, Ai+1, ..., Aj에 k를 더한다.
    • 2 x: Ax 를 출력한다.

     입력

    첫째 줄에 수열의 크기 N (1 ≤ N ≤ 100,000)이 주어진다.

    둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 1,000,000)

    셋째 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.

    넷째 줄부터 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다. 1번 쿼리의 경우 1 ≤ i ≤ j ≤ N, -1,000,000 ≤ k ≤ 1,000,000 이고, 2번 쿼리의 경우 1 ≤ x ≤ N이다. 2번 쿼리는 하나 이상 주어진다.

     출력

    2번 쿼리가 주어질 때마다 출력한다.

     

    문제 풀이  

    단순 세그먼트 트리로 풀이를 하면 최악의 경우, O(NM+KlogN)로 대략 10억이 넘는 시간이 걸려 시간초과가 발생한다. 그래서 일반적으로는 세그먼트 트리에 느리게 갱신하는 lazy propagation 기능을 추가하여 풀이를 해줘야 한다. 그러면 O((M+K)logN)으로 처리할 수 있다.

     

    그런데 이보다 더 간단한 구현으로 풀 수 있는 방법이 있는데 바로 펜윅트리를 사용하는 것이다. 펜윅트리는 구간 합을 구하는 데 특화된 알고리즘으로 lazy propagation과 비슷한 시간복잡도로 풀이가 가능하다.

    • 펜윅트리의 구간 합 업데이트, 부분 합 찾기 모두 O(logN)이 걸린다.
    펜윅트리 개념 설명은 여기를 참고해주세요.

     

    해당 문제는 펜윅 트리에 담는 원소를 좀 새롭게 바꿔서 담을 것이다. 왜냐하면 2번 쿼리에서 주어지는 점 쿼리(arr[x])를 구해야 하기 때문이다. 수열을 담은 배열을 arr이라고 하고 이를 펜윅트리에 맞춰 값을 담은 배열을 tree라고 하자.

    • tree[1] = arr[1]
    • tree[i] = arr[i] - arr[i-1]

    펜윅트리 초기값

     

    1. 구간 업데이트 - (i, j)에 k 더하기

    arr[i] + arr[i+1] + ··· + arr[j]에 k를 더해줘야 한다.  펜윅 트리에서 배열의 값을 변경하는 것은 해당 위치의 값에 숫자를 더하고 빼는 것으로 구현한다. 맨 오른쪽에 있는 1인 비트를 스스로에게 더해주는 연산을 반복하여 해당 위치를 포함하는 구간들을 모두 만날 수 있다.

     

    만약 위의 예제에서 구간 3에 +k를 한다면 다음과 같다. arr[3]을 포함하는 모든 구간의 합을 k씩 늘려주면 된다.

    • 이때 늘려줘야 할 값들은 3번 노드, 4번 노드 로 이진수로 표현하면 11(2) → 100(2)로 이동한다.
    • pos += (pos&-pos);

     

    i에서 j구간에만 포함하려면 어떻게해야 할까?

    • tree[i] → (arr[i] +k) - arr[i-1] = tree[i] + k
    • tree[i+1] → (arr[i+1] +k) - (arr[i] + k) = tree[i+1]
    • ...
    • tree[j] → (arr[j] +k) - (arr[j-1] + k) = tree[j]
    • tree[j+1] → (arr[j+1] +k) - (arr[j] + k) = tree[j+1] - k

    i+1 ~ j까지는 변화가 없고, 각 i와 j+1에 해당하는 구간에만 +k, -k를 더해주면 된다. 

     

     

    예를 들어, 3~4까지 +6을 더해주면 다음과 같이 값이 변경된다.

    • 3이 포함되는 노드들에 +6을 더해주고, 5가 포함되는 노드들에 -6을 더해주면 된다.

    [3,4]구간 +6하기

     

    2. 점 쿼리 구하기 (Point query)

    arr[x]의 값은 기존 펜윅 트리 sum()을 그대로 활용해서 구해주면 된다. 

    • tree[x] = arr[x] - arr[x-1] 
    • tree[x-1] = arr[x-1] - arr[x-2]
    • ...
    • tree[1] = arr[1] - arr[0]
    • arr[0]은 0이므로,  sum(x) = tree[x]이 된다.

     

    예를 들어 arr[3]을 구하기 위해 더해야 하는 숫자는 3에서 끝나는 구간의 합 3번 노드, 2에서 끝나는 구간의 합 b[2]이다.

    • 3번 노드, 2번 노드를 이진수로 표현하면 11(2) → 10(2)이다.
    • arr[3] - arr[2] + arr[2] - arr[0] = arr[3]

    sum(3)구하기

     

    풀이 코드 

    import java.io.*;
    import java.util.StringTokenizer;
    
    public class Main {
    
    	static long[] tree;
    	static int n;
    	static void add(int pos, int val) {
    		while(pos <= n) {
    			tree[pos] += val;
    			pos += (pos&-pos);
    		}
    	}
    	
    	static long sum(int pos) {
    		long result = 0;
    		while(pos > 0) {
    			result += tree[pos];
    			pos &= (pos-1);
    		}
    		return result;
    		
    	}
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		tree = new long[n+1];
    		StringTokenizer st = new StringTokenizer(br.readLine());
    		int prev = Integer.parseInt(st.nextToken());
    		int now;
    		add(1, prev);
    		for(int i=2; i<=n; i++) {
    			now = Integer.parseInt(st.nextToken());
    			add(i, now - prev);
    			prev = now;
    		}
    		StringBuilder sb = new StringBuilder();
    		int m = Integer.parseInt(br.readLine());
    		for(int i=0; i<m; i++) {
    			st = new StringTokenizer(br.readLine());
    			int op = Integer.parseInt(st.nextToken());
    			int a = Integer.parseInt(st.nextToken());
    			
    			if(op == 1) {
    				int b = Integer.parseInt(st.nextToken());
    				int k = Integer.parseInt(st.nextToken());
    				
    				add(a, k);
    				add(b+1, -k);
    			}else {
    				sb.append(sum(a)+"\n");
    			}
                
    		}
    		System.out.println(sb.toString());
    		
    	}
    }