나는 여전히 웹개발자이지만, 회사에 부쩍 딥러닝 과제가 많아져서 딥러닝 공부가 많이 필요하게 됐다. 밑바닥부터 시작하는 딥러닝 책이 좋다고 해서 천천히 하나하나 공부해보려 한다.
신경망을 이용하여 손글씨의 숫자를 분류합니다. 문제 해결 과정은 일반적으로,
그런데 여기서는 1번은 생략하고 이미 학습된 가중치를 이용하여 분류합니다.
학습된 가중치를 이용하여 추론하는 과정을 신경망의 순전파(forward propagation) 라고 합니다.
위와 같이 생긴 손글씨 숫자 이미지가 MNIST 데이터 셋입니다.
위의 코드는 이미지를 다루는 코드입니다. 여기서 주의깊게 봐야 할 코드는
def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
위의 코드로 이미지는 일정한 형식을 가진 numpy array로 변할 수 있고, 일정한 형식을 가진 numpy array는 이미지로 변할 수 있다는 것을 알게 됩니다. img_show
함수는 numpy로 구성된 숫자로 된 배열들을 PIL 라이브러리가 읽을 수 있도록 해주는 코드입니다.
load_mnist(flatten=True, normalize=False, one_hot_label=False)
여기서는 이미지를 불러오는 형식에 대해 알 수 있는데,
flatten은 이미지를 1차원 배열로 가져올 것이냐에 대한 옵션입니다. 28x28의 2차원 배열로 이루어진 이미지는 784개의 1차원 배열로 표현될 수 있습니다.
normalize는 입력 이미지의 픽셀 값을 기존의 0~255 값에서 0.0~1.0 사이의 값으로 정규화할지 결정합니다.
one_hot_label은 레이블(정답)을 원핫 인코딩 형태로 저장할지 결정합니다.
img = img.reshape(28, 28)
위의 코드는 flatten 옵션에 의해 불러와진 784개의 1차원 배열로 이루어진 이미지를 다시 28x28로 돌려놓는 함수입니다.
img_show(img)
위 코드는 이전에 설명했던 img_show 함수를 이용하여, 배열로 이루어진 이미지 정보를 pil 라이브러리를 통해 실제로 어떤 이미지인지 우리에게 시각적으로 보여주는 코드입니다.
이전의 신경망 구성 예제 코드와 별다를 것 없습니다. 가중치를 .pkl
파일에서 불러온다는 점만 조금 특이합니다.
분류이니 predict(forward)의 마지막 부분에서 softmax를 통과시키고, np.argmax로 가장 확률이 높은 원소의 인덱스를 얻어 분류합니다.
위의 코드는 학습된 가중치에 대한 정보를 알아본 코드입니다. 다음과 같이 입력 값은 계속 다음 가중치로 전달되어, 마지막에는 10개의 값만 남게 됩니다.
위의 형상을 보면 (784, )형태의 입력배열 X가 (784, 50) 형태의 가중치 W1를 만나서 각 원소마다 연산이 되고, 그 결과는 (50, ) 형태가 됩니다. 그리고 (50, 100)형태의 가중치 W2와 연산이 되어, (100, )의 형태가 되고, (100, ) 형태의 값은 (100, 10)형태의 가중치 W3와 곱해져 마지막으로 10개의 배열이 나오게 됩니다.
마지막으로 나온 10개의 배열에 np.argmax()를 수행하면 분류의 결과가 나옵니다.
위는 배치로 100개씩 처리한 소스코드입니다.
python for문의 step argument를 이용하여 구현하였습니다.
np.sum(p == t[i:i+batchsize])
부분은 numpy를 이용한 약간의 트릭이 들어갔는데, True나 False로 반환되는 값들을 numpy의 sum 함수로 모두 더하여 정확도를 계산합니다.
배치로 작성한 소스코드의 연산과정을 살펴보면 다음과 같습니다.
처음에 100개의 이미지를 넣고 y에 100개에 대한 결과를 모두 받아오는 것을 유심히 보시면 될 것 같습니다.
배치처리를 하면, 느린 I/O를 통해 데이터를 읽는 횟수가 줄어서 빠른 CPU나 GPU로 순수 계산을 수행하는 비율이 높아집니다. 큰 배열을 한번에 계산함으로써, 분할된 작은 계산을 여러번 하는 것보다 더 빠른 퍼포먼스를 끌어냅니다.
혹시 def get_data():
(x_train,t_train)......
부분에서 x_train과 t_train이 unused variable로 나오는데 어떻게 해결해야되나요..?