Custom Model torch.jit.script Debugging

게으른 개미개발자·2022년 11월 3일
1

model_conversion

목록 보기
11/13

커스텀 모델은 아래와 같이 만들었습니다.

분기에 따라, (커스텀 레이어) → layer1 → layer2 → flatten → fully connected

커스텀 레이어 적용 시, If문이 포워드에서 사용됩니다.

class CustomLayer(nn.Module):
    def __init__(self, first_dimension,second_dimension,third_dimension):
        super(CustomLayer, self).__init__()
        self.first_dimension,self.second_dimension,self.third_dimension = first_dimension,second_dimension,third_dimension
        weights = torch.randn(first_dimension,second_dimension,third_dimension)
        self.weights = nn.Parameter(weights)
        bias = torch.randn(first_dimension,second_dimension,third_dimension)
        self.bias = nn.Parameter(bias)
        return

    def forward(self,x):
        return self.weights * x + self.bias

class CustomModel(nn.Module):
    def __init__(self,branch_testing):
        super(CustomModel,self).__init__()
        self.branch_testing = branch_testing
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=(3,3),stride=1,padding=1),
            nn.SiLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64,32,kernel_size=(3,3),stride=1,padding=1),
            nn.SiLU(inplace=True),
            nn.MaxPool2d(16)
        )
        self.custom_layer1 = CustomLayer(3,256,256)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8192,1024)
        self.condition_func1 = nn.Linear(1024,10)
        self.condition_func2 = nn.Linear(1024,50)
        return
    
    def forward(self,x : torch.Tensor):
        if self.branch_testing==True:
            x = self.custom_layer1(x)        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x

모델을 Torch Script로 실행하는 구문은 아래와 같습니다.

branch_testing = True
model_instance = model(branch_testing)
script_from_model = torch.jit.script(model_instance)

model_instance ← 커스텀 모델의 인스턴스가 담겨있습니다.

이후, torch.jit.script(model_instance) 부분을 디버깅 해보려고 합니다.

if isinstance(obj, torch.nn.Module):
        obj = call_prepare_scriptable_func(obj)
        return torch.jit._recursive.create_script_module(
            obj, torch.jit._recursive.infer_methods_to_compile
        )

torch.jit.script(model_instance) 에 담긴 model_instance 즉, torch.nn을 상속받은 인스턴스이기 때문에 해당 객체가 다음 조건문에서 적용됩니다.

def call_prepare_scriptable_func(obj):
    memo: Dict[int, torch.nn.Module] = {}
    return call_prepare_scriptable_func_impl(obj, memo)

def call_prepare_scriptable_func_impl(obj, memo):
    if not isinstance(obj, torch.nn.Module):
        return obj

    obj_id = id(obj)

    # If obj_id is in memo, obj has already been prepared or is being
    # prepared in another call up the stack.
    if obj_id in memo:
        return memo[id(obj)]

memo의 경우 일종의 메모리 주소를 반환하는 것 같았습니다. 해당 모델의 인스턴스와 시작 메모리주소를 함께 담아주어, 반환하는 것으로 보입니다.

이후 반환된 obj를 활용하여, 다음 메서드에 진입합니다.

def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False):
    """
    Creates a new ScriptModule from an nn.Module

    Args:
        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
        stubs_fn:  Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
        share_types:  Whether to share underlying JIT types between modules (if possible).
            NOTE: Only set to False this when we cannot guarantee type sharing will work
                correctly. This only happens today for traced modules, where the same
                module can produce different traced methods depending on the inputs.
        is_tracing: Whether this function is called during tracing or scripting. If tracing,
                we don't need to do AttributeTypeIsSupportedChecker because all the unsupported
                attributes will be baked as constant in the tracing graph. In addition,
                this check significantly slows down the traced modules when the module size is big.
    """
    assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
    check_module_initialized(nn_module)
    concrete_type = get_module_concrete_type(nn_module, share_types)
    if not is_tracing:
        AttributeTypeIsSupportedChecker().check(nn_module)
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
check_module_initialized(nn_module)

초기화가 정상적으로 되어있는지 체크하는 메서드로 파악됩니다. 모델 모듈에 대해서 _parameters , _buffers, _modules 들이 담겨있는지 확인해보는 것으로 보입니다.(포워드 함수에 대한 경로 및 파라미터 값들) 정상적으로 모듈에 대한 위의 파라미터들이 존재할 경우, None값을 최종적으로 반환합니다.

concrete_type = get_module_concrete_type(nn_module, share_types)

Gets a concrete type for nn_modules. If share_types is True, the concrete type is fetched from concrete_type_store. If it is False, a new concrete type is created without first searching concrete_type_store.

제가 만든 커스텀 모델의 경우, torch.nn.modules 에 대해서 concrete type이 정의되어 있지 않았기 때문에, concrete_type_store.get_or_create_concrete_type(nn_module) 함수를 호출하여, concrete type에 대해서 inference 하는 과정을 거치게 됩니다.

concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
    if isinstance(nn_module, (torch.nn.ModuleDict)):
        concrete_type_builder.set_module_dict()
    if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
        concrete_type_builder.set_module_list()

만약에 해당 모델이 단순하게 torch.nnModuleDict 타입이거나 Sequential 모델 타입일 경우에는 모듈 정보를 위와 같이 정의해주는 것으로 보입니다.

그 경우가 아니라면, class_annotations = getattr(nn_module, '**annotations**', {}) getattr함수를 활용하여 모델에 대한 정보를 가져와 정의해줍니다.

# Get user-annotated ignored attributes.
    user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list())
    concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)
    ignored_properties = jit_ignored_properties(nn_module)

python으로 만들어진 torch 모델에서 데코레이터로 torch.jit.trace가 되어있거나, 모델 conversion 과정에서 필요하지 않은 부분을 사용자가 정의해놨을 때, 그 부분을 체크해서 변환되지 않도록 하는 과정으로 보입니다.

커스텀 모델에서는 따로 그런 부분을 처리하지 않았기 때문에, 따로 무시되는 속성값들은 없습니다.

for name, item in nn_module._parameters.items():
        if name in user_annotated_ignored_attributes:
            continue

        assert item is None or isinstance(item, torch.Tensor)
        attr_type, _ = infer_type(name, item)
        # We currently have the invariant in various places in our code
        # that parameters must be Tensors. However, the nn.Module API also
        # allows NoneType parameters. These parameters are not returned as
        # part of `parameters()` and its variants, but are available
        # through direct attribute access.
        concrete_type_builder.add_attribute(name, attr_type.type(), True, False)
        added_names.add(name)

모델 layer에서 pyTorch Tensor로 이루어진 값들만 added_names 라는 딕셔너리에 담아주게 됩니다. 대표적으로 weight 가 담기게 되는데, 만약, forward 함수가 존재하는 하나의 클래스일 경우에는,

담기지 않습니다. 위에 나와있는 주석에 나와있는 내용처럼 NoneType일 경우, 어떻게 처리해야 되는지 결정하기 위해 위와 같은 구문을 거치는 것으로 추측됩니다.

아래 infer_type 이라는 메서드를 거치며, Tensor의 정보를 확인하게 됩니다. 포워드 함수가 포함되어 있는 모델 레이어들도 다 이 함수를 거치게 되는데, 포워드 함수가 존재하지 않는 타입만 해당 구문해서 타입을 유추하기 위해 사용됩니다.

def infer_type(name, item):
        # The forward function from Module is special; never use this annotations; we
        # need to infer type directly using JIT.  I originally wanted to write
        # this test as isinstance(class_annotations[name], Callable) but
        # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
        # is also true!
        inferred = False
        try:
            if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
                ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            elif isinstance(item, torch.jit.Attribute):
                ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            else:
                attr_type = torch._C._jit_try_infer_type(item)
                inferred = True
        except RuntimeError as re:
            raise RuntimeError(
                "Error inferring type for {name}: {item}: {re}".format(name=name, item=item, re=re)
            )

        return attr_type, inferred

파이토치 모델에서 모델 및 포워드 함수에 사용되는 파라미터를 단순하게 Torch.Tensor로 담아주는 것이 아니라 typing.Union 과 같은 typing 모듈을 활용하여 파라미터의 타입을 선언하는 경우가 많습니다. 아래 포스팅을 보시면 왜 typing을 사용하는지, typing에서 어떤 함수들이 존재하는지 대략적으로 파악하실 수 있습니다.

[파이썬] typing 모듈로 타입 표시하기

결국 typing안에 있는 값들이 어떤 값들이며, 이 값의 데이터 타입이 torch.Tensor인지 확인하기 위해 위와 같은 과정을 거친다고 볼 수 있습니다.

위의 과정을 통해서 얻을 수 있는 사실은 파이썬 기반의 파이토치 모델을 torch.jit.script 타입으로 변경해주기위해, 관계 정보가 담겨있는 테이블을 확인하고, 테이블을 참조하여 torch.jit.script타입으로 매핑시켜주는 작업을 진행하는 것으로 보였습니다.

결과적으로 <class 'torch.jit._script.RecursiveScriptModule'> 에 담아주어 파이토치 jit.script의 RecursiveScriptModule로 반환해줍니다.

profile
특 : 미친듯한 게으름과 부지런한 생각이 공존하는 사람

0개의 댓글