본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 14438번 수열과 쿼리 17 (Java)

#14438 수열과 쿼리 17

난이도 : 골드 1

유형 : 세그먼트 트리

 

14438번: 수열과 쿼리 17

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. 1 i v : Ai를 v로 바꾼다. (1 ≤ i ≤ N, 1 ≤ v ≤ 109) 2 i j : Ai, Ai+1, ..., Aj에서 크기가 가장 작은 값을

www.acmicpc.net

▸ 문제

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오.

  • 1 i v : Ai를 v로 바꾼다. (1 ≤ i ≤ N, 1 ≤ v ≤ 10^9)
  • 2 i j : Ai, Ai+1, ..., Aj에서 크기가 가장 작은 값을 출력한다. (1 ≤ i ≤ j ≤ N)

수열의 인덱스는 1부터 시작한다.

 입력

첫째 줄에 수열의 크기 N이 주어진다. (1 ≤ N ≤ 100,000)

둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 10^9)

셋째 줄에는 쿼리의 개수 M이 주어진다. (1 ≤ M ≤ 100,000)

넷째 줄부터 M개의 줄에는 쿼리가 주어진다.

 출력

2번 쿼리에 대해서 정답을 한 줄에 하나씩 순서대로 출력한다.

 

문제 풀이  

구간의 최솟값을 구하는 문제로, 선형 탐색으로 풀이를하면 시간초과가 발생하므로 세그먼트 트리를 사용하여 풀이해야한다.

 

세그먼트 트리 높이

크기가 N인 배열이 존재할 때 세그먼트 트리 크기는 다음과 같이 구한다.

  1. 트리의 높이 = ceil(log2(N))
  2. 세그먼트 트리의 크기 = ( 2^(트리의 높이 + 1) )
static int getTreeSize() {
	int h = (int)Math.ceil(Math.log(n)/Math.log(2))+1;
	return (int)Math.pow(2, h)-1;
}

 

그런데 그냥 위의 방법보다는 덜 구체적이기는 하지만 배열 크기 * 4로 설정해줘도 된다. 

 

세그먼트 트리 초기화

부모 노드에 자식 노드들 중 가장 작은 값을 선택하는 즉, 구간에서 가장 작은 값을 골라주는 세그먼트 트리를 만들어준다.

static int init(int s, int e, int node) {
	if(s == e) return tree[node] = arr[s];
	
	int mid = (s+e)/2;
	
	return tree[node] = Math.min(init(s, mid, node*2), init(mid+1, e, node*2+1)); 
}

 

예제 [5, 4, 3, 2, 1] 수열로 최솟값 세그먼트 트리를 만들면 다음과 같다.

[5 4 3 2 1] 세그먼트 트리

 

쿼리 1번: 값 업데이트

특정 배열 값을 업데이트한 다음 세그먼트 트리를 다시 업데이트 해준다. 위의 초기화 로직과 다른 점은 idx 범위를 벗어나는 부분은 탐색을 종료해줘야 한다는 점이다.

  • if(s > idx || e < idx ) return tree[node];
static int update(int s, int e, int node, int idx) {
	if(s > idx || e < idx ) return tree[node];
	if(s == e)  return tree[node] = arr[idx];
	
	int mid = (s+e)/2;
	return tree[node] = Math.min(update(s, mid, node*2, idx), update(mid+1, e, node*2+1, idx));
}

 

위의 세그먼트 트리에서 5번 배열 값을 3으로 업데이트하면 다음과 같다.

배열 값 업데이트하기

 

쿼리 2번: 구간 최솟값

l에서 r구간에서 최솟값을 찾아 출력해주면 된다.

  • 범위 밖은 자료형 최댓값을 반환하고, 범위 안에서는 tree[node]를 반환해준다.
static int getMin(int s, int e, int node, int l, int r) {
	if(r < s || l > e ) return Integer.MAX_VALUE;
	if(l <= s && e <= r) return tree[node];
	
	int mid = (s+e)/2;
	return Math.min(getMin(s, mid, node*2, l, r), getMin(mid+1, e, node*2+1, l, r));
}

 

위의 예제에서 [1, 3] 구간의 최솟값을 구하는 과정은 다음의 노드를 탐색하여 3을 출력한다.

[1, 3] 최솟값 찾아내기

 

풀이 코드

import java.io.*;
import java.util.StringTokenizer;

public class Main {

	static int[] arr, tree;
	static int n;

	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		n = Integer.parseInt(br.readLine());
		arr = new int[n];
		tree = new int[getTreeSize()];
		StringTokenizer st = new StringTokenizer(br.readLine());
		for(int i=0; i<n; i++) {
			arr[i] = Integer.parseInt(st.nextToken());
		}
		
		init(0, n-1, 1);
		
		StringBuilder sb = new StringBuilder();
		int m = Integer.parseInt(br.readLine());
		for(int i=0; i<m; i++) {
			st = new StringTokenizer(br.readLine());
			int op = Integer.parseInt(st.nextToken());
			int a = Integer.parseInt(st.nextToken())-1;
			int b = Integer.parseInt(st.nextToken())-1;
			if(op == 1) {
				arr[a] = b+1;
				update(0, n-1, 1, a);
			}else {
				sb.append(getMin(0, n-1, 1, a, b)+"\n");
			}
		}
		System.out.println(sb.toString());
		
	}
    
	static int getTreeSize() {
		int h = (int)Math.ceil(Math.log(n)/Math.log(2))+1;
		return (int)Math.pow(2, h)-1;
	}
	
	static int init(int s, int e, int node) {
		if(s == e) return tree[node] = arr[s];
		
		int mid = (s+e)/2;
		return tree[node] = Math.min(init(s, mid, node*2), init(mid+1, e, node*2+1)); 
	}
	
	static int update(int s, int e, int node, int idx) {
		if(s > idx || e < idx ) return tree[node];
		if(s == e) return tree[node] = arr[idx];
		
		int mid = (s+e)/2;
		return tree[node] = Math.min(update(s, mid, node*2, idx), update(mid+1, e, node*2+1, idx));
	}
	
	static int getMin(int s, int e, int node, int l, int r) {
		if(r < s || l > e ) return Integer.MAX_VALUE;
		if(l <= s && e <= r) return tree[node];
		
		int mid = (s+e)/2;
		return Math.min(getMin(s, mid, node*2, l, r), getMin(mid+1, e, node*2+1, l, r));
	}
	
}