[딥러닝] ONNX를 활용하여 모델 변환하기

김영민·2024년 8월 14일
0

DeepLearning

목록 보기
31/33
post-thumbnail

딥러닝 모델을 학습시킬 때에는 보통 Python과 Pytorch를 활용하였다.
그러나 이 모델을 배포하거나, 다른 프레임워크와 호환되게 하려면 onnx라는 것이 필요하다.

ONNX란?

  • Open Neural Network Exchange의 줄임말.
  • 공식 문서의 말을 빌리면 다음과 같다.

    오픈 신경망 교환(ONNX)은 머신 러닝 모델을 표현하기 위한 개방형 표준 형식. torch.onnx 모듈은 PyTorch 모델을 ONNX로 내보낼 수 있다(import). 그런 다음 이 모델은 ONNX를 지원하는 여러 런타임에서 사용할 수 있습니다.

UNet from Pytorch to ONNX

  • 공식 문서에서는 AlxeNet을 활용하였지만 나는 유방 MRI의 유방 부위를 세그멘테이션 하는 학습된 UNet 모델을 직접 짜고 진행해보려 한다.

1. pretrained 된 UNet을 UNet.onnx로 export 하기

import torch
from model import breast_UNet

breast_unet = breast_UNet(1,1)
breast_unet.load_state_dict(torch.load("/***/UNet_breast.pth"))
breast_unet.cuda()
breast_unet.eval()

dummy_input = torch.randn((1,1,512,512),device="cuda")
onnx_file_path = "/***/breast_unet.onnx"

input_names = ["input_image"]
output_names = ["output_segmentation"]

torch.onnx.export(breast_unet, dummy_input, onnx_file_path, verbose=True, input_names=input_names, output_names=output_names)
  • 다 학습되고, 배포를 위해 사용한다는 가정을 하여 .eval()을 사용.
  • input_names와 output_names를 사용하여 모델의 입력과 출력의 이름을 설정 가능.
  • verbose를 True로 설정하면, 아래와 같이 상세하게 변환 과정을 출력해준다.

2. ONNX pip 설치

pip install onnx
  • ONNX를 설치해준다.

3. ONNX 파일 로드 및 확인

import onnx

# ONNX 모델 로드
onnx_model = onnx.load(onnx_file_path)

# 모델이 잘 형성되었는지 확인
onnx.checker.check_model(onnx_model)


print(onnx.helper.printable_graph(onnx_model.graph))
  • 다음과 같이 onnx_model을 로드할 수 있다.
  • print를 하면 다음과 같이 나온다.

Tracing vs Scripting

Tracing

  • torch.onnx.export() 함수가 모델을 ONNX 형식으로 내보낼 때, 전달된 모델이 ScriptModule이 아니면 torch.jit.trace()를 사용하여 모델을 한 번 실행하고, 그 동안 발생한 모든 연산을 기록.

  • 작동 방식: 모델이 주어진 입력 데이터로 실행되면서 수행되는 연산을 기록하여 모델을 변환합니다. 따라서, 모델이 동적인 구조(예: 입력 데이터에 따라 다른 경로를 따르는 경우)를 가질 경우 이러한 동적 동작을 반영하지 못함.

  • 제한 사항: 만약 모델이 입력 데이터에 따라 동적으로 동작한다면, Tracing 방법은 이러한 동작을 캡처하지 못하고, 고정된 연산 그래프를 생성합니다. 이로 인해, 루프나 조건문이 포함된 모델에서는 올바르지 않은 ONNX 모델이 생성될 수 있음.

  • 사용 예: 모델의 구조가 정적이고, 입력에 따라 변하지 않는 경우에 Tracing 방법이 유용.

  • torch.onnx.export를 사용하는 것과 동일하다.

Scripting

  • 정의: Scripting은 torch.jit.script()를 사용하여 모델을 ScriptModule로 변환하고, 이를 ONNX로 내보내는 방법.

  • 작동 방식: Scripting은 모델의 동적 제어 흐름을 보존하여, 입력 데이터에 따라 다른 동작을 하는 모델에서도 올바르게 작동.

  • 장점: Scripting 방법은 동적 제어 흐름을 가진 모델에도 유효하며, 다양한 크기의 입력에도 적응할 수 있는 모델을 생성할 수 있음.

  • 사용 예: 모델이 입력에 따라 다르게 동작하거나, 동적 제어 흐름이 포함된 경우에는 Scripting 방법이 적합함.

  • 코드 예시

import torch
script_model = torch.jit.script(breast_unet)

이렇게 pytorch로 구성되어 있고, 학습한 모델을 ONNX로 바꿔보았다.
생각보다 간단한 작업이었고, 해당 모델을 C++에서 로드하고 추론하는 것까지 진행해보아야겠다.

0개의 댓글