내가보기위해 만든 torch 모델 Save와 loading

이성준·2023년 3월 16일
0

Pytorch

목록 보기
1/5

최근에 23학점을 듣게 되면서 정신없이 삶을 살던중에 연구실에서 Transfer Learning을 해야하는 상황이 생겼다 전체 데이터셋에 대해서 예측을 수행해야 했는데 그렇다보니 시간이 너무 오래걸렸다.. 이전에 10000개정도 되는 데이터를 예측할때는 시간이 2~3시간 정도 걸렸는데 4만개 가까이 되는 데이터를 예측을 하니깐 시간이 급수적으로 많이 걸리게 된 것이다. 그 사이에 인터넷 연결이 끊기거나 연구실 서버가 잠시 이상해지기라도 하면 처음부터 다시 학습을 시작해야하는 상황이 벌어졌다,,,,,,

이전까지는 3-4시간 정도 지난후 결과를 확인하면 확인할 수 있었는데 이번에는 3~4시간이 지나도 결과가 나오지 않아서 연구실에서 쓰는 노트북으로 돌려놓고 집에가서 데스크 톱으로 서버에 접속을 하게 되니깐 또 모델을 처음부터 학습을 시켜야하는 경우가 생겼다.

정리하자면 두가지 문제점에 직면했다
1. 갑작스러운 인터넷 연결 이상 혹은 서버 이상
2. 모델 예측 완료까지 걸리는 시간을 모르는 경우

따라서 나는 각각 1.은 모델을 save하는 방법을 택했고 2.은 예측이 완료되면 내 메일로 이메일이 보내지도록 프로그래밍 했다. 각각 시작해보겠다.

1. torch.save, torch.load, model.load_stat_dict

torch.save(객체,저장할주소)
→ 위에서 보여준대로 내가 저장하고 싶은 모델 객체 혹은 optimizer객체 그 자체나 모델의 spec 혹은 optimizer의 spec을 저장할 수 있다. 본인의 상황에 맞춰 사용하면 되는 것이다. 예시를 보자

위의 예시는 torch.save를 이용해서 모델의 checkpoint를 만들어놓은 것이다.
각각의 저장할 내용은 위처럼 Dictionary 형태로 만들어 줄 수도 있고 checkpoint를 저장할 것이 아니라면 그냥 torch.save(model,address)로 사용하면 된다. 이때 파일의 확장자는 *.pt로 하면 된다.

바로 모델을 저장할때는 모델 클래스가 선언돼있어야 하고 저장할때 모델과 비교했을때 변해있으면 안된다.

왜 *.pt로 저장해야 하는지를 알게 되면 수정하겠다.

torch.load("파일주소")
그렇다면 저장한 것을 불러올때는 어떤 식으로 수행하면 될까?
다음의 예를 보고 확인해보자

torch.save(model.state_dict(),"./model_stat_dict.pt")
model_stat_dict = torch.load("./model_stat_dict.pt")

자 위에 같이 model의 spec에 대해서 변수에 저장을 해놓고, 어색하긴 하지만 그 model spec을 바로 불러왔다고 하자. 그러면 이렇게 불러온 모델의 spec을 실제 모델에는 어떻게 저장할까?
이때 사용하는 것이 바로
model.load_stat_dict(model_stat_dict)
내가 생성한 모델 객체가 model일 경우 다음과 같이 불러오면 된다.

이때 이상하다,, 싶은 것이 있다. 분명 내 모델에는 load_stat_dict라는 메소드를 정의해두지 않았다. 그러면 어떻게 사용할 수 있는걸까? → 이는 내가 model을 정의하면서 상속해둔 nn.Module 클래스에 저장돼 있는 것이다.

model.load_stat_dict(model_stat_dict) 
#model = 하고 선언해주면 오류가 발생한다0
#이는 optimizer에서도 동일하게 적용된다

만약 model의 spec을 저장하지않고 torch.save(model,"./model.pt") 로 모델 자체를 저장하는 경우가 있을 수 있다. 이경우는 load_stat_dict를 사용하지 않고 load를 통해 모델을 바로 사용하면 된다. 하지만 이 경우 더 많은 저장공간을 요구하게 된다.

1.1 Checkpoint의 활용방안?

우리는 위에서 보여준 사진대로 checkpoint를 만들어 저장했다.
그렇다면 checkpoint = torch.load(~~PATH~~)를 통해서 저장해놓은 checkpoint를 loading을 하면, checkpoint에는 우리가 저장한 dictionary가 저장되게 될 것이다. 이때 key값을 이용해서 코드를 수정해서 모델을 이어서 Training을 시켜주면 되는 것이다.

또 서버나 인터넷이 다운되면 ,, 코드를 올려서 수정해보겠다..

2. SMTP를 활용한 메일 보내기

import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.image import MIMEImage
from email.mime.application import MIMEApplication

recipients = ["~~~~~@naver.com"]

message = MIMEMultipart();
message['Subject'] = '메일 전송 테스트'
message['From'] = "~~~~~@naver.com"
message['To'] = ",".join(recipients)

content = """
    <html>
    <body>
        <h2>{title}</h2>
        <p>메일 전송 테스트입니다</p>
    </body>
    </html>
""".format(
title = '메일.. 받으셨나요..?'
)

mimetext = MIMEText(content,'html')
message.attach(mimetext)

email_id = '~~~~~'
email_pw = '~~~~~'

server = smtplib.SMTP('smtp.naver.com',587)
server.ehlo()
server.starttls()
server.login(email_id,email_pw)
server.sendmail(message['From'],recipients,message.as_string())

위의 내용은 구글링을 통해 가져온 코드이다.
출처

파이썬 기본 패키지로 구현한 코드이다.
지금 당장에 이해할 내용은 아니니 본격적으로 이메일을 자동으로 보내야될 상황이 온다면 심도있게 라이브러리를 이해해 보겠다.

SMTP를 받기 위해서는 받는사람이 네이버 메일설정에서 SMTP를 수신한다고 설정해놔야한다.

profile
Time-Series

0개의 댓글