Feature가 많은 경우, target에 적합한 feature를 가공하는 것을 특성 공학(feature engineering)이라고 한다.
SelectKBest는 target와 상관 관계가 높은 feature K개를 선택하는 전처리 방법이다.
이를 통해 feature가 많아서 발생하는 과적합을 줄일 수 있고, 모델의 성능을 높일 수 있다. 또한, feature가 감소하니 모델의 훈련 시간도 감수시킬 수 있다.
# X_trian, X_test, y_train, y_test 데이터셋이 존재
from sklearn.feature_selection import SelectKBest
selector = SelectKBest(score_func=f_regression, k)
# score_func는 상관 관계를 분석하는 방법을 나타내고, 공식문서를 통해 확인 가능하다
# defalut = f_classif
X_trian_sel = selector.fit_transform(X_trian, y_train)
X_test_set = selector.transform(X_test)
# select 된 feature 확인
all_col = X_train.columns
all_col[selector.get_support()]
# get_support() : select feature=True, Unselect feature=False로 반환
# selected feature 들의 영향력(score)을 확인
df_score = pd.DataFrame(selector.scores_,
selector.feature_names_in_)
.sort_values(col, ascending=False)
# 적절한 K를 어떻게 선택해야 할까?
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, r2_score
training = []
testing = []
ks = range(1, len(X_train.columns)+1)
# 1 부터 특성 수 만큼 사용한 모델을 만들어서 MAE 값을 비교 합니다.
for k in range(1, len(X_train.columns)+ 1):
print(f'{k} features') # 몇 개 feature
selector = SelectKBest(score_func=f_regression, k=k)
X_train_selected = selector.fit_transform(X_train, y_train)
X_test_selected = selector.transform(X_test)
all_col = X_train.columns
selected_names = all_col[selector.get_support()]
print('Selected names: ', selected_names) # select 된 feature name
model = LinearRegression()
model.fit(X_train_selected, y_train)
y_pred = model.predict(X_train_selected)
mae = mean_absolute_error(y_train, y_pred)
training.append(mae)
y_pred = model.predict(X_test_selected)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
testing.append(mae)
print(f'Test MAE: ${mae:,.0f}')
print(f'Test R2: {r2} \n')
# 시각화를 통해 mae 차이의 변화를 보고 적절한 k를 선택할 수 있음.
plt.plot(ks, training, label='Training Score', color='b')
plt.plot(ks, testing, label='Testing Score', color='g')
plt.ylabel("MAE ($)")
plt.xlabel("Number of Features")
plt.title('Validation Curve')
plt.legend()
plt.show()