나만봄 - KNN이 뭐야

esc247·2023년 3월 22일
0

AI

목록 보기
11/22

3월 초에 나온 뉴스이지만,
Meta에서 GPT3 보다 성능이 좋다고 알려진 LLAMA를 발표했다.

Dalai로 간단하게 실행 가능하다.

AI 발전 속도가 무시무시한데,
그만큼 우위를 차자하기 위한 싸움 구경도 재밌다.


ChatGPT

Q. What is KNN?

  • simple algorithm used for classification and regression problems
  • finding the K closest data points to a given input, and then predicting the output based on the most common output value among the K neighbors.
  • non-parametric algorithm, meaning it doesn't make any assumptions about the underlying data distribution.
  • lazy algorithm, meaning it doesn't do any training on the data beforehand and instead stores the entire training dataset.

Q. What are some disadvantages of using KNN?

  • Computationally expensive and slow, especially when dealing with large datasets and high-dimensional feature spaces.
  • Sensitive to the presence of noisy or irrelevant features in the data, which can negatively impact the accuracy of predictions.
  • Choosing the right value of k can be difficult, and there is often a trade-off between bias (using a low value of k) and variance (using a high value of k).

예시 코드

import numpy as np

class KNN:
    def __init__(self, k=3, distance_metric='euclidean'):
        self.k = k
        self.distance_metric = distance_metric
        
    def fit(self, X, y):
        self.X_train = X
        self.y_train = y
        
    def predict(self, X):
        predictions = []
        for i in range(len(X)):
            distances = self.compute_distances(X[i])
            top_k_indices = np.argsort(distances)[:self.k]
            top_k_labels = self.y_train[top_k_indices]
            most_common_label = np.bincount(top_k_labels).argmax()
            predictions.append(most_common_label)
        return np.array(predictions)
    
    def compute_distances(self, x):
        if self.distance_metric == 'euclidean':
            return np.sqrt(np.sum((self.X_train - x)**2, axis=1))
        elif self.distance_metric == 'manhattan':
            return np.sum(np.abs(self.X_train - x), axis=1)
        else:
            raise ValueError("Unsupported distance metric.")

개념이 간단해서 길게 쓸 게 없다.

KNN, K-Nearest Neighbor로 입력이 주어지면
그와 가까운 K개의 데이터를 보고 속할 그룹을 판단한다.

  1. 입력을 받는다.
  2. 저장된 각 데이터와의 거리를 구한다.
  3. K개의 가까운 sample을 찾는다.
  4. Majority Vote로 속할 그룹을 찾는다.

다음 순서로 알고리즘이 진행된다.

질문1. 거리를 어떻게 구할까?
Manhattan Diastance(L1), Euclidean Distance(L2), Hamming Distance, Cosine Similarity 등이 있는데 대체로 L1,L2 Distance를 많이 사용한다.
L1 : dL1(w,v)=i=1nwivid_{L1}(\mathbf{w},\mathbf{v}) = \sum_{i=1}^n |w_i-v_i|
L2 : dL2(w,v)=i=1n(wivi)2d_{L2}(\mathbf{w},\mathbf{v}) = \sqrt{\sum_{i=1}^n (w_i-v_i)^2}

질문2. K 값이 중요한데 이를 어떻게 결정할까?
K에 따라 결과도 달라지고 정확도도 달라진다.

K가 작으면 low bias and high varianceOverfitting될 수 있고
K가 크면 high bias and low varianceUnderfitting될 수 있다.

주어진 데이터셋에 대한 적절한 K 값을 구하기 위해 k-fold Cross Validation을 사용한다.
데이터셋을 test셋을 제외한 채 k개의 fold로 나누고
valid셋을 하나씩 바꿔가며 결과값을 구한다.
이렇게 하면 기존 cv에서 valid셋을 볼 수 없는 문제도 해결하고
valid가 편향될 가능성도 배제할 수 있다.

다음은 예시코드이다.

from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris

iris = load_iris()
X, y = iris.data, iris.target

k_values = list(range(1, 16))
cv_scores = []

for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k)
    
    # Perform 10-fold cross-validation
    scores = cross_val_score(knn, X, y, cv=10)
    
    # Store the mean score across all 10 folds
    cv_scores.append(scores.mean())

best_k = k_values[cv_scores.index(max(cv_scores))]

for k,cvs in enumerate(cv_scores):
    print(k+1,cvs)
print("Best k:", best_k)
'''
1 0.96
2 0.9533333333333334
3 0.9666666666666666
4 0.9666666666666666
5 0.9666666666666668
6 0.9666666666666668
7 0.9666666666666668
8 0.9666666666666668
9 0.9733333333333334
10 0.9666666666666668
11 0.9666666666666668
12 0.9733333333333334
13 0.9800000000000001
14 0.9733333333333334
15 0.9733333333333334
Best k: 13
'''

간단한 알고리즘인만큼 장단점이 뚜렷하다.

먼저 장점은 간단하다.
또한 학습을 시킬 필요 없이 데이터를 저장하고 있기만 하면 된다.
하지만 그렇기에 발생하는 단점이 있는데,
모든 데이터를 저장하고 있어야 하므로 높은 Memory usage and Computational burden이 발생한다.
또한 새로운 데이터가 입력으로 들어오면 기존 모든 데이터에 대한 거리를 계산해야 하므로 Inference 시간이 오래 걸린다.

기본적으로 우리는 Slow Trainig but Fast Testing이 필요하기에 KNN이 현재는 많이 사용되지 않는다.

profile
막상 하면 모르니까 일단 하자.

0개의 댓글