forward 함수에서 IF와 같은 분기가 존재할 경우, 해당 분기에서 사용하는 파라미터들이 torch.Tensor
가 아닌 다른 값일 경우, 최소한 사용자에게 warning을 띄워줄 수 있어야한다. 왜냐하면, torch.jit.trace
를 활용하여 onnx로 export할 경우에는, trace하는 방향으로만 export가 되며, torch.jit.script
를 활용하여 onnx로 export할 경우에는, forward 함수 내부에서 tensor가 기준이 되어, 분기가 되지 않는 경우에는, 정상적으로 분기가 나뉘지 않기 때문이다. 이 역시 trace방식과 비슷하게 적용이 된다.
torch 공식 소스코드에서도 확인해보면,
func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
that will be run with `example_inputs`. `func` arguments and return
values must be tensors or (possibly nested) tuples that contain
tensors. When a module is passed `torch.jit.trace`, only the
``forward`` method is run and traced (see :func:`torch.jit.trace
<torch.jit.trace_module>` for details).
example_inputs (tuple or torch.Tensor): A tuple of example inputs that
will be passed to the function while tracing. The resulting trace
can be run with inputs of different types and shapes assuming the
traced operations support those types and shapes. `example_inputs`
may also be a single Tensor in which case it is automatically
wrapped in a tuple.
torch > jit > _trace.py에 jit trace에 대한 자세한 정보가 나와있다. 아무튼 tensor로 이루어진 tuple이 아닐 경우에는 forward에서 문제가 될 수 있다.
앞선, 나의 model conversion 시리즈 기록들만 보아도, 간단하게 커스터마이징한 모델의 경우도 동일하게 변환되는게 쉽지 않다.
따라서, If나 loop와 같이 정상적으로 변환되기 힘든 부분이 모델에 포함되어 있을 경우, TensorRT Engine까지 변환해주기 위해서, 사용자가 직접 그 부분을 커스터마이징하거나, 모델을 나눠서 합치는 과정이 필요할 수 있다.(추후 진행해볼 예정)
일단, 간단하게 forward 함수 내부에 사용되는 파라미터들, 모델에서 인스턴스 self변수로 사용되는 파라미터들의 데이터 타입을 파악하고, Tensor가 아닐경우에는 logging warning을 띄워주는 방법으로 접근하려고 한다. 또한, python의 inspect 모듈을 활용하여, forward 함수 내부를 체크해보는 방법으로도 시도해보려고 한다.
활용하기에 아주 좋은 파이썬 모듈이 있어서 활용하려고 한다.
파이썬은 정말 많은 라이브러리들이 존재하기 때문에, 다른 언어와 다르게 편하게 짤 수 있어서 좋다.
https://zephyrus1111.tistory.com/153?category=835757
if문의 위치를 찾는것은 추후에 작업을 해볼 것이다.
import torch
import warnings
import logging
import inspect
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
import sys
sys.path.append('/workspace/')
from custom import test_forward_if_script as custom_script
inference_input = torch.randn(1,3,256,256)
condition_testing = False
branch_testing = False
model = custom_script.CustomModel(condition_testing,branch_testing)
# print(model(inference_input))
# print(model.forward)
# def _forward_unimplemented(self, *input: Any) -> None:
# r"""Defines the computation performed at every call.
# Should be overridden by all subclasses.
# .. note::
# Although the recipe for forward pass needs to be defined within
# this function, one should call the :class:`Module` instance afterwards
# instead of this since the former takes care of running the
# registered hooks while the latter silently ignores them.
# """
# raise NotImplementedError
logger = logging.getLogger()
logger.setLevel(logging.INFO)
_DtypeWarning = 'The layer(op) associated with (Parameter) may not be converted normally. Please check'
class Module():
def __init__(self,model) -> None:
# forward: Callable[..., Any] = _forward_unimplemented
self.model = model
self.forward = model.forward
self.caution_list = []
self.caution_dict = []
self.argument_list = []
def __len__(self,*args):
return len(args)
# forward 함수만 적용
def get_argument(self,):
self.argument_list = self.forward.__code__.co_varnames
def alert_argument(self,):
for caution in self.caution_list:
param = self.argument_list[caution]
warnings.warn('('+param+')'+' '+_DtypeWarning,UserWarning)
def _call_impl(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for i in range(self.__len__(*input)):
if type(input[i]) != torch.Tensor:
self.caution_list.append(i)
if self.__len__(*kwargs) != 0:
for k,v in kwargs.items():
if type(v) != torch.Tensor:
self.caution_dict.append(k)
self.get_argument()
self.alert_argument()
__call__ : Callable[..., Any] = _call_impl
module = Module(model)
#forward input이 1개일 경우
#module(inference_input)
#forward input이 여러개일 경우
module(inference_input,inference_input,1.2,3.3)
이렇게 forward 파라미터 중에서 Tensor 형태로 들어오지 않는 경우에는 UserWarning을 띄워주어, 사용자에게 알려주는 기능을 하려고한다. Onnx로 If나 Loop 변환은 사용자의 몫으로 추후에 더 만들어야된다.
IF문 위치까지 찾아서, 정확하게 conversion이 되지 않는 경우를 체크할 수 있는 기능을 만들어보려고한다.
*warnings에서 logging으로도 변경해서 적용할 것.