ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [WEEK02/DAY05] 백준 문제 : 행렬 제곱
    카테고리 없음 2022. 10. 4. 02:50

    https://www.acmicpc.net/problem/10830

     

    10830번: 행렬 제곱

    크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

    www.acmicpc.net

    제법 어려운 문제들이 출현하고 있다.

    다른 사람들의 정답 코드를 많이 참고하게 되는 것 같다.

     

     

    문제 해석 :

    N*N 크기의 정사각형 행렬 A가 주어지는데, A행렬을 B제곱 한 결과를 출력해야 한다.

    그런데 그대로 출력하는 것이 아니라 각 원소를 1000으로 나눈 나머지를 출력해야 한다.

     

    접근 :

    이 문제를 해결하기 위해 필요한 과정을 다음과 같이 나눠 보았다.

     

    1. 행렬간의 곱셈을 구현하는 것

    2. 행렬을 큰 수로 거듭제곱 하게 되면 계산이 오래 걸리므로, 분할하여 해결하는 것.

    3. 1000으로 나눈 나머지 출력을 구현하는 것.

     

    이 문제를 풀기 위해서는 나머지 연산 분배법칙이 필요하다.

    나머지 연산(Modular Arithmetic)에 대해, 위와 같은 항등식이 성립한다.

    나머지 연산을 위와 같이 분배할 수 있다는 것인데, 다행스럽게도 행렬 간의 곱셈에도 적용되는 개념이었다.

    위 개념을 사용하면 거듭제곱을 분할해서 계산할 수 있게 된다.

    예를 들면,

    (A**10) % 1000 을

    ( ((A**5) % 1000) * ((A**5) % 1000) ) % 1000 과 같이

    풀어서 계산할 수 있는 것이다.

     

    그래... 익숙해.

    익숙한 형태가 보인다. 재귀를 써서 거듭제곱 하는 횟수를 줄이면

    시간복잡도 O(N)을 O(logN)으로 줄일 수 있을 것이다.

     

    그러면 이제 재귀를 이용해 문제를 분할했으니

    행렬간의 곱셈과 나머지 연산을 구현하기만 하면 된다.

     

    행렬간의 곱셈은 3중 for 문을 이용해서 구현했다.

     

     

    import sys
    
    def main():
        n, b = map(int, sys.stdin.readline().split())
        arr = [[0 for _ in range(n)] for _ in range(n)]
        for i in range(n):
            arr[i] = list(map(int, sys.stdin.readline().split()))
        
        arr = square(arr, b, n)
        for row in arr:
            print(*row)
    
    
    def square(arr, b, n):
        if b == 1:
            for i in range(n):
                for j in range(n):
                    arr[i][j] %= 1000
            return arr
    
        else:
            tmp = square(arr, b//2, n)
            if b%2 == 0:
                return mul(tmp, tmp, n)
            else:
                return mul(mul(tmp, tmp, n), arr, n)
    
    
    def mul(arr1, arr2, n):
        new_arr = [[0]*n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    new_arr[i][j] += arr1[i][k]*arr2[k][j]
                new_arr[i][j] %= 1000
        return new_arr
    
    
    if __name__ == '__main__':
        main()

     

     

    여기서 시간초과가 나서 애먹었던 부분이 있다.

    square() 함수에서 return 할 때 원래는 아래와 같이 구현했었다.

    if b%2 == 0:
        return mul(square(arr, b//2, n), square(arr, b//2, n), n)
    else:
        return mul(mul(square(arr, b//2, n), square(arr, b//2, n), n), arr, n)

    이 방법의 문제점은 완전히 동일한 두 번의 재귀호출을 한다는 것인데,

    재귀에서 불필요한 반복은... 특히 재귀함수 자체가 반복되는 것은 기하급수적인 시간복잡도 상승을 유발한다.

    앞으로는 이런 디테일을 신경쓸 수 있도록 해야겠다.

    댓글

Copyright 2022. ProdYou All rights reserved.