Model Conversion Method

bluein·2024년 6월 10일
0
post-thumbnail

PyTorch 모델을 안드로이드에 배포하는 방법

1. Pytorch -> Torchscript
2. Pytorch -> ONNX -> TensorflowLite


1. Pytorch -> Torchscript

1-1. PyTorch 모델 준비

  • PyTorch로 학습한 모델 준비

  • torch.jit.trace 또는 torch.jit.script를 사용하여 모델을 TorchScript 형식으로 변환

  • TorchScript는 PyTorch 모델을 직렬화할 수 있는 형식

    직렬화를 하는 이유?

    • 모델을 효율적으로 저장하고 배포하며, 다양한 플랫폼에서 독립적으로 실행할 수 있도록 하기 위함

    직렬화

    • 객체의 상태를 바이트 스트림으로 변환하여 저장하거나 전송할 수 있는 형태로 만드는 과정

    바이트 스트림

    • 데이터를 바이트 단위로 연속적으로 처리하는 방식
    • 데이터 형식에 구애받지 않고 원시 데이터로서 처리

    Torchscript의 특징

    • 모델이 독립적인 실행 파일로 변환
    • C++ 런타임에서 작동되므로 Python 인터프리터 없이도 모델을 실행할 수 있음
    • 모델 실행을 최적화할 수 있어 성능을 향상시킬 수 있음
import torch

# PyTorch 모델 로드 (예: model.pt)
model = torch.load('model.pt')
model.eval()

# 입력 예제 텐서 정의 (입력 형태와 일치해야 함)
example_input = torch.rand(1, 3, 224, 224)  # 입력 형태를 실제 모델에 맞춰 변경

# TorchScript 형식으로 모델 변환
traced_model = torch.jit.trace(model, example_input)
traced_model.save('model.pth')

1-2. PyTorch Mobile 설치 및 설정

  • 안드로이드 프로젝트에서 PyTorch Mobile을 사용하기 위해서는 Gradle 파일을 수정해야 함
build.gradle (프로젝트 수준)
allprojects {
    repositories {
        ...
        maven {
            url 'https://oss.sonatype.org/content/repositories/snapshots'
        }
    }
}
build.gradle (앱 수준)
dependencies {
    ...
    implementation 'org.pytorch:pytorch_android_lite:1.9.0-SNAPSHOT'
    implementation 'org.pytorch:pytorch_android_torchvision:1.9.0-SNAPSHOT'
}

1-3. 모델 파일 추가

  • 모델 파일(model.pth)을 assets 폴더에 추가
  • assets 폴더가 없으면 생성해야 함

1-4. 모델 로드 및 예측 코드 작성

  • 안드로이드에서 모델을 로드하고 예측하는 코드
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.torchvision.TorchVision;

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import java.io.IOException;
import java.io.InputStream;

public class MainActivity extends AppCompatActivity {

    private Module module;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        try {
            // 모델 로드
            module = Module.load(assetFilePath(this, "model.pth"));
        } catch (IOException e) {
            e.printStackTrace();
        }

        // 이미지 로드
        Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

        // 이미지 전처리
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

        // 예측
        final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

        // 결과 처리
        final float[] scores = outputTensor.getDataAsFloatArray();
        // 결과 사용 예: 가장 높은 점수를 가진 클래스 찾기
    }

    // 자산 파일 경로 반환 메소드
    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        try (InputStream is = context.getAssets().open(assetName);
             FileOutputStream fos = new FileOutputStream(file)) {
            byte[] buffer = new byte[4 * 1024];
            int read;
            while ((read = is.read(buffer)) != -1) {
                fos.write(buffer, 0, read);
            }
            fos.flush();
        }
        return file.getAbsolutePath();
    }
}

1-5. 앱 실행 및 테스트

  • 안드로이드 스튜디오에서 앱을 빌드하고 실행하여 모델이 올바르게 로드되고 예측을 수행하는지 확인
  • 이 과정을 통해 PyTorch 모델을 안드로이드에 배포할 수 있음

2. Pytorch -> ONNX -> TensorflowLite

  • PyTorch 모델을 TensorFlow Lite로 변환하기 위해서는 PyTorch 모델을 ONNX(Open Neural Network Exchange) 형식으로 변환한 후, ONNX 모델을 TensorFlow 모델로 변환하고, 마지막으로 TensorFlow 모델을 TensorFlow Lite로 변환하는 단계가 필요
  • 이 과정은 다음과 같은 주요 단계로 나눌 수 있음
    1. PyTorch 모델을 ONNX로 변환
    2. ONNX 모델을 TensorFlow 모델로 변환
    3. TensorFlow 모델을 TensorFlow Lite로 변환

2-1. PyTorch 모델을 ONNX로 변환

  • 먼저, PyTorch 모델을 ONNX 형식으로 변환
import torch
import torch.onnx

# PyTorch 모델 정의
model = torch.load('model.pth')
model.eval()

# 입력 텐서 정의 (예: 배치 크기 1, 채널 3, 224x224 이미지)
dummy_input = torch.randn(1, 3, 224, 224)

# 모델을 ONNX 형식으로 변환
torch.onnx.export(model, dummy_input, "model.onnx", 
                  export_params=True, opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], output_names=['output'])

2-2. ONNX 모델을 TensorFlow 모델로 변환

  • ONNX 모델을 TensorFlow 모델로 변환하기 위해서는 onnx-tf 라이브러리를 사용

  • 이 라이브러리는 ONNX 모델을 TensorFlow 형식으로 변환하는 데 사용

  • 먼저 onnx-tf를 설치

pip install onnx-tf

그 다음, 변환을 수행

from onnx_tf.backend import prepare
import onnx

# ONNX 모델 로드
onnx_model = onnx.load("model.onnx")

# TensorFlow 형식으로 변환
tf_rep = prepare(onnx_model)

# TensorFlow 모델 저장
tf_rep.export_graph("model.pb")

2-3. TensorFlow 모델을 TensorFlow Lite로 변환

  • TensorFlow 모델을 TensorFlow Lite 형식으로 변환하기 위해서는 TensorFlow Lite Converter를 사용

  • 먼저 tensorflow를 설치

pip install tensorflow
  • 그 다음, 변환을 수행
import tensorflow as tf

# TensorFlow 모델 로드
converter = tf.lite.TFLiteConverter.from_saved_model("model.pb")
tflite_model = converter.convert()

# TensorFlow Lite 모델 저장
with open("model.tflite", "wb") as f:
    f.write(tflite_model)
profile
AI Research Engineer

0개의 댓글

관련 채용 정보