[CS231 2020 Lecture note 7]
CNN에서 N개의 batch에 대하여
- input image
x.shape
: * *filter.shape
: ***output.shape
: * *
참고하여 아래 코드를 이해해 봅시다.
import numpy as np
def conv(X, filters, stride=1, pad=0):
n, c, h, w = X.shape
n_f, _, filter_h, filter_w = filters.shape
out_h = int((h+2*pad-filter_h)/stride +1)
out_w = int((w+2*pad-filter_w)/stride +1)
# add padding to height and width.
in_X = np.pad(X, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant') # padding으로 모양이 커진 in_X
# print(n, n_f, out_h, out_w)
out = np.zeros((n, n_f, out_h, out_w))
for i in range(n): # 각 이미지별
for c in range(n_f): # output channel의 수만큼
for h in range(out_h):
h_start = h*stride # 시작지점에 stride 적용
h_end = h_start + filter_h
for w in range(out_w):
w_start = w*stride # 시작지점에 stride 적용
w_end = w_start + filter_w
out[i, c, h, w] = np.sum(in_X[i, : , h_start:h_end, w_start:w_end] *filters[c])
return out
X = np.asarray([
# image 1
[ [[1, 2, 9, 2, 7], [5, 0, 3, 1, 8], [4, 1, 3, 0, 6], [2, 5, 2, 9, 5], [6, 5, 1, 3, 2]], [[4, 5, 7, 0, 8], [5, 8, 5, 3, 5], [4, 2, 1, 6, 5], [7, 3, 2, 1, 0], [6, 1, 2, 2, 6]], [[3, 7, 4, 5, 0], [5, 4, 6, 8, 9], [6, 1, 9, 1, 6], [9, 3, 0, 2, 4], [1, 2, 5, 5, 2]] ],
# image 2
[ [[7, 2, 1, 4, 2], [5, 4, 6, 5, 0], [1, 2, 4, 2, 8], [5, 9, 0, 5, 1], [7, 6, 2, 4, 6]], [[5, 4, 2, 5, 7], [6, 1, 4, 0, 5], [8, 9, 4, 7, 6], [4, 5, 5, 6, 7], [1, 2, 7, 4, 1]], [[7, 4, 8, 9, 7], [5, 5, 8, 1, 4], [3, 2, 2, 5, 2], [1, 0, 3, 7, 6], [4, 5, 4, 5, 5]] ]
])
print('Images:', X.shape)
filters = np.asarray([
# kernel 1
[ [[1, 0, 1], [0, 1, 0], [1, 0, 1]], [[3, 1, 3], [1, 3, 1], [3, 1, 3]], [[1, 2, 1], [2, 2, 2], [1, 2, 1]] ],
# kernel 2
[ [[5, 1, 5], [2, 1, 2], [5, 1, 5]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 0, 2], [0, 2, 0], [2, 0, 2]], ],
# kernel 3
[ [[5, 1, 5], [2, 1, 2], [5, 1, 5]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 0, 2], [0, 2, 0], [2, 0, 2]], ]
])
print('Filters:', filters.shape)
out = conv(X, filters, stride=2, pad=0) #stride 2 : (5-3)/2+1
print('Output:', out.shape)
print(out)
>>>
Images: (2, 3, 5, 5)
Filters: (3, 3, 3, 3)
Output: (2, 3, 2, 2)
[[[[174. 191.]
[130. 122.]]
[[197. 244.]
[165. 159.]]
[[197. 244.]
[165. 159.]]]
[[[168. 171.]
[153. 185.]]
[[188. 178.]
[168. 200.]]
[[188. 178.]
[168. 200.]]]]