조기 종료(early stopping)는 머신 러닝이나 딥 러닝에서 모델의 과적합을 방지하고 효율적인 훈련을 위해 사용되는 기법이다. 모델이 학습 데이터에 지나치게 적합되어 검증 데이터나 테스트 데이터에서 성능이 저하되는 과적합 상태를 방지하기 위해 주로 적용된다.
📎 작동 방식
- 훈련 및 검증 데이터:
- 학습을 위한 훈련 데이터와 모델의 성능을 평가하기 위한 검증 데이터를 준비한다.
- 훈련 및 검증:
- 모델을 훈련하면서 일정 간격마다 검증 데이터에 대한 성능을 평가한다.
- 조기 종료 조건 설정:
- 모델의 성능이 지속적으로 향상되지 않을 때를 감지하는 조건을 설정한다. 이 조건은 일반적으로 검증 데이터에 대한 손실(loss)이 감소하지 않거나 정확도(accuracy)가 일정 기간 동안 향상되지 않는 경우 등으로 설정된다.
- 학습 중단:
- 조기 종료 조건이 충족되면, 즉 모델의 성능이 개선되지 않거나 하나의 최적 지점을 지나쳤다고 판단되면 학습을 조기에 종료한다.
조기 종료는 모델이 최적의 성능을 내는 지점을 지나치지 않고 효과적으로 찾아낼 수 있도록 도와주며, 훈련 시간과 리소스를 절약할 수 있는 장점이 있다. TensorFlow, PyTorch, Scikit-learn 등의 라이브러리에서는 조기 종료를 위한 콜백(callback)이나 함수를 제공하여 쉽게 적용할 수 있다. 이러한 기능을 활용하여 모델 훈련 시에 조기 종료를 구현할 수 있다.
조기 종료를 실습을 하기 위해 Keras 라이브러리에서 제공하는 EarlyStopping 콜백을 사용할 수 있다.
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# 가상의 분류 데이터 생성
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 모델 구성
model = Sequential()
model.add(Dense(32, input_shape=(20,), activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 조기 종료 설정
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
# 모델 컴파일 및 학습
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_test, y_test), callbacks=[early_stopping])
# 모델 평가
test_loss, test_acc = model.evaluate(X_test, y_test)
print(f'Test Accuracy: {test_acc:.4f}')
monitor
매개변수는 모니터링할 지표를 선택하고, patience
는 성능이 향상되지 않을 때 얼마나 기다릴지를 지정한다.fit
함수를 호출하여 학습을 진행한다. 여기서 EarlyStopping 콜백을 콜백 리스트에 추가한다.이 예제에서 EarlyStopping 콜백은 검증 데이터의 손실(val_loss
)을 모니터링하고, 3번의 epoch 동안 성능 향상이 없을 경우 학습을 조기 종료시킨다. restore_best_weights=True
로 설정하면 최선의 성능을 내는 모델의 가중치를 복원한다.