K-최근접 이웃 알고리즘(K-Nearest Neighbors)

Yoon1013·2023년 4월 25일
0
post-thumbnail

머신러닝이란 데이터를 바탕으로 컴퓨터가 스스로 데이터에 대한 규칙을 찾는 것이다. 어떤 모델을 통해 컴퓨터를 학습시킬지는 머신러닝 엔지니어가 결정한다.
따라서 머신러닝 엔지니어는 주어진 데이터를 분석하여 적절한 모델을 선택해야 하므로 데이터분석 프로세스에 맞춰 알고리즘을 설명한다.

❓ 문제 정의

분석의 목표

도미와 빙어를 분류하자
한빛 마켓의 물류센터에서의 배송 지연을 해결하기 위해 도미와 빙어를 자동으로 분류해주는 머신러닝을 만들자!

🤔 이진 분류(binary classification)

데이터를 여러개의 그룹으로 구별하는 것을 분류(classification) 문제라고 한다.
특히 지금 예시처럼 두개의 클래스(class)로 구분하는 것을 이진분류라고 한다.
cf) 다중분류

classificationclustering
입력(data)과
출력(답)을 다 줌
입력(data)만 줌
지도학습비지도학습

가설 설정

생선의 특징을 알면 쉽게 구분할 수 있을 것 같다!

💾 데이터 수집

물류 창고에서 도미와 빙어의 무게(g)와 길이(cm) 데이터를 수집

📍 특성(feature)

데이터의 속성을 의미한다.
속성, column, attribute 등등으로 부르기도 한다.
이 예시에서는 생선의 무게와 길이가 특성에 해당한다.

데이터 불러오기

우선은 실제로 데이터를 입력했다고 생각하자
도미 데이터 입력

# 도미 데이터 준비: 리스트 형태

bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]

bream_length와 bream_weight에서 같은 인덱스의 데이터가 한 묶음!

도미 데이터가 몇개일까?

#데이터 개수
print(len(bream_length), len(bream_weight))

출력

[해석]

  • 도미 데이터는 총 35개이다
  • 길이와 무게 데이터의 개수가 같으니 누락된 값 없이 데이터를 잘 입력했다

빙어 데이터 입력

# 빙어 데이터 준비
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

도미 데이터와 마찬가자로 같은 인덱스가 한 묶음!

빙어 데이터는 몇개일까?

print(len(smelt_length), len(smelt_weight))

출력

[해석]

  • 빙어 데이터는 14개이다
  • 길이와 무게 데이터의 개수가 같으니 누락된 값 없이 데이터를 잘 입력했다

⚙️ 데이터 준비

데이터 준비 단계는 데이터를 표로 변환하고, 필요한 속성을 선택하는 등 데이터를 모델에 입력하기 편하게끔 준비하는 단계를 말한다.
여기서는 스몰데이터에 사용하려는 속성이 두가지 밖에 없으므로 데이터 준비 단계는 생략한다.

📈 데이터 분석

데이터 분포 파악(그래프 그리기)

📍 scatter plot

수치형 데이터의 분포를 파악하기에 적합한 그래프이다.
matplotlib의 pyplot 함수나 seaborn 라이브러리를 사용해서 그릴 수 있다.

도미 데이터 분포 파악

#산점도 그리기
import matplotlib.pyplot as plt #matplotlib의 pyplot 함수를 plt로 줄여서 사용

plt.scatter(bream_length, bream_weight) #x축은 length, y축은 weight
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

출력

해석

  • 길이와 무게에는 양의 상관관계가 있다
  • 그래프가 일직선에 가까운 형태이므로 선형(linear)적 관계라고 볼 수 있다
  • 상관성이 높으므로 두 특성 중에 한가지만 사용하는 것이 더 바람직하다고 볼 수도 있겠으나 현재 예시에서는 특성이 두가지밖에 없으므로 모두 사용한다

빙어 데이터 분포 파악

plt.scatter(smelt_length, smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

출력

해석

  • 도미 데이터와 마찬가지로 길이와 무게에는 양의 상관관계 존재한다
  • 빙어가 도미에 비해 무게가 적게 나가는 경향이 있다
    → 추후 같은 스케일로 정규화 시켜야 함!!!: 후에 포스팅 예정

도미 데이터와 빙어 데이터를 한 그래프로 그려보자

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

출력

해석

  • 파란색이 도미 데이터, 주황색이 빙어 데이터이다
  • 빙어 데이터는 도미에 비해 무게가 길이에 영향을 덜 받는다
    (그래프가 더 완만함!)
  • 현재 스케일로는 빙어 데이터의 무게 차가 매우 적어 보이기 때문에 정규화 과정이 필요하다

통계적 요약정보 확인

통계 정보는 numpy 라이브러리를 이용하거나 dataframe 형식을 이용하면 직접 계산하지 않아도 쉽게 구할 수 있다.
아래 방법은 numpy 라이브러리를 이용한 방법이다.

도미 데이터의 통계 정보

#도미 데이터의 평균, 분산, 표준편차
import numpy as np

#파이썬 리스트를 numpy 배열로 변환
bream_length_numpy = np.array(bream_length)
bream_weight_numpy = np.array(bream_weight)

#길이 평균, 표준편차, 중앙값
print("평균: ", bream_length_numpy.mean(), "표준편차: ", bream_length_numpy.std(), "중앙값: ", np.median(bream_length_numpy))
#무게 평균, 표준편차, 중앙값
print("평균: ", bream_weight_numpy.mean(), "표준편차: ", bream_weight_numpy.std(), "중앙값: ", np.median(bream_weight_numpy))

출력

빙어 데이터의 통계 정보

#파이썬 리스트를 numpy 배열로 변환
smelt_length_numpy = np.array(smelt_length)
smelt_weight_numpy = np.array(smelt_weight)

#길이 평균, 표준편차, 중앙값
print("평균: ", smelt_length_numpy.mean(), "표준편차: ", smelt_length_numpy.std(), "중앙값: ", np.median(smelt_length_numpy))
#무게 평균, 표준편차, 중앙값
print("평균: ", smelt_weight_numpy.mean(), "표준편차: ", smelt_weight_numpy.std(), "중앙값: ", np.median(smelt_weight_numpy))

출력

해석

  • 도미와 빙어 모두 길이의 중앙값과 평균이 비슷하다: 길이의 경우 분표가 고르다
  • 무게의 경우 도미와 빙어 모두 평균과 중앙값에 차이가 있다: 무게는 표본의 range가 크다
  • 특히 도미의 경우 표준편차가 매우 크기 때문에 무게의 분포가 매우 다양하다고 유추할 수 있다
  • 도미와 빙어의 무게에 차이가 매우 크다: 추후 정규화 시킬 필요

🤖 머신러닝 프로그램: K-최근접 이웃 알고리즘

📝 K-최근접 이웃 알고리즘(K-Nearest Neighbors)

분류(Classification) 알고리즘의 하나로, K-NN, KNN 등으로도 부른다.
같은 클래스에 속하는 데이터는 같은 범주로 묶일 수 있을 것이다라는 가정하에 사용한다.
즉, "같은 클래스로 묶일 수 있는 데이터는 그들끼리 뭉쳐있을 것이다" 라는 가정 하에 사용한다.

예를 들어 위와 같은 그림에서 빨간색 삼각형 데이터는 근처에 초록 원들이 많이 분포하고 있으므로 초록 원일 것이다!라고 추측하는 것이다.
다시 한번 설명하자면 주위 K개의 데이터를 보고 다수를 차지하는 것을 답으로 사용한다.
"주위에 있다"는 의미는 거리가 가까운 데이터를 의미하고 거리는 유클리드 거리를 사용한다.
이 알고리즘은 예측하려는 타겟 데이터로부터 나머지 데이터의 직선거리를 모두 계산하여 가장 가까운 K개의 데이터를 골라내야 하므로 데이터가 많은 경우 사용하기 어렵다.

도미 데이터와 빙어 데이터 하나로 합치기

# 두 데이터를 합쳐 하나의 리스트로 만들기
length = bream_length + smelt_length
weight = bream_weight + smelt_weight

길이 데이터와 무게 데이터를 2차원 리스트로 만들기

# 데이터를 2차원 리스트로 만들기 -> 사이킷런 사용을 위해
fish_data = [[l,w] for l, w in zip(length, weight)]
print(fish_data)


2차원 리스트 형태!

🌟 왜 2차원 리스트로 만드는가?!
사용하려는 사이킷런(scikit-learn)머신러닝 패키지는 입력 데이터 형태가 2차원 리스트여야 함!!

정답 데이터 준비
우리는 도미 데이터 35개, 빙어 데이터 14개를 단순하게 합쳤으므로 앞에 35개 데이터는 도미, 그 다음 나머지 14개 데이터가 빙어이다.
도미와 빙어를 구분하기 위해서 컴퓨터에게 어느 것이 도미 데이터이고 어느 것이 빙어 데이터인지 알려줘야하므로 정답 리스트를 만들어야 한다.
보통 이진분류에서는 찾으려는 대상을 1, 그 외에는 0으로 설정하지만 이 경우 도미 1, 빙어 0으로 설정하려고 한다. 당연히 도미 0, 빙어 1로 설정해도 된다.

# 정답 리스트 만들기
# tip) 곱셈 연산자를 사용하면 간단하게 같은 데이터를 반복할 수 있다
fish_target = [1] * 35 + [0] * 14
print(fish_target)


훈련시킬 모델 준비

from sklearn.neighbors import KNeighborsClassifier
# 객체 만들기: 이 객체 훈련시킬거임!
kn = KNeighborsClassifier()

KNN 모델은 sklearn.neighbors 패키지 안에 들어있다.
생성한 KNeighborsClassifier() 객체에 데이터를 입력하여 훈련(train)시킬 것이다!

모델 훈련

kn.fit(fish_data, fish_target)

✅ fit() 메서드

머신러닝 모델을 학습시킬 때 위에서 만든 데이터와 정답(여기서는 fish_data와 fish_target)을 전달하여 분류 기준을 알려주는데, 이러한 과정을 훈련(train) 이라고 한다. 사이킷 런에서는 fit() 메서드를 사용하여 데이터를 전달한다.

모델 평가

kn.score(fish_data, fish_target)

💯 score() 메서드

모델을 평가하는 메서드이다. 0~1 사이의 값을 반환한다. 이 값은 score()의 매개변수로 전달하는 데이터를 얼마나 맞췄는지 정확도(accuracy)를 의미한다. 위에서는 전달한 fish_data를 0과 1로 판단하고 이 값을 fish_target과 비교하여 얼마나 맞췄는지 알려준다.


반환값이 1.0이므로 모델의 정확도는 100%이다!

예측

🤷‍♀️ predict() 메서드

결국 우리가 알고 싶은거는 "그래서 이 데이터는 도미이니, 빙어이니?" 아닐까. predict() 메서드는 새로운 데이터를 매개변수로 전달하여 이 데이터에 대한 정답을 예측한다. 마찬가지로 2차원 리스트의 값을 매개변수로 갖는다. 2차원 리스트가 어렵다면 리스트 안에 리스트가 들어가 있는 형태([[]])라고 생각하면 된다.

kn.predict([[30,600]])

"자, 길이가 30cm이고 무게가 600g인 이 생선은 도미니, 빙어니?"

리스트를 매개변수로 전달했으므로 결과로 리스트로 리턴된다.
우리는 도미를 1로 설정했으므로 이 모델은 길이가 30cm이고 무게가 600g인 이 생선을 도미로 예측했다.

객체의 X, Y 속성 알아보기

print(kn._fit_X)


잘 보면 fish_data를 그대로 갖고 있음을 알 수 있다

print(kn._y)


마찬가지로 fish_target을 그대로 갖고 있다

🌀 K에 따라 변하는 결과

기본적으로 K의 기본값은 5이다. 즉 주위의 5개의 데이터를 바탕으로 클래스를 분류하는데 당연하게도 K값을 어떻게 설정하냐에 따라서 결과는 바뀔 수 있다.

kn49 = KNeighborsClassifier(n_neighbors=49) #참고 데이터를 49개로 설정
kn49.fit(fish_data, fish_target)
kn49.score(fish_data, fish_target)

K값은 n_neighbors 속성을 통해 설정할 수 있다.
위의 예시의 경우 전체 fish_data의 개수인 49로 설정한 예시이다.
이 경우, 어떤 데이터를 넣어도 무조건 절대 다수인 도미로 예측한다

따라서 정확도는 35/49와 같은 위 값으로 나오게 된다.

📚 Reference

혼자 공부하는 머신러닝+딥러닝, 박해선, 한빛미디어

profile
Data Science & AI

0개의 댓글