ONNX는 Pytorch, Tensorflow, TensorRT 등 다양한 모델의 프레임워크들의 중심에 있는 모델을 표현하는 프레임워크로, 각 프레임워크에서 다른 프레임워크로 변환할 때 ONNX를 거쳐서 변환할 수 있도록 지원해줍니다.
dummy_input = torch.randn(1, 3, 128, 128)
traced = torch.jit.trace(net_g, dummy_input)
torch.onnx.export(
traced,
dummy_input,
"model.onnx",
export_params=True,
do_constant_folding=True,
opset_version=13,
input_names=["input"],
output_names=["output"],
)
torch 모델을 ONNX로 변환하는 과정은 생각보다 간단했다.
변환 자체는 코드 한 줄이면 됐다.
우선 입력샘플을 준비해줘야한다. ONNX는 입력 텐서의 shape을 알아야 한다. 이 입력 텐서의 shape이 추후 inference 할 입력과 동일해야 한다. 이 입력 텐서의 크기가 다른 프레임워크에서 추론할 때도 동일하게 유지된다.
dummy_input = torch.randn(1, 3, 128, 128)1은 배치 크기, 3은 채널 수 (RGB), 128 x 128은 이미지 크기traced = torch.jit.trace(net_g, dummy_input)net_g 모델을 TorchScript의 trace 모드로 변환⚠️ 주의: trace는 if, for 등 동적인 흐름이 있는 모델에서는 부정확할 수 있으므로, 주로 CNN 계열 모델에 적합
torch.onnx.export(...)이 부분이 ONNX 변환의 핵심 함수입니다.
| 파라미터 | 의미 |
|---|---|
traced | 변환할 TorchScript 모델 |
dummy_input | ONNX 변환 시 사용할 입력 예제 |
"sr_model.onnx" | 저장될 ONNX 파일 이름 |
export_params=True | 모델 가중치를 ONNX 그래프에 포함시킴 |
do_constant_folding=True | 상수 연산을 미리 계산해 그래프 최적화 |
opset_version=13 | ONNX 연산자 세트 버전 (13 이상 권장) |
input_names=["input"] | ONNX 입력 이름 지정 |
output_names=["output"] | ONNX 출력 이름 지정 |
| dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } | input과 output의 0번째 차원(배치 사이즈)는 동적으로 처리하겠다는 의미 → 이렇게 설정된 모델은 배사이즈에 관계 없이 추론이 가능한 모델이 된다. |
🔧 opset_version은 PyTorch 기능과 ONNX 연산자 매핑을 정의하는 규격입니다. 최신 ONNX 기능을 사용하려면 opset_version=13 또는 16 추천됩니다.
| 단계 | 설명 |
|---|---|
dummy_input 생성 | ONNX 변환 및 trace용 입력 텐서 준비 |
torch.jit.trace() | 정적인 TorchScript IR 생성 |
torch.onnx.export() | TorchScript 모델을 ONNX 형식으로 저장 |
opset_version=13 | 최신 연산자 사용을 위해 필수 설정 |
| 이름 지정 | 입력/출력 이름을 명확히 지정 (디버깅, 배포 시 유용) |
import onnx
onnx_model = onnx.load("sr_model.onnx")
onnx.checker.check_model(onnx_model) # 유효성 검사
onnxsim을 활용하면 ONNX 모델을 단순화해 불필요한 복잡도를 줄일 수 있다. input.onnx 모델을 단순화해 output.onnx를 만든다.onnxsim input.onnx output.onnx
그렇다고 pytorch의 모델이 모두 ONNX로 변환되는 것이 아니다. 일부 Pytorch의 연산 중에 ONNX로의 변환이 지원되지 않거나 완벽하게 매핑되지 않는 경우가 있기 때문에 이런 경우에는 Custom layer를 통해 별도 처리를 해줘야합니다.
보통 이 이유는 ONNX는 그래프 기반 표현이기 때문입니다. pytorch는 python 코드 기반으로 모델 내부에서도 shape이 자유자재로 바뀔 수 있지만 ONNX는 연산 그래프 기반으로 표현하기 때문에 텐서의 input output이 명확해야 표현이 가능합니다. 그렇기 때문에 동적으로 shape을 바꾸거나 인덱싱 하는 연산은 지원이 어렵다고 합니다.
예를 들어 in-place 연산이나 masked_fill과 같이 텐서 값을 동적으로 바꾸는 경우에 지원이 안될 수도 있습니다. einsum같은 경우도 matmul, permute, reshape의 조합으로 표현해야 모델이 ONNX 변환이 이뤄질 수 있습니다.
F.pixel_shuffle 을 예로 들어보겠습니다. 이 함수는 API 형태로 되어 있어 ONNX가 내부 구조를 알고 변환하기가 어렵습니다. 따라서 nn.PixelShuffle이나 Custom function을 구성해야합니다.
def pixel_shuffle_custom(x, upscale_factor):
batch_size, c, h, w = x.size()
r = upscale_factor
out_c = c // (r ** 2)
x = x.view(batch_size, out_c, r, r, h, w)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(batch_size, out_c, h * r, w * r)
return x
pixelshuffle을 reshape과 permute등을 통해 실제로 구현하면 위와 같이 구현할 수 있고, 이렇게 함수 형태로 표현된 모듈은 ONNX로의 변환이 가능하기 때문에 문제가 생기는 경우에는 이렇게 모델을 바꿔주는 작업이 필요합니다.