PyTorch Project Template

DONGJIN IM·2022년 4월 23일
0
post-thumbnail

PyTorch Project Template

Template 설명

이미 OOP를 통한 완벽한 모듈화로 Template를 미리 형성해 놓은 것이 GitHub에 많이 공유되고 있다.
이런 Template에서는 config 파일(형식 : json 등)을 내가 원하는 값으로 변경시키고, train.py 파일에 내가 원하는 모델을 입력함으로써 다른 부분은 건드리지 않고 내가 원하는 학습을 수행시킬 수 있다.
(이런 코드를 통해 실행 환경을 저장하고, 깃허브 충돌을 막을 수 있다)

  • 실행 방법 : python train.py -c config.json
  • Visual Studio Code를 통해 Python 코드 보기 : code .(cmd에서)

Template 구조 뜯어보기

나중에 내가 Code를 짜다 보면 Template과 유사하게 짜지는 것을 알 수 있다.
따라서, 이미 전문가들에 의해 만들어진 완벽한 코드 Template의 구조를 잘 살펴보고, 훗날 활용하도록 해보자

train.py

Python 코드를 수행시켜주는 주체이다.
parse_config.py의 from_args 메서드에서 객체를 반환받아 main 메서드를 수행하는 방식으로 학습이 진행된다.

  • args : Temrinal에서 Argumen셍 대한 설정을 지정해 준 Part
    • -c : 설정 값을 불러오는 argument. 주로 config.json에 설정값을 저장

parse_config.py(ConfigParser 객체)

  • from_args : Argument로 불러온 파일을 해석(config.json)

    • Factory 패턴 활용 : 재료(Arugmnet)를 넣어주면 객체를 도출하는 패턴
    • Argument로 받은 값들을 활용해 객체를 형성하고, 이 객체를 train.py에서 활용
  • __getitem__ : index 값을 넣어주면 index에 존재하는 값을 불러오는 메서드

  • init_obj(A,B) : A 폴더 속에 존재하는 B.py 파일에서 모듈을 불러오기 위한 메서드

    import data_loader.data_loaders as module_data
    
    init_obj('data_loader', 'module_data')
    • data_loader 폴더 접속 & data_laoder.data_loaders.py에 접속
    • 이후, data_loaders.py에 저장된 객체를 가지고 와 학습에 활용
    • config.json 파일에 불러올 모듈에 대한 정보가 저장되어 있으므로, data_loaders.py에 모듈이 여러개 존재하더라도 어떤 모듈을 가져와야 할지 미리 알 수가 있음
    "data_loader":{
       "type" : "want_to_call_module",
       "args":{
                ~~~~~~
       }
    }
    • 위 코드에서는 json 파일에서 want_to_call_module이라는 이름의 모듈을 불러옴
    • 만약 "type" 부분의 value부분만 바꾸면, 가져올 Module을 빠르게 변경할 수 있음. OOP를 통한 완벽한 모듈화로 인해 가능한 방법이다.

util.py

  • read_jons : json 파일을 읽어와서 (Ordered) Dict type Data로 만들어줌
  • trainer.py : 핵심 코드. 학습이 이뤄지는 공간
    • base_trainer.py : Loggin에 대한 코드가 저장되어 있고, 특히 어느 정도 Epoch이 진행되면 자동으로 학습 모델을 중간 저장하는 코드가 입력되어 있다
    • Parameter를 통해 Loss Function, Data 등에 대한 정보를 입력 받음
    • 아래 코드는 훗날 다시 자세히 살펴 볼 것이지만, 형태는 지금 알아두자
self.optimizer.zero_grad() # Gradient 초기화
output = self.model(data)  # y의 예측값 생성
loss = self.criterion(output, target) 
# y의 예측값인 output과 실제 target 사이 loss값 측정

loss.backward() # Backward Propagation 수행
self.optimizer.step() # Gradient Update

# 위 과정 전체가 Batch 1개가 1 Epoch 수행되는 과정임

test.py

학습한 모델에 대한 정확도 측정 코드로써, Inference를 수행해주는 파일이다.

profile
개념부터 확실히!

0개의 댓글

관련 채용 정보