문제
문제
Suppose you have to evaluate an expression like A*B*C*D*E where A,B,C,D and E are matrices.
Since matrix multiplication is associative, the order in which multiplications are performed is arbitrary. However, the number of elementary multiplications needed strongly depends on the evaluation order you choose.
For example, let A be a 50*10 matrix, B a 10*20 matrix and C a 20*5 matrix.
There are two different strategies to compute A*B*C, namely (A*B)*C and A*(B*C).
The first one takes 15000 elementary multiplications, but the second one only 3500.
Your job is to write a program that determines the number of elementary multiplications needed for a given evaluation strategy.
예제 입력 1
9
A 50 10
B 10 20
C 20 5
D 30 35
E 35 15
F 15 5
G 5 10
H 10 20
I 20 25
A
B
C
(AA)
(AB)
(AC)
(A(BC))
((AB)C)
(((((DE)F)G)H)I)
(D(E(F(G(HI)))))
((D(EF))((GH)I))
예제 출력 1
0
0
0
error
10000
error
3500
15000
40500
47500
15125
정답 코드
import sys
def computation(m1, m2):
## computation이 가능한지 확인해야 함
## 안 되면 error 출력
if m1[1] != m2[0]:
return (None, None)
shape = (m1[0], m2[1])
computation_count = m1[1] * m2[1] * m1[0]
## computation 가능하면 행렬 연산 횟수와 행렬 shape 반환
return computation_count, shape
def check(matrices, line):
# print(line)
stack = []
count = 0
## 행렬 하나만 주어진 경우에는 연산 횟수가 0
if len(line) == 1:
return count
for c in line:
## 수식을 읽어가면서 스택에 넣기
if c == "(" or c == ")":
stack.append(c)
else:
stack.append(matrices[c])
# print("stack:", stack)
## ) 를 만났을 때는 pop
if c == ")":
metrix = []
current = c
## (을 만나기 전까지는 꼭 두 개의 행렬이 들어간다.
while current != "(":
# print("current:", current)
# print("metrix:", metrix)
# print(stack)
current = stack.pop()
if type(current) == tuple:
metrix.append(current)
local_count, shape = computation(metrix[1], metrix[0])
if not local_count:
return "error"
else:
count += local_count
stack.append(shape)
return count
input = sys.stdin.readline
N = int(input().strip())
matrices = {}
for _i in range(N):
name, m, n = (input().strip().split())
matrices[name] = (int(m), int(n))
lines = []
while True:
line = input().strip()
if line:
lines.append(line)
else:
break
for line in lines:
print(check(matrices, line))
문제 해결 접근
- 기본적인 아이디어 : 스택으로 풀어보자!
- 한 글자씩 읽어서 스택에 넣어주되, 괄호면 ()는 그대로, 행렬일 경우에는 행렬의 shape을 넣어준다.
- line을 다 넣어주면서 )을 만나면 pop을 시작! (을 만날 때까지 한다. 그러면 연산해야하는 두 개의 행렬이 나온다. 그러면 computation 함수를 통해 연산 횟수인 count를 업데이트하고, 행렬의 shape을 다시 스택에 넣어준다. 이런 방식을 통해 괄호 단위로 우선순위 연산을 유지할 수 있다.
- 추가적으로 정해지지 않은 개수의 입력을 받는 방법
- while 문으로 받아서 line이 있을 때까지 list lines에 넣어주고, lines에 대해서 check 함수를 실행시킨다.