#11505 구간 곱 구하기
난이도 : 골드 1
유형 : 자료 구조 / 세그먼트 트리
▸ 문제
어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 곱을 구하려 한다. 만약에 1, 2, 3, 4, 5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 곱을 구하라고 한다면 240을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 곱을 구하라고 한다면 48이 될 것이다.
▸ 입력
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 곱을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1 번째 줄까지 세 개의 정수 a,b,c가 주어지는데, a가 1인 경우 b번째 수를 c로 바꾸고 a가 2인 경우에는 b부터 c까지의 곱을 구하여 출력하면 된다.
입력으로 주어지는 모든 수는 0보다 크거나 같고, 1,000,000보다 작거나 같은 정수이다.
▸ 출력
첫째 줄부터 K줄에 걸쳐 구한 구간의 곱을 1,000,000,007로 나눈 나머지를 출력한다.
문제 풀이
세그먼트 트리를 사용하여 구간 곱을 구해야 한다. 일반적으로는 구간 합을 구하는데 사용되는데 로직을 곱셉의 성질을 신경써서 살짝만 변형시켜주면 구현이 가능하다. 해당 문제를 풀기 위해서는 세그먼트 트리에 대한 개념은 필수이다.
📚 조건
- 수의 개수 N ( 1 <= N <= 1,000,000 )
- 수의 변경이 일어나는 횟수 M ( 1 <= M <= 10,000 )
- 구간의 곱을 구하는 횟수 K ( 1 <= K <= 10,000 )
일반 선형탐색으로 구하게 되면 수 변경 O(1), 구간 곱 구하기 O(NK)로 100억번의 연산을 해야하므로 세그먼트 트리를 사용하여 구해야한다. 그러면 O(KlogN)으로 줄어든다.
구상
- 구간 곱 트리 사이즈를 구한 다음 구간 곱 트리를 초기화 시켜준다.
- a==1, 구간 곱 트리를 업데이트 시켜준다.
- a==2, [l~r] 구간 곱을 구해준다.
로직 1번. 구간 곱 트리 사이즈, 초기화
세그먼트 트리는 이진 트리이므로, 사이즈는 모든 노드가 꽉차있는 완전 이진트리일 경우를 생각해서 크기를 구해주면 된다.
static int getTreeSize() {
int h = (int)Math.ceil(Math.log(n)/Math.log(2));
return (int) Math.pow(2, h+1);
}
각 원소의 최댓값은 1,000,000이므로 구간 곱을 저장할 트리는 long으로 선언하고 각 구간의 곱을 트리에 저장해주면 된다.
static long init(int start, int end, int node) {
if(start == end) return tree[node] = elements[start];
int mid = (start+end)/2;
return tree[node] = (init(start,mid,node*2) * init(mid+1, end, node*2+1))%MOD;
}
로직 2번. 구간 곱 트리 업데이트
구간 합 업데이트 방식은 Top-down으로 루트노드를 시작으로 변경이 발생하는 해당 리프노드까지 내려가면서 업데이트를 해준다. 그런데 구간 곱은 그렇게하면 안된다. 왜냐하면 만약 값이 0에서 6으로 변경되었을 때는 루트노드부터 탐색하게 되면 값이 올바르게 바뀌지 않는다.
그래서 구간 곱 업데이트 방식은 Bottom-up으로 리프노드로부터 시작하여 해당 리프노드가 속해있는 루트 노드까지 탐색하면서 업데이트를 해줘야한다.
static long update(int start, int end, int node, int idx, int dif) {
if(end < idx || idx< start) return tree[node];
if(start == end) return tree[node] = dif;
int mid = (start+end)/2;
return tree[node] = (update(start, mid, node*2, idx, dif)*update(mid+1, end, node*2+1, idx, dif))%MOD;
}
로직 3번. ( left~right ) 구간 곱 구하기
구간 곱 업데이트 방식과 똑같다. 해당 구간에 해당하는 노드들의 곱을 구해주면 된다.
static long pMul(int start, int end, int node, int left, int right) {
if(end < left || right < start) return 1;
if(left <= start && end <= right) {
return tree[node];
}
int mid = (start+end)/2;
return (pMul(start, mid, node*2, left, right)* pMul(mid+1, end, node*2+1, left, right))%MOD;
}
먄약 3~5의 구간 곱을 구한다고 하면 빨간색 동그라미의 노드를 조회하게 되어 해당 노드들의 곱을 출력해준다.
풀이 코드
import java.io.*;
import java.util.StringTokenizer;
public class Main {
static int n, MOD = 1_000_000_007;
static int[] elements;
static long[] tree;
public static void main(String[] args) throws IOException{
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringBuilder sb = new StringBuilder();
StringTokenizer st = new StringTokenizer(br.readLine());
n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
int k = Integer.parseInt(st.nextToken());
elements = new int[n+1];
for(int i=1; i<n+1; i++) {
elements[i] = Integer.parseInt(br.readLine());
}
tree = new long[getTreeSize()];
init(1, n, 1);
for(int i=0; i<m+k; i++) {
st = new StringTokenizer(br.readLine());
int op = Integer.parseInt(st.nextToken());
if(op ==1) {
int idx = Integer.parseInt(st.nextToken());
int dif = Integer.parseInt(st.nextToken());
elements[idx] = dif;
update(1, n, 1, idx, dif);
}else {
int left = Integer.parseInt(st.nextToken());
int right = Integer.parseInt(st.nextToken());
sb.append(pMul(1, n, 1, left, right) % MOD+"\n");
}
}
System.out.println(sb.toString());
}
static int getTreeSize() {
int h = (int)Math.ceil(Math.log(n)/Math.log(2));
return (int) Math.pow(2, h+1);
}
static long init(int start, int end, int node) {
if(start == end) return tree[node] = elements[start];
int mid = (start+end)/2;
return tree[node] = (init(start,mid,node*2) * init(mid+1, end, node*2+1))%MOD;
}
static long update(int start, int end, int node, int idx, int dif) {
if(end < idx || idx< start) return tree[node];
if(start == end) return tree[node] = dif;
int mid = (start+end)/2;
return tree[node] = (update(start, mid, node*2, idx, dif)*update(mid+1, end, node*2+1, idx, dif))%MOD;
}
static long pMul(int start, int end, int node, int left, int right) {
if(end < left || right < start) return 1;
if(left <= start && end <= right) {
return tree[node];
}
int mid = (start+end)/2;
return (pMul(start, mid, node*2, left, right)* pMul(mid+1, end, node*2+1, left, right))%MOD;
}
}
'Dot Algo∙ DS > PS' 카테고리의 다른 글
[BOJ] 백준 9935번 문자열 폭발 (Java) (1) | 2021.07.04 |
---|---|
[BOJ] 백준 3986번 좋은 단어 (Java) (0) | 2021.07.03 |
[BOJ] 백준 7662번 이중 우선순위 큐 (Java) (0) | 2021.07.01 |
[BOJ] 백준 1939번 중량제한 (Java) (0) | 2021.06.30 |
[BOJ] 백준 13460번 구슬 탈출 2 (Java) (0) | 2021.06.29 |