pretrained 모델 매개변수 일부만 로드하기

김준오·2021년 9월 15일
0

받은 업무

기존 identification 에 쓰이고있는 resnet 모델을 verification용으로 변형시키고 학습시켜보는 업무를 받았다.

전체적인 신경망 구성과 contrastive loss를 사용한다는점은 그대로 유지하되 약간의 입출력이 바뀜에 따른 변경처리를 해보려고 한다.

loss함수 : contrastive loss를 그대로 사용하지만,
기존 3채널 이미지 2개와 pairmap 1개, 총 3개의 매개변수를 받던것을 6채널 이미지 1개, pairmap1개 해서 총 2개의 매개변수로 줄이기로 했다.

model : 그에따라 기존 3채널로 들어가서 최종적으로 1024개 feature를 출력하던것을 6채널 입력받고 확률값 1개를 출력받는 형태로 바꾸려고 한다.

그 사이에 들어가는 복잡한 내부 신경망은 기존에 작동하고있는 방식을 그대로 쓰기로 했다.

따라서 입력부분과 출력부분만 살짝 바꿔줬다.

Fine tuning

전체적인 신경망의 구조는 유지되고있기에 처음부터 학습을 시키는것 보다, 기존에 학습시켜둔 신경망을 불러와서 추가적으로 학습을 시켜보려고 한다.

이런걸 Fine Tuning 이라고 하는것 같다.

물론 신경망의 구조를 바꾸고 추가 학습을 시키는거라서 기존 상태에 누적된 학습을 시킨다고 볼 수 는없을것같지만,
그래도 대부분의 구조는 그대로 두고 입력,출력 부분에 층만 1개씩 추가해준 상태이기때문에 어느정도 기존 상태가 도움이 되지 않을까? 하는 가설로 일단 불러와서 돌려보기로 했다.

근데 이렇게 신경망의 레이어가 달라지게되니 기존 모델의 파라미터를 불러오는 부분에서 어려움을 겪어서 그부분을 정리해두려고 한다. 이왕 정리하는김에 기본적인 모델의 저장,로드 방법부터 쭉 정리해둬야겠다.

모델 저장, 로드

모델 전체 저장/로드

저장하기

torch.save(model, PATH)

불러오기

# 모델 클래스는 어딘가에 반드시 선언되어 있어야 한다
model = torch.load(PATH)
model.eval()

state_dict 사용하여 저장/불러오기

문서들을 찾아보니 일반적으로는 요렇게 state_dict로 매개변수만 저장하고 불러오는 방식을 권장하는것 같다.

저장

PATH = "state_dict_model.pt"    #저장 경로

torch.save(net.state_dict(),PATH)

불러오기

model = Net()
model.load_state_dict(torch.load(PATH))

모델의 일부 파라미터만 불러오기

사실 이부분 처리하다가 기억해두려고 이 글을 쓰게 됐다.

일반적인 모델 로드하는 방법은 쉽게 찾을 수 있었는데 이거는 바로 찾지를 못했던 내용이라 나중에 쓸일 있으면 또 써먹으려고 한다.

pretrained_dict = torch.load(PATH)  # pretrained 상태 로드

model_dict = model.state_dict() # 현재 신경망 상태 로드

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 

# 3. load the new state dict
model.load_state_dict(pretrained_dict)

pretrained 신경망의 매개변수를 로드하고, 현재 신경망에다가 덮어씌워준다.

dict구조로 이루어져있기 때문에 python의 dict에서 지원하는 update 함수를 사용해서 덮어쓰기가 가능하다.

참고출처: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/16

python dict update : https://code.tutsplus.com/ko/tutorials/how-to-merge-two-python-dictionaries--cms-26230

학습에 필요한 모델의 기본적인 로드, 저장방법 및 특수한 상황에서의 로드방법까지 정리해봤다.

학습을 진행하며 checkpoint를 만들어서 학습된 epoch의 상태까지 중간저장을 하는 방법이 있는데 이건 아직 안해봐서 일단 해보고나서 다음번에 정리해야겠다!

profile
jooooon

0개의 댓글