torch.jit.script if,else 파악하기

게으른 개미개발자·2022년 9월 28일
0

model_conversion

목록 보기
4/13

지난번에 알아보았던, torch.jit.script 로 pt 파일을 내보낼 때, torch.eq 모듈이 지원되지 않는다며, 정상적으로 netron에서 열리지 않는 문제와 Onnx로 변환하였을 때, if,else와 같은 분기가 존재할 경우, Condition에 해당되는 분기만을 포워딩하는 것을 확인할 수 있었다.

따라서 해당 문제를 살펴보려고 한다.

결론부터 말하자면 해당 문제가 발생했던 이유는, 파이토치 모델 클래스를 만들 때, 사용하는 forward() 함수가 문제였다.

클래스에서 인스턴스를 생성할 때, __init__ 이라는 생성자 함수가 호출된다. 마찬가지로 인스턴스가 호출될 때, __call__ 이라는 호출자 함수가 호출된다.

파이토치에서 forward 함수 역시, torch.nn.module 클래스를 사용하게 되면, 호출 함수로 사용하게 된다.

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

해당 내용은 1178 라인에 존재한다.

forward 함수가 실행되는데, 한 가지 의문스러운 점이 있었다.

https://pytorch.org/docs/master/jit_language_reference.html#id2

def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

파이토치 레퍼런스에 존재하는 torch.jit.script 의 일부분을 발취해서 가져왔다.

다른 예제들도 그렇고 항상 input 텐서 값들을 기준으로 if 조건을 걸어서 활용하지, boolean값과 같은 다른 상태값으로 분기를 나누지는 않았다.

def forward(self,x):
        x = self.custom_layer(x)
        x = self.custom_layer2(x)
        if self.branch_testing==True:
            x = self.custom_layer2(x)
            x = self.layer1(x)
        else:
            x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

내가 직접 커스터마이징했던 포워딩 구문이다. 나는 input값을 기준으로 분기를 나눴던 것이 아니라, 어떠한 condition에 따라, 분기를 나눴었다.

aten::eq.str_list(str[] a, str[] b) -> (bool):
  Expected a value of type 'List[str]' for argument 'a' but instead found type 'NoneType'.

  eq(float a, Tensor b) -> (Tensor):
  Expected a value of type 'float' for argument 'a' but instead found type 'NoneType'.

  eq(int a, Tensor b) -> (Tensor):
  Expected a value of type 'int' for argument 'a' but instead found type 'NoneType'.

The original call is:
  File "forward_if_simple_script.py", line 57
    def forward(self,x):
        if self.branch_testing==True:
           ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            x = self.custom_layer1(x)
        else:

혹시 몰라서 boolean input값을 주지 않고, None 값을 넣어줬을 때, 다음과 같은 오류가 발생했다. 이해가 되지 않았던 점은 boolean값으로 비교해주고 싶었는데, Tensor로 비교하려고 하는 것 같았다.

def forward(self,x):
        if x[0][0][0][0] % 2 == 0:
            x = self.custom_layer(x)
        x = self.custom_layer(x)
        x = self.layer1(x)    
        x = self.layer2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.condtion_func(x)
        return x

그래서 약간 허접하지만, 텐서값을 기준으로 분기를 설정해보았다.

두둥, 그러고 pyTorch pt파일과 Onnx파일로 export해보았다.

graph(%self : __torch__.CustomModel,
      %x.1 : Tensor):
  %3 : int = prim::Constant[value=0]() # forward_if_simple_script.py:67:13
  %21 : int = prim::Constant[value=2]() # forward_if_simple_script.py:67:27
  %14 : Tensor = aten::select(%x.1, %3, %3) # forward_if_simple_script.py:67:11
  %16 : Tensor = aten::select(%14, %3, %3) # forward_if_simple_script.py:67:11
  %18 : Tensor = aten::select(%16, %3, %3) # forward_if_simple_script.py:67:11
  %20 : Tensor = aten::select(%18, %3, %3) # forward_if_simple_script.py:67:11
  %22 : Tensor = aten::remainder(%20, %21) # forward_if_simple_script.py:67:11
  %23 : Tensor = aten::eq(%22, %3) # forward_if_simple_script.py:67:11
  %25 : bool = aten::Bool(%23) # forward_if_simple_script.py:67:11
  %x : Tensor = prim::If(%25) # forward_if_simple_script.py:67:8
    block0():
      %custom_layer.1 : __torch__.CustomLayer = prim::GetAttr[name="custom_layer"](%self)
      %x.9 : Tensor = prim::CallMethod[name="forward"](%custom_layer.1, %x.1) # forward_if_simple_script.py:68:16
      -> (%x.9)
    block1():
      -> (%x.1)
  %custom_layer : __torch__.CustomLayer = prim::GetAttr[name="custom_layer"](%self)
  %x.23 : Tensor = prim::CallMethod[name="forward"](%custom_layer, %x) # forward_if_simple_script.py:69:12
  %layer1 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name="layer1"](%self)
  %x.27 : Tensor = prim::CallMethod[name="forward"](%layer1, %x.23) # forward_if_simple_script.py:70:12
  %layer2 : __torch__.torch.nn.modules.container.___torch_mangle_1.Sequential = prim::GetAttr[name="layer2"](%self)
  %x.31 : Tensor = prim::CallMethod[name="forward"](%layer2, %x.27) # forward_if_simple_script.py:71:12
  %flatten : __torch__.torch.nn.modules.flatten.Flatten = prim::GetAttr[name="flatten"](%self)
  %x.35 : Tensor = prim::CallMethod[name="forward"](%flatten, %x.31) # forward_if_simple_script.py:72:12
  %fc1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc1"](%self)
  %x.39 : Tensor = prim::CallMethod[name="forward"](%fc1, %x.35) # forward_if_simple_script.py:73:12
  %condtion_func : __torch__.torch.nn.modules.linear.___torch_mangle_2.Linear = prim::GetAttr[name="condtion_func"](%self)
  %x.43 : Tensor = prim::CallMethod[name="forward"](%condtion_func, %x.39) # forward_if_simple_script.py:74:12
  return (%x.43)

graph를 찍어보았을 때, 정상적으로 일단 분기가 나뉜 것을 확인할 수 있었다.

한 가지 기억해야될 점은 Onnx file을 netron으로 시각화해보았을 때도 마찬가지로 정상적으로 분기가 나뉘었다. torch.jit.script 로 만든 pt파일의 경우, 시각화하였을 때, x라는 Input을 넣었을 때, True 조건에 대해서만 간선이 연결된 시각화가 이루어진다. 즉 else 정보도 갖고는 있지만, forward 그래프를 표기할 때, If에 대해서 Onnx처럼 하나의 노드로 보는 것이 아니라, python이나 pytorch의 기본 메서드라고 생각하는 것에 가까운 것 같다.



위쪽이 pytorch pt 모델이고, 아래쪽이 Onnx까지 export했을 때의 경우이다.

그래프 중간에 If문이 포함되어 있는것을 확인할 수 있었다. 이 if가 얼마나 반가운지,,,

두 개의 branch로 나눠지는 것을 확인할 수 있었다.

결론, 3줄 요약

  1. forward안에, 다양한 종류의 분기가 존재할 수 있으나, pt나 Onnx와 같은 static한 형태로 export할 때, Tensor를 기준으로는 정상적으로 export가 진행된다.
  2. boolean값과 같은 형태로 분기를 나눌 경우, (그 기준이 텐서값을 토대로 한 boolean값이 아닌 경우), 정상적으로 torch.jit.script 에서는 export되지 않는다. trace와 동일하게 export된다.
  3. 코드에서 분기에 따라 컨디션이 다른 경우에는, 코드를 분기의 컨디션에 맞게 여러 버전으로 만들어주고, static하게 구성해주면 된다.(약간 좀 많이 번거로워짐,,,)
profile
특 : 미친듯한 게으름과 부지런한 생각이 공존하는 사람

0개의 댓글