1. Pytorch -> Torchscript
2. Pytorch -> ONNX -> TensorflowLite
PyTorch로 학습한 모델 준비
torch.jit.trace
또는 torch.jit.script
를 사용하여 모델을 TorchScript 형식으로 변환
TorchScript는 PyTorch 모델을 직렬화할 수 있는 형식
직렬화를 하는 이유?
- 모델을 효율적으로 저장하고 배포하며, 다양한 플랫폼에서 독립적으로 실행할 수 있도록 하기 위함
직렬화
- 객체의 상태를 바이트 스트림으로 변환하여 저장하거나 전송할 수 있는 형태로 만드는 과정
바이트 스트림
- 데이터를 바이트 단위로 연속적으로 처리하는 방식
- 데이터 형식에 구애받지 않고 원시 데이터로서 처리
Torchscript의 특징
- 모델이 독립적인 실행 파일로 변환
- C++ 런타임에서 작동되므로 Python 인터프리터 없이도 모델을 실행할 수 있음
- 모델 실행을 최적화할 수 있어 성능을 향상시킬 수 있음
import torch
# PyTorch 모델 로드 (예: model.pt)
model = torch.load('model.pt')
model.eval()
# 입력 예제 텐서 정의 (입력 형태와 일치해야 함)
example_input = torch.rand(1, 3, 224, 224) # 입력 형태를 실제 모델에 맞춰 변경
# TorchScript 형식으로 모델 변환
traced_model = torch.jit.trace(model, example_input)
traced_model.save('model.pth')
build.gradle
(프로젝트 수준)allprojects {
repositories {
...
maven {
url 'https://oss.sonatype.org/content/repositories/snapshots'
}
}
}
build.gradle
(앱 수준)dependencies {
...
implementation 'org.pytorch:pytorch_android_lite:1.9.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0-SNAPSHOT'
}
model.pth
)을 assets
폴더에 추가assets
폴더가 없으면 생성해야 함import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.torchvision.TorchVision;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import java.io.IOException;
import java.io.InputStream;
public class MainActivity extends AppCompatActivity {
private Module module;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
try {
// 모델 로드
module = Module.load(assetFilePath(this, "model.pth"));
} catch (IOException e) {
e.printStackTrace();
}
// 이미지 로드
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
// 이미지 전처리
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
// 예측
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// 결과 처리
final float[] scores = outputTensor.getDataAsFloatArray();
// 결과 사용 예: 가장 높은 점수를 가진 클래스 찾기
}
// 자산 파일 경로 반환 메소드
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
try (InputStream is = context.getAssets().open(assetName);
FileOutputStream fos = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
fos.write(buffer, 0, read);
}
fos.flush();
}
return file.getAbsolutePath();
}
}
- PyTorch 모델을 ONNX로 변환
- ONNX 모델을 TensorFlow 모델로 변환
- TensorFlow 모델을 TensorFlow Lite로 변환
import torch
import torch.onnx
# PyTorch 모델 정의
model = torch.load('model.pth')
model.eval()
# 입력 텐서 정의 (예: 배치 크기 1, 채널 3, 224x224 이미지)
dummy_input = torch.randn(1, 3, 224, 224)
# 모델을 ONNX 형식으로 변환
torch.onnx.export(model, dummy_input, "model.onnx",
export_params=True, opset_version=11,
do_constant_folding=True,
input_names=['input'], output_names=['output'])
ONNX 모델을 TensorFlow 모델로 변환하기 위해서는 onnx-tf
라이브러리를 사용
이 라이브러리는 ONNX 모델을 TensorFlow 형식으로 변환하는 데 사용
먼저 onnx-tf
를 설치
pip install onnx-tf
그 다음, 변환을 수행
from onnx_tf.backend import prepare
import onnx
# ONNX 모델 로드
onnx_model = onnx.load("model.onnx")
# TensorFlow 형식으로 변환
tf_rep = prepare(onnx_model)
# TensorFlow 모델 저장
tf_rep.export_graph("model.pb")
TensorFlow 모델을 TensorFlow Lite 형식으로 변환하기 위해서는 TensorFlow Lite Converter를 사용
먼저 tensorflow
를 설치
pip install tensorflow
import tensorflow as tf
# TensorFlow 모델 로드
converter = tf.lite.TFLiteConverter.from_saved_model("model.pb")
tflite_model = converter.convert()
# TensorFlow Lite 모델 저장
with open("model.tflite", "wb") as f:
f.write(tflite_model)