Transfer Learning 또는 Fine Tuning 이라는 기법은 사전 학습된 모델을 활용해 적은 데이터 셋으로도 모델의 성능을 끌어올리는 기법이다.
예를 들면 VGG16 기반으로 fine tuning 하는 방법이다.
from torchvision import models
model = models.vgg16(pretrained=True)
for param in model.parameters():
param.requires_grad = False
fc = nn.Sequential(
nn.Linear(7*7*512, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 2),
)
이렇게 구성하면 VGG를 사용하는 Backbone 단에서 피쳐를 뽑을 때에는 기존 weight를 사용할 수 있고 뒤쪽 classifier만 새로운 데이터 셋으로 학습하게 된다. 즉 이미지의 특성을 추출하는 것은 일반적으로 이미 잘 훈련된 네트워크를 사용하고, 이미지를 분류하는 부분에서는 도메인에 맞게 학습된 레이어를 사용하여 학습 효율과 정확성을 높일 수 있는 것이다.