PyTorch에서 다른 모델 구조 간 Weight 불러오기 방법

Bean·2025년 8월 19일
0

인공지능

목록 보기
110/123

모델 학습을 하다 보면 구조는 같지만 레이어 이름이 다른 모델 혹은 부분적으로만 weight을 불러와야 하는 상황이 종종 발생합니다. 예를 들어, 동일한 UNet3D 구조라도 state_dict의 키 값(레이어 이름)이 다르면 weight을 그대로 로드할 수 없는데요. 이번 글에서는 PyTorch에서 이런 경우 어떻게 대응할 수 있는지 정리해 보겠습니다.


🔎 PyTorch의 state_dict 구조

PyTorch에서 모델 파라미터는 state_dict라는 딕셔너리로 관리됩니다.

model.state_dict()
# 예시
{
  "down1.conv1.weight": ...,
  "down1.conv1.bias": ...,
  "up2.conv_transpose.weight": ...,
  ...
}
  • load_state_dict()레이어 이름(key)텐서 shape이 모두 일치해야 정상적으로 로드됩니다.
  • 이름이 다르면 "Unexpected key" 또는 "Missing key" 에러가 발생합니다.
  • 즉, 구조가 같더라도 이름이 다르면 자동 매칭은 불가능합니다.

✅ 해결 방법

1. 레이어 이름을 강제로 맞추기

모델 정의 시 기존과 동일한 이름을 사용하면 weight을 그대로 로드할 수 있습니다.
→ 가장 깔끔하지만, 코드 수정이 필요합니다.


2. strict=False 옵션 사용

일치하는 키만 불러오고 나머지는 무시합니다.

model.load_state_dict(torch.load("checkpoint.pth"), strict=False)
  • 장점: 가장 간단한 방법
  • 단점: 일부 weight이 초기화 상태라 성능 차이가 날 수 있음

3. 키 매핑해서 로드하기

state_dict의 키 이름을 바꿔치기 후 로드합니다.

state_dict = torch.load("checkpoint.pth")
new_state_dict = {}

for k, v in state_dict.items():
    new_k = k.replace("old_prefix", "new_prefix")  # 이름 변환
    new_state_dict[new_k] = v

model.load_state_dict(new_state_dict, strict=True)

4. 필요한 레이어만 부분적으로 로드

구조가 일부 달라도 공통된 부분만 불러올 수 있습니다.

checkpoint = torch.load("checkpoint.pth")
model_dict = model.state_dict()

# 공통 키만 추출
pretrained_dict = {k: v for k, v in checkpoint.items() 
                   if k in model_dict and v.shape == model_dict[k].shape}

model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

또는 특정 모듈(예: encoder)만 불러오기:

checkpoint = torch.load("checkpoint.pth")
encoder_weights = {k: v for k, v in checkpoint.items() if k.startswith("encoder")}
model.encoder.load_state_dict(encoder_weights, strict=False)

📌 정리

  • 구조와 shape가 같으면 이름만 맞춰주면 그대로 로드 가능
  • 이름이 다르면 키 매핑을 통해 해결 가능
  • 구조가 다르다면 공통 부분만 추려서 부분 로드 가능
  • 가장 간단한 방법은 strict=False, 세밀하게 제어하려면 키 필터링이나 이름 매핑 활용

profile
AI developer

0개의 댓글