[머신러닝]Multilabel / Multioutput Classification

김태경 SMARCLE·2024년 8월 9일

머신러닝

목록 보기
1/9
post-thumbnail
  • 본 글은 다음의 교재를 참고하여 작성되었습니다. - "Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, P108~P112"

Multilabel Classification

직역하자면 다중레이블 분류. 결과값으로 여러 개의 레이블을 가지는 시스템입니다.

1) Multiclass Classification과의 차이

앞서 수요일에 다룬 Multiclass Classification(다중클래스 분류)과는 비슷하지만 다음과 같은 차이가 있습니다.

Multiclass Classification이 1개의 레이블에 대하여 여러 종류의 클래스 중 하나를 출력하는 방식이라면, Multilabel Classification은 일반적인 Binary Classification과 같이 True / False를 출력하는 대신 여러 레이블에 대하여 출력이 가능합니다.

2) 구현

from sklearn.neighbors import KNeighborsClassifier

#7 이상의 수에 대한 레이블
y_train_large = (y_train >= 7)
#홀수에 대한 레이블
y_train_odd = (y_train % 2 == 1)
#앞 2개의 레이블을 합침
y_multilabel = np.c_[y_train_large, y_train_odd]

#knn을 이용해 모델 생성
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)

#만들어진 모델을 이용해 MNIST 데이터 중 하나에 대해 테스트
knn_clf.predict([some_digit])

위 코드는 교재 P108의 코드로, KNN을 활용하여 Multilabel Classification을 구현하였습니다.
KNN은 각 레이블에 대해 독립적으로 예측을 할 수 있기 때문에 2개 이상의 레이블을 출력해야 하는 해당 방식에 적합합니다.

코드 내에서 넣은 some_digit는 숫자 5의 이미지이기 때문에, 결과는 다음과 같이 나옵니다.

array([[False,  True]])

3) 평가 지표 관련

교재 내에는 '다양한 평가 지표가 있으며, 자신의 프로젝트에 맞는 방식을 사용하면 됩니다' 라고만 나와있습니다.
예제로는 F1스코어 방식이 나와있네요.

그래서 구체적으로 어떤 평가 지표들이 있는지,
그리고 어떤 기준으로 선택해야 하는지 좀 더 조사해보았습니다.
https://g.co/gemini/share/0b08e393e17e
https://heytech.tistory.com/434
위 글과 제미나이 답변에 의하면 평가지표들을 크게 2종류로 나눌 수 있습니다. 두 종류에 대한 설명과, 이들에 해당하는 몇가지 평가 지표를 간단히 소개하겠습니다.

  • Example-based Evaluation : 하나의 테스트 데이터 별로 정답과 예측 간의 차를 구하고, 데이터들에 대해 그 차들의 평균을 지표로 삼는 방식들
    • F1스코어 - 계산을 각 레이블에 대해서 전부 실시한 뒤 평균을 냄
    • Hamming Loss - 전체 계산 중 틀린 레이블 수를 계산한 뒤 평균을 냄
    • Exact-Match Ratio(EMR) - 모든 레이블에 대해 정확한 예측을 한 데이터의 비율을 셈
  • Label-based Evaluation : 반대로 하나의 레이블에 대해 정답과 예측 간 차를 구해 전체 레이블에 대한 평균을 지표로 삼는 방식들
    • Macro-average - 레이블별 평가 지표 계산 후 전체 테스트 데이터 수로 평균을 구하는 방식
    • Weighted-average - 레이블별 TP+FN의 수를 고려하여 평균을 구하는 방식

교재에서는 F1스코어와 Macro average 방식을 같이 사용한 것 같네요.

Multioutput Classification

이쪽도 직역하면 다중출력 분류. 하나의 입력 데이터에 대해 여러 개의 출력 레이블을 예측하는 작업으로, 여기서 레이블의 클래스는 2개 이상일 수 있습니다.
쉽게 말해, Multilabel + Multiclass입니다.

예시를 들자면 YOLO에서 카메라에 잡히는 여러 물체마다 human / car 등 식별하는 기능이 바로 여기에 해당하겠네요.

profile
네이버 블로그 업로드 전 개념정리용

0개의 댓글