torch model weight 저장 / 불러오기

jaeha_lee·2021년 6월 4일
0
def save_parameters(network,save_loc, optimizer):
    # Generator
    state_dict = dict()
    state_dict['network'] = network.state_dict()
    state_dict['optimizer'] = optimizer.state_dict()
    # state_dict['step'] = step
    # print(state_dict)
    torch.save(state_dict, save_loc)
    # state_dict['inr_function'] = self.inr_function.state_dict()

def load_parameters(network,optimizer,path_):
    data = torch.load(path_)
    layer,op = data['network'],data['optimizer']
    network.load_state_dict(layer,strict=True)
    optimizer.load_state_dict(op)

save_parameters 함수는 weight와 optimizer를 불러온다.
load_parameters 함수는 저장한 파일에서 load를 한다

0개의 댓글