Transfer Learning

Xpert·2024년 3월 5일
0

머신러닝

목록 보기
8/11

Transfer Learning 또는 Fine Tuning 이라는 기법은 사전 학습된 모델을 활용해 적은 데이터 셋으로도 모델의 성능을 끌어올리는 기법이다.

예를 들면 VGG16 기반으로 fine tuning 하는 방법이다.

from torchvision import models

model = models.vgg16(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
    
fc = nn.Sequential(
    nn.Linear(7*7*512, 256),
    nn.ReLU(), 
    nn.Linear(256, 64), 
    nn.ReLU(), 
    nn.Linear(64, 2),
)
  1. torchvision에서 VGG 모델을 가져오고 pretrained=True를 설정하여 가중치도 같이 가져온다.
  2. 모든 파라미터의 가중치를 업데이트 하지 않도록 한다. (Frozen 상태로 만듬)
  3. VGG 모델의 출력인 7x7x512=25088과 256을 Linear하게 연결한다.

이렇게 구성하면 VGG를 사용하는 Backbone 단에서 피쳐를 뽑을 때에는 기존 weight를 사용할 수 있고 뒤쪽 classifier만 새로운 데이터 셋으로 학습하게 된다. 즉 이미지의 특성을 추출하는 것은 일반적으로 이미 잘 훈련된 네트워크를 사용하고, 이미지를 분류하는 부분에서는 도메인에 맞게 학습된 레이어를 사용하여 학습 효율과 정확성을 높일 수 있는 것이다.

profile
Python, CV, ML, Backend

0개의 댓글

관련 채용 정보