데이터 취업 스쿨 스터디 노트 -(87) pytorch - 전체 딥러닝 플로우 구현

테리·2024년 10월 1일
0
post-thumbnail

데이터 불러오기

  • torch.utils.data.DataLoader()
  • transforms.Compose()

데이터 확인

  • squeeze()

모델 정의

학습 로직

pytorch에서는 fit 함수를 사용하지 않고 직접 학습로직을 작성함.

tf.keras.losses.SparseCategoricalCrossentropy(label_num, one-hot) 처럼 예측값은 원 핫 이지만 실제값은 원핫으로 변환하기 귀찮은경우 pytorch에서는 model의 출력을 F.log_softmax로 받고 학습을 F.nll_loss로 하면 동일한 기능을 할 수 있음.

0개의 댓글