TTA가 뭘까?

기린이·2021년 1월 22일
0

Data Augmentation

기존 데이터셋을 회전/반전/줄이기/늘이기/노이즈 등을 줘서 데이터셋을 늘리는 것

TTA

Test Time Augmentation, data augmentaion은 학습데이터를 늘리는 것이었다면 TTA는 테스트 데이터를 늘리는 것.
한 장의 테스트 사진이 있을 때 이것을 Augmentation한 사진으로 테스트해서 나온 예측값들의 평균으로 최종 예측

train_datagen = ImageDataGenerator(validation_split = 0.2,
                                horizontal_flip = True, #좌우 반전 여부
                                vertical_flip = True, #상하 반전 여부
                                rotation_range = 30, #회전 제한 각도
                                width_shift_range = 0.1, #좌우 이동가능 비율
                                height_shift_range = 0.1, #상하 이동가능 비율
                                fill_mode = 'nearest',
                                zoom_range = 0.2) #확대축소 비율

케라스의 ImageDataGenerator를 사용하면 간단하다.
위 코드에선 train data를 augmentation한 것이지만 이를 val/test set에 적용하면 TTA를 할 수 있다.

tta_steps = 10

test_datagen = ImageDataGenerator(
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        rotation_range=10.,
        fill_mode='reflect', 
        width_shift_range = 0.1, 
        height_shift_range = 0.1)

predictions = []

for i in tqdm(range(tta_steps)):
    preds = model.predict_generator(test_datagen.flow(x_val, batch_size=bs, shuffle=False), steps = len(x_val)/bs)
    predictions.append(preds)
    
final_pred = np.mean(predictions, axis=0) #TTA한 사진 10장의 pred 평균


print(f'Accuracy with TTA: {np.mean(np.equal(np.argmax(y_val, axis=-1), np.argmax(final_pred, axis=-1)))}') # label과 비교 후 정확도

참고자료

profile
중요한 것은 속력이 아니라 방향성

0개의 댓글