본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 10999번 구간 합 구하기2 (Java)

    #10999 구간 합 구하기2

    난이도 : 플레 4

    유형 : 자료 구조 / 세그먼트 트리 / lazy propagation

     

    10999번: 구간 합 구하기 2

    첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

    www.acmicpc.net

    ▸ 문제

    어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째부터 4번째 수에 6을 더하면 1, 2, 9, 10, 5가 되고, 여기서 2번째부터 5번째까지 합을 구하라고 한다면 26을 출력하면 되는 것이다. 그리고 그 상태에서 1번째부터 3번째 수에 2를 빼고 2번째부터 5번째까지 합을 구하라고 한다면 22가 될 것이다.

     입력

    첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c 또는 a, b, c, d가 주어지는데, a가 1인 경우 b번째 수부터 c번째 수에 d를 더하고, a가 2인 경우에는 b번째 수부터 c번째 수의 합을 구하여 출력하면 된다.

    입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

     출력

    첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

     

    문제 풀이  

    주어진 배열을 갱신하면서 구간합을 구하는 문제이다. 일반적인 구간 합 문제는 세그먼트 트리를 사용하여 그냥 풀면 되지만 갱신하는 부분이 난이도가 올라간 문제이다. 

    📚 조건

    • 수의 개수 N ( 1  <= N <= 1,000,000 ), 수의 범위 -2^63 ~ 2^63-1
    • 업데이트 횟수 M ( 1 <= M <= 10,000 )
    • 구간 합 구하는 횟수 K ( 1<= K <= 10,000 )

    어떤 한 개의 리프노드만을 업데이트시키면 해당 부분만 탐색해주면 되는데 해당 문제는 [a,b]의 구간에 있는 노드를 동시에 업데이트 시켜줘야한다.

     

    만약 단순 세그먼트 트리만으로 풀이를 하게 되면 업데이트 처리를 할 때, 최악의 경우 트리 전체를 업데이트 해줘야하기 때문에 O(NM + KlogN)의 시간이 걸린다. 범위를 보면 알 수 있듯이 이 풀이는 시간초과가 발생하게 된다.

     

    그래서 lazy propagation을 사용해줘야 한다.

    lazy propagation은 처리를 미뤘다가 필요할 때 갱신을 해주는 방식이다.

     

    빠른 처리를 위해 필요한 노드만 업데이트 처리를 해주고 나머지 자식 노드는 lazy상태를 걸어두었다가 나중에 필요해질 때(구간 합 구할 때) 업데이트를 처리해주는 방식이다. 이렇게 되면 최적화 방식으로 줄일 수 있기 때문에 일반 구간합을 구하는 시간복잡도와 마찬가지로

    O((M+K)logN)로 쿼리를 해결할 수 있다.

     

    이제 세그먼트 트리 로직에 lazy propagation을 어떻게 입히는지 살펴보자

    propagate

    1. 해당 노드 lazy가 등록되어있으면 업데이트해준다.
      1. 리프노드인 경우 물려줌은 패스
      2. 리프노드가 아닌 경우 자식 노드들에게 lazy를 물려준다. (나중에 필요하게 될 때가 오면 갱신; 구간 합 구하는 로직)
    2. 해당 노드는 이제 업데이트 처리를 해준다. 노드 += 업데이트 값 * 해당 자식노드의 총 갯수 
      1. ex) update 1 1 3 -3 : 1~3구간에 속하는 노드가 -3으로 업데이트된다면 1~3구간에 속하는 노드는 3*-3 = -9를 갱신해줘야한다.
    3. 업데이트 완료 후 해당 노드 lazy값을 삭제한다. lazy[node] =0;

    코드

    static void propagate(int start, int end, int node) {
    	// 값이 들어있는 경우 업데이트 처리 시작 
    	if(lazy[node] !=0){ 
    		// 리프노드가 아닌 경우 자식에게 lazy 물려줌(자식은 나중에 갱신)
    		if(start != end) { 
    			lazy[node*2] += lazy[node];
    			lazy[node*2 +1] += lazy[node];
    		}
    		tree[node] += lazy[node] * (end-start+1); // tree[node] = 업데이트 값 * 해당 자식노드의 총 갯수
    		lazy[node] = 0; // 업데이트 완료 후 초기화 
    	}
    }

     

    update

        1. 조회하는 해당 노드에 lazy가 걸려있으면 propagate 메소드로 이동하여 갱신해준다. propagate(start, end, node);
        2. 업데이트 구간에 들어오게 되면 lazy를 등록해준다. if(left <= start && end <= right) { lazy[node] = dif; }
          1. 여기서 중요! 해당 노드만 바로 업데이트 처리해주고 바로 리턴한다. 자식 노드들은 나중에 필요할 때 갱신하게 lazy만 걸어주고 내버려두는게 포인트이다.  propagate(start, end, node); return; 
        3. 1~2번의 업데이트 로직이 모두 일어난 후 갱신된 노드들의 부모 노드에 반영해준다. tree[node] = tree[node*2]+tree[node*2+1];
      1.  

    코드

    static void update(int start, int end, int node, int left, int right, long dif) {
    	propagate(start, end, node); // 해당 노드에 갱신할 값이 있다면 업데이트 
    	if(end < left || right < start) return;
    	if(left <= start && end <= right) {
    		lazy[node] = dif; // 범위에 속할 경우 업데이트 lazy값 갱신  
    		propagate(start, end, node); // 해당 노드에 갱신할 값이 있다면 업데이트 
    		return;
    	}
    	
    	int mid = (start+end)/2;
    	update(start, mid, node*2, left, right, dif);
    	update(mid+1, end, node*2+1, left, right, dif);
    	
    	// update가 다 일어난 후 부모 노드에 반영
    	tree[node] = tree[node*2]+tree[node*2+1];
    }

     

    pSum

    구간합을 구하는 데 방문하는 노드가 lazy가 걸려있다면 이 때 업데이트 처리를 해주게 된다.

    1. 조회하는 해당 노드에 lazy가 걸려있으면 propagate 메소드로 이동하여 갱신해준다. propagate(start, end, node);
    2. 업데이트된 노드로 구간합을 구해준다.

    코드

    static long pSum(int start, int end, int node, int left, int right) {
    	propagate(start, end, node); // 남은 lazy update 처리
    	if(end < left || right < start) return 0;
    	if(left <= start && end <= right) {
    		return tree[node];
    	}
    	int mid = (start+end)/2;
    	return pSum(start, mid, node*2, left, right) + pSum(mid+1, end, node*2+1, left, right);
    }

     

     

    풀이 코드 

    시뮬레이션 보러가기

    import java.io.*;
    import java.util.StringTokenizer;
    
    public class Main {
    	static int n;
    	static long[] elements, tree, lazy;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
    		StringTokenizer st = new StringTokenizer(br.readLine());
    		
    		n = Integer.parseInt(st.nextToken());
    		int m = Integer.parseInt(st.nextToken());
    		int k = Integer.parseInt(st.nextToken());
    		
    		elements = new long[n];
    		for(int i=0; i<n; i++) {
    			elements[i] = Long.parseLong(br.readLine());
    		}
    		
    		tree = new long[n*4];
    		lazy = new long[n*4];
    		
    		init(0,n-1,1);
    		for(int i=0; i<m+k; i++) {
    			st = new StringTokenizer(br.readLine());
    			int op = Integer.parseInt(st.nextToken());
    			int l = Integer.parseInt(st.nextToken())-1;
    			int r = Integer.parseInt(st.nextToken())-1;
    			if(op==1) {
    				long dif = Long.parseLong(st.nextToken());
    				update(0,n-1, 1, l, r, dif);
    			}else {
    				bw.write(pSum(0, n-1, 1, l, r)+"\n");
    			}
    		}
    		bw.flush();
    		bw.close();
    	}
    	
    	static long init(int start, int end, int node) {
    		if(start == end) return tree[node] = elements[start];
    		
    		int mid = (start+end)/2;
    		return tree[node] =  init(start, mid, node*2) +init(mid+1, end, node*2+1);
    	}
    	
    	static void propagate(int start, int end, int node) {
    		if(lazy[node] !=0){
    			if(start != end) {
    				lazy[node*2] += lazy[node];
    				lazy[node*2 +1] += lazy[node];
    			}
    			tree[node] += lazy[node] * (end-start+1);
    			lazy[node]=0;
    		}
    		
    	}
    	
    	static void update(int start, int end, int node, int left, int right, long dif) {
    		propagate(start, end, node);
    		if(end < left || right < start) return;
    		if(left <= start && end <= right) {
    			lazy[node] = dif;
    			propagate(start, end, node);
    			return;
    		}
    		
    		int mid = (start+end)/2;
    		update(start, mid, node*2, left, right, dif);
    		update(mid+1, end, node*2+1, left, right, dif);
    		
    		tree[node] = tree[node*2]+tree[node*2+1];
    	}
    	
    	
    	static long pSum(int start, int end, int node, int left, int right) {
    		propagate(start, end, node);
    		if(end < left || right < start) return 0;
    		if(left <= start && end <= right) return tree[node];
    		
    		int mid = (start+end)/2;
    		return pSum(start, mid, node*2, left, right) + pSum(mid+1, end, node*2+1, left, right);
    	}
    }

    처음에 시간초과로 꽤나 고생했다. 기존 세그먼트 트리 문제와의 차이점은 업데이트 부분이 한 개의 노드에서 구간으로 확장되었다는 것을 인지해서 이 부분을 고쳐야하는 건 알겠는데 연산을 나중에 처리하는 방식으로 해결해야 되는 것인 줄은 생각도 못했다. lazy propagation에 대한 개념이 없으면 접근하기 어려울 것 같다.