오늘 공부한 내용은 이항분류 binary classification!
주어진 데이터를 두 개의 범주 or 클래스로 분류하는 작업
ex) 스팸메일 분류, 암 유무 판단 등
특징변수(Feature Variable)
= 설명변수, 독립변수. 모델이 입력으로 사용하는 데이터의 속성
여러 특징변수를 조합해 목표 변수를 예측한다.
목표변수(Target Variable)
= 종속변수, 반응변수. 모델이 예측하려는 출력 값.
이진 분류 문제에서 목표변수는 일반적으로 0,1 / T,F 같은 두 클래스를 말한다.
이진 분류 모델은 훈련 데이터를 이용해 특징변수와 목표변수 사이 관계를 학습해 최종적으로는 훈련하지 않은 데이터의 특징변수로 목표변수를 예측한다.
데이터는 DataFrame으로 불러와서 원하는 특징변수와 목표변수만 추출해서, 목표변수를 이산형 레이블로 매핑하고, 훈련-테스트 데이터를 분할해 표준화 시킨다.
해당 데이터를 model에서 사용하게 tensor로 변환해 준 뒤 Dataset과 DataLoader 클래스를 생성하고, nn.Model을 상속한 이진분류 모델을 만들어 훈련을 시킨다.
텐서로 변환하는 과정에서 목표변수의 훈련,테스트 데이터는 2차원으로 unsqueeze해주는데, 특징변수는 [데이터수, 특징수]로 이미 2차원이므로, 배치 처리 과정에서 호환성을 고려한 것이다.
DataSet & DataLoader 클래스
from torch.utils.data import Dataset, DataLoader
두 클래스로 전처리와 배치처리를 할 수 잇다.
이진분류 문제를 해결하기 위해 사용되는 머신러닝 모델.
특정 입력 데이터가 두 클래스중 하나에 속할 확률을 예측한다.
이미지 출처 : TheYoonicon
y = Sigmoid(z) = 1/(1+exp(-z))
에 넣어 결과를 0~1 사이의 클래스에 속할 확률로 변환한다. 0.5를 기준으로 1과 0으로 나눈다.이진 분류 모델에서 사용되는 손실함수
cf. 선형회귀에서는 MSE(평균제곱오차)
최대 가능도 추정(MLE)
관측된 데이터를 기반으로 가장 잘 설명하는 모수를 찾기 위해 가능도 함수를 최대화
하는 것
(미분 계수가 0이 되는 지점이 최대가 된다)
계산상의 이점을 위해 로그 가능도 함수를 사용하고 이를 최대화하는 파라미터(θ)를 찾는다.
loss_funcion = nn.BCELoss()
이진 교차 엔트로피를 사용한 이진분류 모델의 손실함수 코드
.isin(values)
import pandas as pd
# DataFrame df에서 city열의 뉴욕, 시카고 항목만 필터링한 데이터
filtered_df = df[df['City'].isin(['New York', 'Chicago'])]
filtered_df.loc[:, 'City'] = filtered_df['City'].map({'New York': 0, 'Chicago': 1})
from sklearn.model_selection import train_test_split
# 테스트를 20%로 설정한 코드. 난수는 관습적으로 42를 한다(별이유x)
f_train, f_test, t_train, t_test = train_test_split(f, t, test_size=0.2, random_state=42)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
f_train = scaler.fit_transform(f_train)
f_test = scaler.transform(f_test)
green_img_t = torch.cat((zero_t, green_only, zero_t), dim=-1)
green_img_t = torch.stack((zero_t, green_only, zero_t), dim=-1)
y = w1x1 + w2x2 + ... wnxn + b
intercept: [22.483147], other coef: [[-0.9662663 0.694296 0.25551897 0.7085805 -1.9914881 3.1231995
-0.17710298 -3.038284 2.205488 -1.7015713 -1.9768412 1.1222284
-3.6321542 ]]