lstm 10. pytorch best model 저장 방법

행동하는 개발자·2022년 12월 19일
0

RNN

목록 보기
9/14

state_dict

state_dict는 간단히 말해 각 계층 매개변수 텐서로 매핑되는 파이썬 사전 객체다. 모델의 학습 가능한 매개변수들은 모델의 매개변수에 포함되어 있다. state_dict는 간단히 말해 각 계층을 매개변수 텐서로 매핑되는 dict 객체이다. 이 때, 학습 가능한 매개변수를 갖는 계층 및 등록된 버퍼들만이 모델의 state_dict 항목을 가진다.

저장하기

torch.save(model.state_dict(), PATH)

불러오기


model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

모델의 state_dict를 저장하는 것이 나주에 모델을 사용할 때 가장 유연하게 사용할 수 있는 모델 저장 시 권장하는 방법이다.

일반 체크포인트 저장하기 & 불러오기

저장하기

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

불러오기

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

일반 체크포인트를 저장할 때는 반드시 모델의 state_dict보다 많은 것을 저장해야 한다. state_dict만을 저장하는 것보다 파일의 크기 자체는 2~3배 정도 커진다.

profile
끊임없이 뭔가를 남기는 사람

0개의 댓글