본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 11658번 구간 합 구하기3 (Java)

#11658 구간 합 구하기3

난이도 : 플레 5 

유형 : 펜윅 트리 / 누적합

 

11658번: 구간 합 구하기 3

첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는

www.acmicpc.net

▸ 문제

N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다.

예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.

1 2 3 4
2 3 4 5
3 4 5 6
4 5 6 7

여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.

표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.

 입력

첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2 가 주어진다. w = 0인 경우는 (x, y)를 c로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ c ≤ 1,000) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.

 출력

w = 1인 입력마다 구한 합을 순서대로 한 줄에 하나씩 출력한다.

 

문제 풀이  

2차원 배열의 구간 합을 구하는 문제로 누적합으로 풀어도 되고 펜윅 트리로 풀어도 된다.

  • 누적합은 1열을 기준으로 누적합을 구하고 각 열별로 합을 구해서 더해주면 된다. 그래서 O(n*m)이 걸린다.
  • 펜윅트리는 x열 탐색하는 데 logn, y열 탐색하는데 logn씩 소요된다. 그래서 O(m*(logn)^2)이 걸린다. 참고로 n > (logn)^2이다.

 

펜윅트리

펜윅트리를 2차원 구조로 구해준다. 1차원적으로는 [1 2 3 4] → [1 3 3 10]와 같이 구해지는데, y축을 하나 더해줬다고 생각하면 된다.

2차원 펜윅트리

여기서 구간 합을 구해주는 방식은 1차원적으로는 sum(x2) - sum(x1-1)으로 더해주면 되는데, 2차원에서는 다음과 같이 중첩되는 부분은 빼주면서 구해주면 된다.

구간합 구하기

 

누적합

누적합 코드 또한 이와 똑같고 더 단순한 구조로 이루어져있다. 이는 2차원적이 아닌 단순 열 기준으로 누적합을 구해놓은 다음 각 열별로 구해주면 된다.

누적합 초기화

 

누적합의 구간합은 중복되는 부분이 없으니 뺄 부분은 없고 단순 각 열의 구간 합을 구해서 더해주면 된다.

2차원 구간합 구하기

펜윅트리 풀이 코드 

import java.io.*;
import java.util.StringTokenizer;

public class Main {

	static int n;
	static int[][] arr, tree;
	static void update(int x, int y, int val) {
		while(x <= n){
			for(int i=y; i<=n;) {
				tree[x][i] += val;
				i += i&-i;
			}
			x += x&-x;
		}
	}
	
	static int sum(int x, int y) {
		int result = 0;
		while(x > 0) {
			for(int i=y; i>0;) {
				result += tree[x][i];
				i -= i&-i;
			}
			x -= x&-x;
		}
		return result;
	}
	
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		n = Integer.parseInt(st.nextToken());
		int m = Integer.parseInt(st.nextToken());
		
		arr = new int[n+1][n+1];
		tree = new int[n+1][n+1];
		for(int i=1; i<=n; i++) {
			 st = new StringTokenizer(br.readLine());
			for(int j=1; j<=n; j++) {
				arr[i][j] = Integer.parseInt(st.nextToken());
				update(i,j,arr[i][j]);
			}
		}
		
		StringBuilder sb = new StringBuilder();
		for(int i=0; i<m; i++) {
			st = new StringTokenizer(br.readLine());
			int op = Integer.parseInt(st.nextToken());
			int x1 = Integer.parseInt(st.nextToken());
			int y1 = Integer.parseInt(st.nextToken());
			if(op == 1) {
				int x2 = Integer.parseInt(st.nextToken());
				int y2 = Integer.parseInt(st.nextToken());
				sb.append((sum(x2, y2) - sum(x2, y1-1) - sum(x1-1, y2) + sum(x1-1,y1-1))+"\n");
			}else {
				int c = Integer.parseInt(st.nextToken());
				update(x1, y1, c-arr[x1][y1]);
				arr[x1][y1] = c;
			}
		}
		System.out.println(sb.toString());
	}
}

 

누적합 풀이 코드 

import java.io.*;
import java.util.StringTokenizer;

public class Main {

	static int[][] arr, dp;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		int n = Integer.parseInt(st.nextToken());
		int m = Integer.parseInt(st.nextToken());
		
		arr = new int[n+1][n+1];
		dp = new int[n+1][n+1];
		for(int i=1; i<=n; i++) {
			 st = new StringTokenizer(br.readLine());
			for(int j=1; j<=n; j++) {
				arr[i][j] = Integer.parseInt(st.nextToken());
				dp[i][j] = dp[i][j-1] + arr[i][j];
			}
		}
		
		StringBuilder sb = new StringBuilder();
		for(int i=0; i<m; i++) {
			st = new StringTokenizer(br.readLine());
			int op = Integer.parseInt(st.nextToken());
			int x1 = Integer.parseInt(st.nextToken());
			int y1 = Integer.parseInt(st.nextToken());
			if(op == 1) {
				int x2 = Integer.parseInt(st.nextToken());
				int y2 = Integer.parseInt(st.nextToken());
				
				int res = 0;
				for(int x=x1; x<=x2; x++) {
					res += dp[x][y2] - dp[x][y1-1];
				}
				sb.append(res+"\n");
			}else {
				int c = Integer.parseInt(st.nextToken());
				int dif = c - arr[x1][y1];
				for(int y=y1; y<=n; y++) {
					dp[x1][y] += dif; 
				}
				arr[x1][y1] = c;
			}
		}
		
		System.out.println(sb.toString());
	}
}