torch script 에 대해서

너 오늘 코드 짰니?·2023년 7월 20일
0

torch script가 뭐지?

pytorch로 모델을 개발해서 안드로이드에 이식하기 위해 onnx로 변환하는 과정에서 공식문서를 많이 보게 되었는데 그중에 아래와 같은 말이 있었다.

Internally, torch.onnx.export() requires a torch.jit.ScriptModule rather than a torch.nn.Module. If the passed-in model is not already a ScriptModule, export() will use tracing to convert it to one:

이전에 잘 알지 못하면서 torch.onnx.export 함수 안에 파라미터로 sample input 데이터를 넣긴 했었다. 그땐 그냥 넣으라길래 왜넣는지도 모르고 넣었는데, 알고보니 이게 자동으로 torch script로 변환해주기 위한 torch.jit.trace를 위한 input data였던 것이다.

torch script가 뭐냐면 파이토치(PyTorch)의 모델을 최적화하고, 직렬화하며, 배포하기 위한 도구이고 pytorch에서 작성한 Module의 중간표현식 (IR)이라고 할 수 있다. C++ 같은 고성능 환경에서 실행될 수 있도록 모델의 속도를 향상시키고, 모바일 기기 및 임베디드 시스템에서 모델을 실행하는 데 필요한 최소한의 종속성을 갖게하는 효과가 있다고 한다.

torch script를 만드는 방법

torch script를 만들기 위해서는 2가지 방법이 존재한다.

  • torch.jit.trace() : pytorch module에 입력값을 넣어서 모델 안에서 동작하는 흐름에 따라 모델 구조를 파악한 뒤 jit 컴파일러가 모델을 기록한다. flow가 기록되기 때문에 statically fix된 그래프이다. 그러나 Module안에 조건문이나 반복문 같은 contrlo flow가 있으면 해당 부분이 제대로 기록되지 않는다.
  • torch.jit.script() : torchscript 컴파일러가 모듈을 분석해서 직접 컴파일을 진행한다. 따라서 조건문이나 반복문 같은 control flow가 포함되어 있어도 이를 반영해서 컴파일을 하지만,... 솔직히 깔끔하게 변환되지 않는 경우가 허다해서 지원되지 않는 operation이 뭐가 끼어있는지 제대로 확인하고 쓰는게 중요해 보인다.

torch script를 만들 때 주의할 점

1. Numpy나 파이썬 자체 built-in type을 피하자

쉽게말해서 torchscript로 export할 계획이 있는 프로젝트이면, 모델을 디자인 할 때 torch에 있는데이터 타입만 쓰는게 깔끔하다는 이야기이다. torch.jit.trace 과정에서 Numpy나 파이썬 내장 데이터 type 등은 모두 상수로 변환되어버리기 때문에 제대로 convert 되지 않을 수 있다. 따라서 torch에 관련된 데이터 타입만 사용하는게 깔끔하게 변환된다.

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

그리고 추가적으로 주의해야할 점이 torch.item()과 같은 함수를 사용하면 안된다. 왜냐하면 torch.item()이 torch의 built-in 함수이긴 하지만 그 결과로 torch 데이터타입을 python 상수로 반환하기 때문이다. 따라서

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)
    
# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

위 코드와 같이 torch.item()을 사용하지 않고 return 한 다음 외부에서 상수로 바꿔서 처리하던가 해야할거 같다.

모델 안에서 상수로 바꿔서 return 하지 않도록!

모델 안에서는 꼭 torch 관련 데이터타입으로만 동작하도록 모델링 해야 한다.

2. tensor.shape에 대해 inplace 연산을 사용하지 말자

tracing mode에서 tensor.shape로 얻어진 shape는 tensor로써 추적되고 같은 메모리를 참조하게 된다. 따라서 이들이 inplace연산으로 인해 같은 메모리를 참조하고 수정된다면 옳지 않은 출력결과를 얻을 수 있다.

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

위와 같은 경우에서 real_seq_length와 seq_length는 inplace 연산으로 묶여있기 때문에 같은 메모리를 참조하게 되고 동시에 변경이 된다. 따라서 아래와 같은 코드로 변경해야 한다.

real_seq_length = real_seq_length + 2

참고로 inplace연산이란 a += 1 와 같이 a라는 변수를 직접 수정해서 다시 대입시키는 연산을 의미한다.

다음 Step으로 할 것은...

보통 파이토치로 모델을 개발해서 모바일 안드로이드 기기에 이식하기 위해서는 tensorflow lite로 변환을 거쳐야 한다.

즉 pytorch -> onnx -> tensorflow -> tensorflow lite 모델로 변환을 거쳐서 안드로이드 기기에 이식하게 되는데 이 과정은 이전에 포스팅을 했었다.

그 과정이 상당히 복잡하고 특히 변환할 때 라이브러리간 호환이 맞지 않으면 잘 변환이 되지 않고 tensorflow 버전따라서 저장 방식도 다르고 그래서 굉장히 스트레스를 많이 받았었다.

다음에는 torchscript를 활용하여 pytorch mobile 라이브러리를 통해 안드로이드에 바로 이식하는 실습을 해볼 예정이다.

이렇게 했을 때 장점은

  1. pytorch 모델을 강제로 tflite 모델로 변환시키는 수고가 덜 들어간다.
  2. pytorch는 [batch, channel, height, width] shape의 input을 가지는데 tflite는 [batch, height, width, channel] shape의 input을 가지기 때문에 이 두 차이를 고려하여 모델링 했어야 했지만 torchscript를 바로 이식하면 이런 고민도 사라진다.
  3. pytorch에서 quantization을 거친 경량화 모델을 바로 이식할 수 있다.

3번이 아마 굉장한 장점이 되지 않을까 예상해 보는데, 모델을 경량화 하기 위해서 pytorch에서 시도해볼 수 있는 기법이 상당히 있다. 하지만 양자화를 거친 모델을 onnx로 변환하는게 거의 상당히 부자연스러워 보였기 때문에 그림의 떡이었던것 같다.

pytorch에서 모델을 최적화 하고 영혼끝까지 경량화 한다음 torchscript로 변환해서 바로 모바일 기기에 이식하면 최고의 성능을 낼 수 있지 않을까..?

profile
안했으면 빨리 백준하나 풀고자.

1개의 댓글

comment-user-thumbnail
2023년 7월 20일

너무 좋은 글이네요. 공유해주셔서 감사합니다.

답글 달기