본문 바로가기

Dot Algo∙ DS/PS

[BOJ] 백준 11049번 행렬 곱셈 순서 (Java)

    #11049 행렬 곱셈 순서

    난이도 : 골드 3

    유형 : DP

     

    11049번: 행렬 곱셈 순서

    첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같

    www.acmicpc.net

    ▸ 문제

    크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.

    예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

    • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
    • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

    같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

    행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.

     입력

    첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.

    둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)

    항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.

     출력

    첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 2^31-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 2^31-1보다 작거나 같다.

     

    문제 풀이  

    행렬의 곱셈을 하기 위해서는 왼쪽 행렬의 열과 오른쪽 행렬의 행의 크기가 같아야 한다. 그러므로 인접한 행렬들 끼리만 곱셈이 가능하다.

    • 항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.

    문제에서도 입력값이 곱셉이 가능하게 주어진다고 했기 때문에 순서를 꼬는 경우의 수는 고려하지 않아도 된다.

     

    1. Bottom-up

     인접한 행렬끼리만 곱셈이 가능하기 때문에 처음에는 Bottom-up 방식으로 접근했다.

    행렬 곱셉 연산의 최솟값 구하는 과정

     

    DP에 넣을 값은 계속해서 비교해주면서 최솟값을 찾아줘야하기 때문에 i의 행렬과 j의 행렬의 곱셉 연산을 저장하도록 설계했다.

    • dp[i번째 행렬][j번째 행렬]  = 필요한 곱셈 연산 수 최솟값

     

    구상

    반복문 설계가 쉽지 않았다. 처음에는 구간 간격(i)을 설정해주고, 그 구간 길이에 따라서 구할 수 있는 행렬의 곱셈을 구하여 값을 저장해주었다.

    • i : 행렬의 개수가 n이라면 구간 간격은 2~ n까지 나눌 수 있다.
    • j : 구간 길이가 i이기 때문에 구간의 시작 행렬은 0~ n-i까지 가능하다
      • 위의 그림을 예로 n=4, i=2 라면 j는 0~1(AB), 1~2(BC), 2~3(CD)의 구간으로 나눌 수 있다. 
    • k : 구간의 시작부터 끝까지(j ~ j+i-1) 중간 지점을 설정하여 해당 구간의 곱셈 연산의 최솟값을 구한다.  
      • dp[j][k]  + dp[k+1][j+i-1] : (j ~ k), (k+1 ~ j+i-1)의 행렬의 곱셈 연산 최솟값을 더해준다.
      • (data[j]*data[k+1]*data[j+i]) : (j ~ j+i-1)행렬의 곱셈이 이루어지면서 새로운 값을 더해준다.
      • 기존값과 비교하여 작으면 값을 바꿔준다.
    for(int i=2; i<n+1; i++) { // 구간 간격 i 
    	for(int j=0; j<n-i+1; j++) { // 구간 시작점 j (0~j+i-1))
    		dp[j][j+i-1] = INF;
    		for(int k=j; k<j+i-1; k++) { // 중간 지점 k (j~ j+i-1)) 
    			int value = dp[j][k]  + dp[k+1][j+i-1] + (data[j]*data[k+1]*data[j+i]);
    			dp[j][j+i-1] = Math.min(dp[j][j+i-1], value);
    		}
    	}
    }

     

    2. Top-down 풀이

    이는 재귀 Top-down 풀이 방법이다. 반복문 설계가 쉽지않을 때 가독성이 좋은 Top-down 방식으로 전환해서 풀기도 한다. 반대로 ABCD를 하위문제로 쪼개어 최솟값을 뽑아주는 식으로 재귀문을 짜봤다.

    static int solve(int pos, int cur) {
    	if(pos == cur) return 0;
    	if(dp[pos][cur] != INF) return dp[pos][cur];
    	
    	for(int i=pos; i<cur; i++) {
    		int value = solve(pos,i) + solve(i+1, cur) + (data[pos] *data[i+1]*data[cur+1]);
    		dp[pos][cur] = Math.min(dp[pos][cur], value);
    	}
    	return dp[pos][cur];
    }

     

    풀이 코드  (반복문)

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.Arrays;
    import java.util.StringTokenizer;
    
    public class Main {
    	static int n, INF = Integer.MAX_VALUE;
    	static int[] data;
    	static int[][] dp;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		
    		data = new int[n+1];
    		StringTokenizer st = null;
    		for(int i=0; i<n; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			data[i] = a; data[i+1] = b;
    		}
    		
    		dp = new int[n][n];
    		
    		for(int i=2; i<n+1; i++) { // 구간 간격  
    			for(int j=0; j<n-i+1; j++) { // 구간 시작점 j (0~n-i))
    				dp[j][j+i-1] = INF;
    				for(int k=j; k<j+i-1; k++) { // 중간 지점 k (j~ j+i-1))
    					int value = dp[j][k]  + dp[k+1][j+i-1] + (data[j]*data[k+1]*data[j+i]);
    					dp[j][j+i-1] = Math.min(dp[j][j+i-1], value);
    				}
    			}
    		}
    		System.out.println(dp[0][n-1]);
    		
    	}
    }

     

    풀이 코드  (재귀)

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.Arrays;
    import java.util.StringTokenizer;
    
    public class Main {
    
    	static int n, INF = Integer.MAX_VALUE;
    	static int[] data;
    	static int[][] dp;
    	public static void main(String[] args) throws IOException{
    		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    		n = Integer.parseInt(br.readLine());
    		
    		data = new int[n+1];
    		StringTokenizer st = null;
    		for(int i=0; i<n; i++) {
    			st = new StringTokenizer(br.readLine());
    			int a = Integer.parseInt(st.nextToken());
    			int b = Integer.parseInt(st.nextToken());
    			data[i] = a; data[i+1] = b;
    		}
    		
    		dp = new int[n][n];
    		for(int i=0; i<n; i++) {
    			Arrays.fill(dp[i], INF);
    		}
    		System.out.println(solve(0,n-1));
    	}
    	static int solve(int pos, int cur) {
    		if(pos == cur) return 0;
    		if(dp[pos][cur] != INF) return dp[pos][cur];
    		
    		for(int i=pos; i<cur; i++) {
    			int value = solve(pos,i) + solve(i+1, cur) + (data[pos] *data[i+1]*data[cur+1]);
    			dp[pos][cur] = Math.min(dp[pos][cur], value);
    		}
    		
    		return dp[pos][cur];
    	}
    }