모델 학습을 하다 보면 구조는 같지만 레이어 이름이 다른 모델 혹은 부분적으로만 weight을 불러와야 하는 상황이 종종 발생합니다. 예를 들어, 동일한 UNet3D 구조라도 state_dict
의 키 값(레이어 이름)이 다르면 weight을 그대로 로드할 수 없는데요. 이번 글에서는 PyTorch에서 이런 경우 어떻게 대응할 수 있는지 정리해 보겠습니다.
state_dict
구조PyTorch에서 모델 파라미터는 state_dict
라는 딕셔너리로 관리됩니다.
model.state_dict()
# 예시
{
"down1.conv1.weight": ...,
"down1.conv1.bias": ...,
"up2.conv_transpose.weight": ...,
...
}
load_state_dict()
는 레이어 이름(key) 과 텐서 shape이 모두 일치해야 정상적으로 로드됩니다."Unexpected key"
또는 "Missing key"
에러가 발생합니다.모델 정의 시 기존과 동일한 이름을 사용하면 weight을 그대로 로드할 수 있습니다.
→ 가장 깔끔하지만, 코드 수정이 필요합니다.
strict=False
옵션 사용일치하는 키만 불러오고 나머지는 무시합니다.
model.load_state_dict(torch.load("checkpoint.pth"), strict=False)
state_dict
의 키 이름을 바꿔치기 후 로드합니다.
state_dict = torch.load("checkpoint.pth")
new_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace("old_prefix", "new_prefix") # 이름 변환
new_state_dict[new_k] = v
model.load_state_dict(new_state_dict, strict=True)
구조가 일부 달라도 공통된 부분만 불러올 수 있습니다.
checkpoint = torch.load("checkpoint.pth")
model_dict = model.state_dict()
# 공통 키만 추출
pretrained_dict = {k: v for k, v in checkpoint.items()
if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
또는 특정 모듈(예: encoder)만 불러오기:
checkpoint = torch.load("checkpoint.pth")
encoder_weights = {k: v for k, v in checkpoint.items() if k.startswith("encoder")}
model.encoder.load_state_dict(encoder_weights, strict=False)
strict=False
, 세밀하게 제어하려면 키 필터링이나 이름 매핑 활용