지난 이야기

게으른 개미개발자·2022년 11월 3일
0

model_conversion

목록 보기
13/13

기존에 계속해서 공부했던 내용을 정리하는 회고록입니다...

기존에 If문 및 loop가 변환이 안 되는 것이 포워드 함수 내부의 파라미터가 Tensor인지 아닌지에 따라서 변환이 되고, 안되는 것이 달라진다고 생각했었습니다.

위의 가설은 반은 맞고, 반은 틀린 내용이었습니다.

우선 아래와 같이 해당 모델이 존재한다고 가정해보겠습니다.

class CustomModel(nn.Module):
    def __init__(self,branch_testing):
        super(CustomModel,self).__init__()
        self.branch_testing = branch_testing
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=(3,3),stride=1,padding=1),
            nn.SiLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64,32,kernel_size=(3,3),stride=1,padding=1),
            nn.SiLU(inplace=True),
            nn.MaxPool2d(16)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8192,1024)
        return
		def forward(self,x : torch.Tensor):
        if self.branch_testing:
            x = x * 2
				else:
						x = x * 3
				x = self.layer1(x)
        x = self.layer2(x)
			  x = self.flatten(x)
        x = self.fc1(x)
        return x

torch.jit.script를 활용하여 RecursiveScriptModule 형태로 바꿔서 PT파일이나 ONNX 파일로 export할 수 있습니다.

모델 인스턴스를 아래와 같이 생성했을 때, jit의 script모듈을 활용하여, RecursiveScriptModule

객체로 변환이 가능합니다.

# 모델 인스턴스 객체 생성
model_instance = model()
# RecursiveScriptModule 객체 생성
script_from_model = torch.jit.script(model_instance)

위와 같이 모델 인스턴스를 생성해줄 때, 모델의 init부분에 선언되어 있는 인스턴스 변수를 사용하기 위해서 파라미터로 입력을 받습니다. 일반적인 모델에서는 yaml파일이나 config파일로 입력을 받은 후, 파싱하여 모델 인스턴스 객체를 생성줍니다.

위의 간단한 모델의 경우, branch_testing이라는 파라미터가 모델 인스턴스 변수로 존재하기 때문에, 아래와 같이 입력을 받아야합니다.

branch_testing = False
# 모델 인스턴스 객체 파라미터 입력받아 생성
model_instance = model(branch_testing)

이후, 모델에 필요한 파라미터까지 담긴 객체를 RecursiveScriptModule로 변환해줍니다.

script_from_model = torch.jit.script(model_instance)

즉, static하게 고정되어 있는 모델 객체를 script파일 또는 Onnx(binary)파일로 변환해주기 때문에, 모델 객체의 파라미터는 static하게 됩니다.

if self.branch_testing:
      x = x * 2
else:

다시말해, 파이토치의 pt나 onnx의 경우 모델 객체가 static하게 존재하기 때문에 branch_testing이라는 값은 객체를 생성해줄 때, 이미 정해져있고, 그에 따라 if문의 경우도 정해지게 됩니다.

따라서, if 분기가 나타나지 않는 것이었습니다.

그렇다면 IF나 Loop에 존재하는 Path가 담기는 경우는?

def forward(self,branch_testing : bool, x : torch.Tensor):
        if branch_testing
            x = x * 2
				else:
						x = x * 3
				x = self.layer1(x)
        x = self.layer2(x)
			  x = self.flatten(x)
        x = self.fc1(x)
        return x

위의 포워드 함수의 내부 파라미터를 보면 알 수 있습니다.

def forward(self,branch_testing : bool, x : torch.Tensor):

이번에는 branch_testing이 모델 클래스의 변수가 아닌, forward 내부의 지역변수입니다. 이 경우에는, 모델 객체를 생성할 때, 인스턴스 변수로 입력을 받을 필요가 없습니다.

# 모델 인스턴스 객체 생성
model_instance = model()

즉, 모델 객체를 생성해줄 때, 다음과 같이 입력을 받지 않아도 됩니다.

하지만 파이토치나 Onnx로 추론을 진행하려고 할 때, forward 함수에 대한 Input이 필요합니다.

(번외로 forward 내부 함수의 파라미터의 경우, 타입을 선언해주지 않으면, 기본적으로 텐서로 인식합니다. 따라서 텐서가 아닌 경우, 타입을 선언해주어야합니다.)

결국, 추론을 진행하는 시점에서 해당 파라미터를 입력받게 됩니다.

inference_input = torch.rand(1,3,256,256)
torch_output = custom_model(branch_testing,inference_input)
print(torch_output)

파이토치의 경우, 위와 같이 입력하면 됩니다. 일반적인 모델을 추론하는 과정과 동일합니다.

Onnx 모델의 경우, 아래와 같이 입력을 받아 추론하게 됩니다.

torch.onnx.export(script_v, (tmp,inference_input), "sample.onnx")
onnx_model = onnx.load('sample.onnx')
onnx_bytes = onnx_model.SerializeToString()

torch.manual_seed(1)
inference_input = torch.rand(1,3,256,256)
onnx_output = do_onnx_inference(onnx_bytes,(branch_testing,inference_input))
print(onnx_output)
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def do_onnx_inference(onnx_model,inference_inputs):
    ort_inputs = {}
    session = onnxruntime.InferenceSession(
        onnx_model, providers=['CPUExecutionProvider'])

    ort_inputs['input_name'] = to_numpy(inference_inputs)

    output_name = ['output_name']
        
    ort_outs = session.run(output_name,ort_inputs)
    return ort_outs

위의 내용과 같이 추론을 진행할 경우, 파이토치 결과값과 Onnx의 결과값이 동일했으며, 모델 내부 구조를 시각화해보거나, 파싱해보아도 If문이나 Loop가 정상적으로 담겨있는 것을 확인할 수 있었습니다.

Forward 함수에서 Inference를 진행했을 때, 내부에 아래의 function들이 존재할 경우, 모든 Path가 정상적으로 변환되는가?

function / 변수 타입클래스 변수인스턴스 변수포워드 함수 내부 변수
IFX(지정된 경로만)X(지정된 경로만)O(가변이기 때문에)
LoopX(지정된 경로만)X(지정된 경로만)O(가변이기 때문에)

코드의 대략적인 흐름도는 아래와 같습니다.

profile
특 : 미친듯한 게으름과 부지런한 생각이 공존하는 사람

0개의 댓글