SVM 알고리즘(분류)

밤비나·2023년 5월 25일
0

머신러닝

목록 보기
3/7

SVM의 이해와 활용

서포트 벡터 머신(SVM) : 데이터 분류를 위하여 마진(margin)이 최대가 되는 결정 경계선(decision boundary)을 찾아내는 머신러닝 방법

SVM 알고리즘 관련 용어의 이해

  • 결정 경계선(Decision boundary) : 한강은 도시가 강북인지, 강남인지를 구분하는 결정 경계선. 서로 다른 분류 값을 결정하는 경계.
  • 서포트 벡터(Supoort vector) : 서포트 벡터라는 단어에서 벡터는 2차원 공간 상에 나타난 데이터 포인트를 의미. 결정 경계선과 가장 가까이 맞닿은 데이터 포인트
  • 마진(Margin) : 서포트 벡터와 결정 경계 사이의 거리
  • 비용(Cost) : 얼마나 많은 데이터 샘플이 다른 클래스에 놓이는 것을 허용하는지 결정. 비용이 낮을수록 -> 마진을 최대한 높이고, 학습 에러율을 증가시키는 방향으로 결정 경계선 만듦, 비용이 높을수록 -> 마진은 작아지고, 학습 에러율은 감소하는 방향으로 결정 경계선 만듦.
  • 커널 트릭(Kernel Trick) : 실제로 데이터를 고차원으로 보내진 않지만, 보낸 것과 동일한 효과를 줘서 매우 빠른 속도로 결정 경계선을 찾는 방법(3차원 공간으로 옮겨진 데이터의 결정 경계선을 2차원 공간에 표현함)

SVM 알고리즘의 장단점

장점 :

  • 특성이 다양한 데이터 분류에 강함
  • 파라미터를 조정해서 과대/과소적합에 대응 가능
  • 적은 학습 데이터로도 정확도가 높은 분류 성능

단점 :

  • 데이터 전처리 과정이 매우 중요
  • 특성이 많을 경우, 결정 경계 및 데이터의 시각화가 어려움

SVM 알고리즘 활용하기

  1. 사이킷런의 그리드 서치(gridSearch)를 사용하여 간편하게 최적의 비용과 감마를 알아내기
# 패키지 및 데이터셋 추가
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC 
import numpy as np
import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/wikibook/machinelearning/2.0/data/csv/basketball_stat.csv")
train, test = train_test_split(df, test_size=0.2
  1. 개발자가 부여한 비용과 감마 후보를 모두 조합해서 최적의 비용과 감마 조합 찾아내기
# 최적의 SVM 파라미터 찾기
def svc_param_selection(X, y, nfolds): 
svm_parameters = [{'kernel': ['rbf'], 
'gamma': [0.00001, 0.0001, 0.001, 0.01, 0.1, 1], 
'C': [0.01, 0.1, 1, 10, 100, 1000]}]
clf = GridSearchCV(SVC(), svm_parameters, cv=nfolds)
clf.fit(X, y) 
print(clf.best_params_)
return clf
X_train = train[['3P', 'BLK']]
y_train = train[['Pos']]
clf = svc_param_selection(X_train, y_train.values.ravel(), 10)
{'C': 0.1, 'gamma': 1, 
'kernel': 'rbf'}
  1. 그리드서치를 통해 얻은 C와 감마를 사용해 학습된 모델 테스트
# 모델 테스트
X_test = test[['3P', 'BLK']]
y_test = test[['Pos']]
y_true, y_pred = y_test, clf.predict(X_test)
print(classification_report(y_true, y_pred))
print()
print("accuracy : "+ str(accuracy_score(y_true, y_pred)) )
              precision    recall  f1-score   support

           C       1.00      0.89      0.94         9
          SG       0.92      1.00      0.96        11

    accuracy                           0.95        20
   macro avg       0.96      0.94      0.95        20
weighted avg       0.95      0.95      0.95        20


accuracy : 0.95
# 실제 예측값 확인하기
comparison = pd.DataFrame({'prediction': y_pred, 'ground_truth': y_true.values.ravel()})
comparison
	prediction	ground_truth
0	SG	        SG
1	SG	        SG
2	C	        C
3	SG 	        SG
4	C	        C
5	C	        C
6	SG	        SG
7	SG	        SG
8	SG	        C
9	SG	        SG
10	SG	        SG
11	SG	        SG
12	C	        C
13	C	        C
14	SG	        SG
15	C	        C
16	C	        C 
17	C	        C
18	SG	        SG
19	SG	        SG
profile
씨앗 데이터 분석가.

0개의 댓글