torch save 아키텍처 파라미터 저장하기

게으른 개미개발자·2022년 9월 22일
0

model_conversion

목록 보기
1/13
post-thumbnail

yolov7 torch script 활용하여 Onnx Export

Tensorflow 모델의 경우 tf hub라는 곳이 있어, 텐서플로우 모델들을 일일히 저장하지 않고, 다양한 형태의 모델들을 받아서 onnx로 변환해보거나 최종적으로 TensorRT inference까지 해볼 수 있었다.

TensorFlow Hub

하지만 pyTorch 모델의 경우, torch hub가 존재하기는 하지만 모델이 tensorflow만큼 다양하지 않았다. 따라서 torchvision 라이브러리에 존재하는 모델을 제외하고, 다른 github나 paper에 존재하는 pyTorch DL code에서 직접 모델을 저장하고 onnx 및 trt engine으로 변환해보는 작업을 하게 되었다.

pyTorch 모델의 경우, 모델 학습의 결과를 저장할 수 있는 방법이 존재한다.

model.save & model.load 를 활용하면 pyTorch 모델을 저장할 수 있다.

저장할 수 있는 유형은 크게 보면 2가지이다.

  1. 모델 학습의 결과를 저장(parameter, weights, bias…)
  2. 모델 형태(architecture)와 모델 학습 결과(parameter)를 동시에 저장

일반적으로 많은 사람들이 알고 사용하는 것이 1번이다. 실제로 모델을 학습하고 다른 프레임워크로 변환하는 일은 크게 많지 않고, 이전에 학습했던 가중치를 활용하는 경우가 많기 때문에, 1번 형태의 파일을 많이 사용한다.

확장자는 .pt 를 사용하며, 딕셔너리 형태로 파일이 저장된다.

state_dict 이라고 부르며, 모델의 파라미터를 갖고 있다.

## 모델의 파라미터만 저장
torch.save(model.state_dict(), os.path.join(MODEL_PATH, 'model.pt'))

하지만 나는 2번의 형태가 필요하다.

그래서 구글링하고, pyTorch 레퍼런스도 많이 확인해본 결과,

torch.save(model, os.path.join(MODEL_PATH, 'model.pt'))

이런식으로 저장하게 되면 된다고한다.

하지만 torch.load() 형태로 불러오게되면, 모델 architecture와 parameter가 껍데기만 담고 있게된다. 결국 코드가 존재하지 않는다면, 텐서플로우의 saved-model 처럼 protobuf 파일만을 갖고 onnx를 통해서 inference를 진행할 수 없게 된다.

그렇다면 어떻게 해야될까…?

(여기서부터,,, 멘탈 바사삭)

여기서 갑자기 등장한 torch.jit.tracetorch.jit.script

텐서플로우 및 ONNX와 TensorRT 모델들 입문하면서 살짝살짝 들었던 trace와 script 개념이 등장하였다.

TorchScript - PyTorch 1.12 documentation

Trace 방식


pyTorch에서 model(input)과 같이 torch.nn.Module을 상속하는 모델에게 input을 넘겨주면 해당 모델 클래스의 forward 함수가 실행되게 된다. 이 때, forward 함수 내부의 다양한 함수 및 모듈들이 Tensor 연산을 호출하게 되고, 종속적인 python script들도 모두 실행되게 된다. 이렇게 forward가 한 번 수행하는 동안 execution path에 존재하는 모든 연산들이 기록되게 된다.

하지만 Trace 방식으로 저장할 경우, 문제가 되는것이 forward 함수 내부에 dynamic control flow가 존재한다면, trace를 생성하기 위해 한 번 forward가 호출되었을 때, 호출되지 않은 path는 trace되지 않는다는 점이다. 예를 들어, 모델 내부에 if else와 같은 분기가 존재한다고 했을 때, 한 번의 iteration에 해당되는 분기만을 통과하고 다른 분기는 trace되지 않고 static하게 저장된다.

또한, trace 시에는 forward 함수 호출을 위한 input이 필요하다. 따라서 input shape에 따라 그래프가 static하게 되기 때문에, trace에 들어간 input 값과 동일한 shape의 input이 필요하다.

script 방식


Trace의 dynamic control flow에서 발생하는 문제점을 해결할 수 있는 방법이다. 컴파일하고 빌드를 하는 C,C++,JAVA와 같은 프레임워크처럼 code 전체를 컴파일하여 사용하는 방법이다. 따라서 forward propagation 시 실행될 전체 코드에 대해서 컴파일을 진행하고 TorchScriptCode인 ScriptModule 인스턴스를 생성한다. 전체 코드를 보고 컴파일하기 때문에 if else문제도 해결할 수 있고, input도 따로 필요하지 않다.

Yolo v7 활용하여 TensorRT Inference


Yolo v7 의 pyTorch version과 Onnx version, TensorRT(FP32) version을 벤치마킹하여, 어떻게 성능이 좋아지게 되는지 검증을 해볼 것이다.

우선 Yolo v7 export.py에 기본적으로 Onnx 및 TensorRT로 export할 수 있는 코드가 존재한다. 해당 코드를 활용하면 쉽게, Onnx와 TensorRT 프레임워크로 export 후 Inference 진행해 볼 수 있다.

하지만 직접 conversion 툴을 커스터마이징하고 있는 입장에서, 해당 툴을 사용하지 않고, 변환을 시도해보려고한다. 일단, 첫번째로 실행했던 방법에 대해서 말해보려고 한다.

  1. torch.save()

    → 저장은 됐다. pt파일로 저장되었으며 type을 찍어보니, nn.module을 상속받은 forward가 있는 클래스였다. 하지만 모델의 클래스 코드가 필요했으며, 경로정보가 필요했다.

    class Model(torch.nn.Module):
        pass
    
    class Detect(torch.nn.Module):
        pass

    pt 파일을 로드하기 위해서 위와 같이 class를 선언해줄 필요도 존재했다. 또한, 코드 경로가 필요하기 때문에, pt파일이 단독적으로 사용되지는 않았다. 그리고 netron으로 시각화해보았는데, 껍데기 정보만 담겨있을 뿐 그래프 architecture가 전혀 담겨있지 않았다. 결국 다른 방법이 필요했다. ㅠㅠㅠㅠ

  2. torch.jit.save() ← Trace방식

     dummy_input = torch.randn(32, 3, 256, 256, device="cuda")
     trace_cell =torch.jit.trace(model,(dummy_input))
     torch.jit.save(trace_cell, os.getcwd()+'/yolo_v7_trace.pt')

    → 해당 방식으로 export했을 때, 여러 에러가 발생하기는 했지만 그래도 pt파일이 생성되었다. 이후 netron을 활용하여 모델을 시각화해봤을 때도, 정상적인 그래프 모습으로 보이기는 했다.

    ![](https://velog.velcdn.com/images/sjj995/post/72c31797-506e-46f0-807a-6cb5a8b2e3d8/image.png)

    export한 pt파일 시각화한 결과, 얼추 그래프의 모습을 띈다.

    이후 torch.onnx.export() 모듈을 활용하여 onnx로 변환해주었다.

    torch.onnx.export(model4, dummy_input, "custom_yolo.onnx", verbose=False,
    opset_version=13)

    → export 결과, 자잘자잘한 warning이 존재하기는 했지만, 정상적으로 onnx파일이 생성되었다. 이게 정상이 아닌건가…

모델의 구조가 비슷해보였지만, 아래로 내려갈수록 조금 많이 달랐다. onnx에서 지원되지 않는 operation들이 많아서 대체되어 변환되었다고 추측하기는 했는데 살짝 불안했다.

혹시 몰라서 직접 만든 onnx parser와 inference 툴로 inference를 진행해보았다.

vector 인덱스 에러가 뜨는데, 어디서 터지는지는 정확히 모르겠다. ㅠㅠㅠㅠ 혹시 몰라서 polygraphy와 trtexec 툴을 활용하여, 똑같이 변환 과정을 거쳐봤는데 동일한 에러가 뜬다. 위 사진은 trtexec를 활용했을 때의 결과이다.

일단, trace방식은 예상대로 잘 변환이 되지 않는 것 같다.

  1. torch.jit.save() ← script 방식
    script_cell =torch.jit.script(model,dummy_input)
        with torch.no_grad():
            torch.jit.save(script_cell, os.getcwd()+'/yolo_v7_script.pt')
    script 방식으로 pt파일을 생성해보았을 때, 해당 에러가 존재했다.

Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.

구글링 결과, torch.cat 부분에 들어가는 x 값이 list형태로 들어가서 생기는 오류였다…

많은 구글링 결과 문제를 해결해볼 수 있었다.

 from typing import List
 
 class Concat(nn.Module):
     def __init__(self, dimension=1):
         super(Concat, self).__init__()
         self.d = dimension
 
     def forward(self,x:List[torch.Tensor]):
         return torch.cat(x, dim=self.d)

x:List[torch.Tensor] 로 x를 리스트로 인식할 수 있도록 감싸주었더니, 정상적으로 작동했다.

하지만 다른 에러가 존재했다…

 self.training |= self.export
 #self.training = True
 #self.export = False

비트연산자가 torch.scipt에서든 되지 않는다는 추측이 든다. 그래서 일단은 넘어가려고 아래와 같이 바꿔주었다.

 self.training = self.export

이후 또 다른 에러가 생겼다…

하나 잡으면 하나 터지고,,, 결국 해결하지 못했다.

단순하게 torch.jit.script 사용하는 방법을 통해서 못했기 때문에,export.py 를 디버깅해보면서 파악해볼 예정이다.

아무래도 object dection 모델이기 때문에, 여러 argument들이 필요하긴 하다고 생각했는데, 그 값들이 없거나, 데이터가 없으면 pt 파일을 생성하기는 어려워진다.

이어서 해결해보도록 하겠습니다…

profile
특 : 미친듯한 게으름과 부지런한 생각이 공존하는 사람

0개의 댓글