본격적인 실습에 앞서 Nearest Neighbor을 알기위해서 아래의 사이트를 공부해보았다.
http://aikorea.org/cs231n/classification/
이미지 분류 문제는 특정 이미지를 미리 정해진 클래스로 분류하는 문제이다. 사람의 입장에서는 별일이 아니지만 컴퓨터의 관점에서 이미지를 보고 판별한다는 것은 매우 어려운 일이 될 것 같다.
이 방법은 실용성이 크게 없는 분류기라고 한다. 하지만 기본적인 접근 방법을 알기 위해 공부해보기로 한다. 코드를 따라가 보자.
Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar10/') # 제공되는 함수
# 모든 이미지가 1차원 배열로 저장된다.
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows는 50000 x 3072 크기의 배열.
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows는 10000 x 3072 크기의 배열.
nn = NearestNeighbor() # Nearest Neighbor 분류기 클래스 생성
nn.train(Xtr_rows, Ytr) # 학습 이미지/라벨을 활용하여 분류기 학습
Yte_predict = nn.predict(Xte_rows) # 테스트 이미지들에 대해 라벨 예측
# 그리고 분류 성능을 프린트한다
# 정확도는 이미지가 올바르게 예측된 비율로 계산된다 (라벨이 같을 비율)
print 'accuracy: %f' % ( np.mean(Yte_predict == Yte) )
import numpy as np
class NearestNeighbor(object):
def __init__(self):
pass
def train(self, X, y):
""" X is N x D where each row is an example. Y is 1-dimension of size N """
# nearest neighbor 분류기는 단순히 모든 학습 데이터를 기억해둔다.
self.Xtr = X
self.ytr = y
def predict(self, X):
""" X is N x D where each row is an example we wish to predict label for """
num_test = X.shape[0]
# 출력 type과 입력 type이 갖게 되도록 확인해준다.
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
# loop over all test rows
for i in range(num_test):
# i번째 테스트 이미지와 가장 가까운 학습 이미지를
# L1 거리(절대값 차의 총합)를 이용하여 찾는다.
# broadcasting 연산을 이해하는 것이 필요하다.
# Xtr이 50000 * 3072이고 X[i]가 1 * 3072이다. 따라서 연산을 하면 50000 * 3072 행렬이 나온다.
# 여기에서 sum을 axis=1방향으로 해주었기에
# 각 행을 더한 값들이 나열된 1 * 50000의 array가 나온다.
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
min_index = np.argmin(distances) # 가장 작은 distance를 갖는 인덱스를 찾는다.
Ypred[i] = self.ytr[min_index] # 가장 가까운 이웃의 라벨로 예측
return Ypred
위 예시에서는 L1 distance를 사용했지만 L2 distance를 사용할 수도 있다.
#Xtr - X[i]는 50000 * 3072
#np.square는 각 요소의 값을 제곱한다.
#따라서 (a-b)^2의 요소들이 50000 * 3072이다.
distances = np.sqrt(np.sum(np.square(self.Xtr - X[i,:]), axis = 1))
1-NN의 가장 큰 문제점은 training error가 낮은대신, test error가 높다는 데에 있다. 어떤 이미지를 분류하기 위해서 투표를 한다고 가정해보자. 그 이미지분류에서 가장 높은 투표를 받은게 dog인데 2, 3, 4, 5등으로 받은 투표가 cat이라면 그 이미지를 dog로 분류하는게 틀릴 가능성이 높다는 것이다.
Hyperparameter 튜닝을 위한 검증 셋 (Validation set)으로 교차 검증을 활용할 수 있다.
이제 학습한 내용을 바탕으로 하는 본격적인 실습 과정은 프로젝트 벨로그에 정리해두었다.