본문 바로가기

Dot Algo∙ DS/알고리즘 개념

[알고리즘] 펜윅 트리(Fenwick Tree) 빠르고 간단한 구간 합 구하기 (Java)

    펜윅 트리(Fenwick Tree) :  빠르고 간단한 구간 합

    세그먼트 트리의 가장 흔한 사용 예는 바로 구간 합을 빠르게 구하는 것이다. 이 경우 세그먼트 트리 대신 쓸 수 있는 세그먼트 트리의 궁극적인 진화 형태로 펜윅 트리(Fenwick Tree) 혹은 이진 인덱스 트리(binary indexed tree)라고 불리는 것이 있다.

     

    펜윅 트리가 사용하는 중요한 아이디어는 구간 합 대신 부분 합만을 빠르게 계산할 수 있는 자료구조로 만들어도 구간 합을 계산할 수 있다는 것이다.

     

    arr[]의 위치 pos에 대해 배열의 부분합 psum[pos] = arr[0] + arr[1] + ··· + a[pos]를 빠르게 계산할 수 있다고 하자.

    • 그러면 [i, j] 구간의 합은 psum[j] - psum[i-1]로 계산할 수 있다.

     

    부분 합만을 계산한다고 생각해 보면, 구간 트리가 미리 계산헤 저장하는 정보의 상당수는 필요가 없다. 다음 그림을 보면 구간 트리가 계산한 각 구간을 보여준다.

    • 이때 [8, 15] 구간의 구간 합은 사실 부분 합만을 구한다면 필요가 없다. psum[15]를 구한다면 어차피 루트에 있는 값을 사용하면 되고, 다른 위치의 부분 합을 구할 때는 이 값을 쓸 수가 없기 때문에 굳이 저장해 둘 필요가 없다.

    세그먼트 트리가 저장하는 각 구간들

     

    필요없는 구간 지우기

    같은 원리로, 하나의 긴 구간 밑에 두 개의 작은 구간이 있을 때 이 두 구간 중 오른쪽 구간은 항상 지워도 된다.

    • 남은 구간의 갯수는 정확히 n개가 된다. (8 + 4 + 2 + 1 = 15)
    • 오른쪽 끝 원소들을 보면 모두 값이 다른 것을 알 수 있다. 그래서 이 대응을 이용해 1차원 배열에 각 구간의 합을 저장할 수 있다.
    • tree[i] = 오른쪽 끝 위치가 arr[i]인 구간의 합

     

    필요없는 구간을 지워서 펜윅 트리로 만들기

     

    구간 합 구하기

    이젠 arr[pos] 까지의 구간 합 psum[pos]를 구하고 싶으면 위 그림에서 pos에서 끝나는 구간의 합 tree[pos]를 답에 더한다. 그리고 남은 부분들을 왼쪽에서 찾아 더하면 된다.

    • 예를 들어, psum[12] = tree[12] + tree[11] + tree[7]이다.
    • 어떤 부분 합을 구하든 O(logN)개의 구간 합만 있으면 된다.

     

    그럼 이제 pos에서 끝나는 구간 다음으로 더해야 할 구간을 어떻게 찾아야 하는지 알아내야 한다.

    • tree[12] → tree[11] → tree[7]

    psum[12] 구간합 구하기

     

    펜윅 트리(Fenwick Tree): 이진수로 표현하기

    펜웍 트리는 각 숫자의 이진수 표현을 이용해 이 문제를 해결할 수 있다.우선 이를 위해 배열 arr[]와 tree[]의 첫 원소의 인덱스를 1로 바꾸자. 모든 원소의 인덱스에 1을 더해주면 된다. 그러고 나면 특정 부분 합을 구하기 위해 더해야 할 구간 합들을 쉽게 찾을 수 있다.

    다음 그림을 보면 각 구간들의 길이는 오른쪽 끝에 있는 0의 개수가 하나 늘 때마다 두 배로 늘어나는 것을 확인할 수 있다.

    • 8의 이진수 표현은 1000(2)이고, 이 수의 오른쪽 끝에는 0이 세 개이므로 8에서 끝나는 구간의 길이는 2^3 = 8이다.
    • 10의 이진수 표현은 1010(2) 이고, 이 수의 오른쪽에는 0이 하나 있으므로 10에서 끝나는 구간의 길이는 2^1 = 2가 된다.

     

    이진수로 표현하기

     

    이진수 표현으로 부분 합 구간 찾기

    이제 부분 합을 구하기 위해 더해야 하는 구간들의 번호도 이들의 이진수 표현과 관계가 있다. 오른쪽 끝 위치의 이진수 표현에서 마지막 비트를 지우면 다음 구간을 쉽게 찾을 수 있다.

    • 예를 들어 psum[7]을 구하기 위해 더해야 하는 숫자는 7에서 끝나는 구간의 합 tree[7], 6에서 끝나는 구간의 합 tree[6], 그리고 4에서 끝나는 구간 합 tree[4]이다.
    • 이진수로 표현하면, 111(2), 110(2), 100(2)이 된다.

    이진수 표현으로 부분 합 구간 찾기

     

    이진수 표현으로 배열 값 변경하기

    펜윅 트리에서 배열의 값을 변경하는 것은 해당 위치의 값에 숫자를 더하고 빼는 것으로 구현한다. 맨 오른쪽에 있는 1인 비트를 스스로에게 더해주는 연산을 반복하여 해당 위치를 포함하는 구간들을 모두 만날 수 있다.

    • 예를 들어 arr[5]를 3늘리고 싶다고 하면, arr[5]를 포함하는 모든 구간의 합들을 3씩 늘려주면 된다.
    • 이때 늘려줘야 할 값들은 tree[5], tree[6], tree[8], tree[16]으로, 101(2), 110(2), 1000(2), 10000(2)이다.
    • 101 → 110 → 1000 → 10000 순으로 이동한다.

    이진수 표현으로 배열 값 변경하기

     

     

    일반 펜윅 트리 구현

    부분합 찾기, 값 변경 로직 모두 O(logN)의 시간복잡도를 가진다. 반복문이 수행될 때 마다 트리의 한층을 올라가는데, 트리의 높이는 항상 O(logN)이기 때문이다.

     

    1.  배열 값 업데이트 

    arr[pos]값을 +k를 해줘야 한다. 펜웍트리에서 배열 값을 변경하는 것은 해당 위치의 값에 숫자를 더하고 빼는 것으로 구현한다. 맨 오른쪽에 있는 1인 비트를 스스로에게 더해주는 연산을 반복하여 해당 위치를 포함하는 구간들을 모두 만날 수 있다.

     

    만약 구간 3에 +k를 한다면 다음과 같다. arr[3]을 포함하는 모든 구간의 합을 k씩 늘려주면 된다.

    • 이때 늘려줘야 할 값들은 3번 노드, 4번 노드 로 이진수로 표현하면 11(2) → 100(2)로 이동한다.
    • pos += (pos&-pos);

    배열 값 업데이트

     

    2. 구간 합 구하기 - psum

    i부터 j까지의 구간 합을 구하려면 psum[j] - psum[i-1]의 값을 구해주면 된다. psum은 이진수 표현에서 마지막 비트를 지우면서 다음 구간을 찾아가서 더해주면 된다. 오른쪽 끝 위치의 이진수 표현에서 마지막 비트를 지우면 다음 구간을 쉽게 찾을 수 있다.

     

    예를 들어 psum[3]을 구하기 위해 더해야 하는 숫자는 3에서 끝나는 구간의 합 3번 노드, 2에서 끝나는 구간의 합 tree[2]이다.

    • 3번 노드, 2번 노드를 이진수로 표현하면 11(2) → 10(2)이다.
    • pos &= (pos-1);

    구간 합 구하기

     

    다음 구현을 보면 펜윅 트리 구현은 엄청 간단함을 알 수 있다. 때문에 계속 변하는 배열의 구간 합을 구할 때는 세그먼트 트리보다 펜윅 트리를 훨씬 자주 쓰게 된다.

    public class FenwickTree {
    	static int[] tree;
    
    	public FenwickTree(int size) {
    		tree = new int[size+1];
    	}
    
    	long sum(int pos){
    		long result = 0;
    		while(pos > 0){
    			result += tree[pos];
    			pos &= (pos-1);
    		}
    		return result;
    	}
    
    	void add(int pos, int val){
    		while(pos < tree.length){
    			tree[pos] += val;
    			pos += (pos & -pos);
    		}
    	}
    }

     

     

    구간 업데이트,  점 쿼리가 가능한 펜윅 트리

    여기서 좀 더 업그레이드하면 구간 업데이트와 구간 합을 구할 수 있는 Fenwick 트리도 설계할 수 있다. 여기서는 배열 값을 인접한 값들의 차로 다음과 같이 설정한다. 

    • b[1] = arr[1]
    • b[i] = arr[i] - arr[i-1]

     

    1. 구간 업데이트 - (i, j)에 k 더하기

    arr[i] + arr[i+1] + ··· + arr[j]에 k를 더해줘야 한다.  펜윅 트리에서 배열의 값을 변경하는 것은 해당 위치의 값에 숫자를 더하고 빼는 것으로 구현한다. 1번에서 업데이트하는 방식만 약간 변형해주면 된다.

     

    i에서 j구간에만 포함하려면 어떻게해야 할까?

    • b[i] → (arr[i] +k) - arr[i-1] = tree[i] + k
    • b[i+1] → (arr[i+1] +k) - (arr[i] + k) = b[i+1]
    • ...
    • b[j] → (arr[j] +k) - (arr[j-1] + k) = b[j]
    • b[j+1] → (arr[j+1] +k) - (arr[j] + k) = b[j+1] - k

    [i+1,  j]까지는 변화가 없고, 각 i와 j+1에 해당하는 구간에만 +k, -k를 더해주면 된다. 

     

    예를 들어, 3~4까지 +6을 더해주면 다음과 같이 값이 변경된다. (구현 코드)

    • 3이 포함되는 노드들에 +6을 더해주고, 5가 포함되는 노드들에 -6을 더해주면 된다.

    구간 업데이트

     

    2. 점 쿼리 구하기 (Point query)

    arr[x]의 값은 sum()을 그대로 이용하여 구해주면 된다. 

    • b[x] = arr[x] - arr[x-1] 
    • b[x-1] = arr[x-1] - arr[x-2]
    • ...
    • b[1] = arr[1] - arr[0]
    • arr[0]은 0이므로,  sum(x) = arr[x]이 된다.

     

    예를 들어 arr[3]을 구하기 위해 더해야 하는 숫자는 3에서 끝나는 구간의 합 3번 노드, 2에서 끝나는 구간의 합 b[2]이다.

    • 3번 노드, 2번 노드를 이진수로 표현하면 11(2) → 10(2)이다.
    • arr[3] - arr[2] + arr[2] - arr[0] = arr[3]

     

    점 쿼리 구하기

     

    이를 통해 백준 16975번 수열과 쿼리 21를 풀이하면 다음과 같이 구현할 수 있다.

    import java.io.*;
    import java.util.StringTokenizer;
    
    public class Main {
    
    	static long[] tree;
    	static int n;
    	static void add(int pos, int val) {
    		while(pos <= n) {
    			tree[pos] += val;
    			pos += (pos&-pos);
    		}
    	}
    	
    	static long sum(int pos) {
    		long result = 0;
    		while(pos > 0) {
    			result += tree[pos];
    			pos &= (pos-1);
    		}
    		return result;
    		
    	}
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		tree = new long[n+1];
    		StringTokenizer st = new StringTokenizer(br.readLine());
    		int prev = Integer.parseInt(st.nextToken());
    		int now;
    		add(1, prev);
    		for(int i=2; i<=n; i++) {
    			now = Integer.parseInt(st.nextToken());
    			add(i, now - prev);
    			prev = now;
    		}
    		StringBuilder sb = new StringBuilder();
    		int m = Integer.parseInt(br.readLine());
    		for(int i=0; i<m; i++) {
    			st = new StringTokenizer(br.readLine());
    			int op = Integer.parseInt(st.nextToken());
    			int a = Integer.parseInt(st.nextToken());
    			
    			if(op == 1) {
    				int b = Integer.parseInt(st.nextToken());
    				int k = Integer.parseInt(st.nextToken());
    				
    				add(a, k);
    				add(b+1, -k);
    			}else {
    				sb.append(sum(a)+"\n");
    			}
                
    		}
    		System.out.println(sb.toString());
    		
    	}
    }