Android serving을 위한 pth to tflite convert

너 오늘 코드 짰니?·2023년 7월 20일
0

Versioning

python==3.8.5
torch==1.7.1
onnx==1.7.0
onnx_tf==1.6.0
tensorflow==2.2.0

versioning guide

https://github.com/onnx/onnx-tensorflow/blob/main/Versioning.md

위 readme를 참고해서 버전을 맞추었습니다.

onnx가장 최신버전을 사용하는게 좋아보이는데, 제가 모델을 개발한 torch 버전에서 가장 높게 사용가능한 onnx 가 1.7.0 이어서 1.7.0 사용했습니다.

.pth to .onnx

"""
.pth to .onnx
"""
import torch
from model import EAST
import onnx
from onnxsim import simplify

pth_path = "{model_directory_path}/best.pth"
onnx_path = '{target_directory_path}/best.onnx'


model = CustomModel()
model.load_state_dict(torch.load(pth_path, map_location='cpu'))
model.eval()

# 모델을 ONNX로 변환
input_sample = torch.randn(1, 3, 1024, 1024)  # 입력 샘플 생성
torch.onnx.export(model, input_sample, onnx_path, opset_version=11)

# load your predefined ONNX model
model = onnx.load(onnx_path)

# convert model
model_simp, check = simplify(model)

assert check, "Simplified ONNX model could not be validated"

# use model_simp as a standard ONNX model object
onnx.save(model_simp, "best_simplified.onnx")

pytorch는 모델의 가중치만 따로 저장해서 모델에 load하여 사용하기 때문에 그래프의 형태가 pth 파일에 존재하지 않습니다. 따라서 pytorch에서 onnx모델로 export 하기 위해서는 입력 shape에 맞는 sample input을 넣어주어야 합니다.

또한 opset_version은 사용하는 onnx 라이브러리의 버전에 맞는 opset_version을 넣어주어야 합니다.

simplify(model)을 통하여 onnx변환된 모델을 좀 더 단순하게 줄일 수 있습니다.

.onnx to .pb (frozenGraph)

"""
.onnx to .pb
"""
import onnx 
import torch
from onnx_tf.backend import prepare

onnx_model_path = "{onnx_model_directory}/best.onnx"
tf_model_path = "{target_directory}/best.pb"

onnx_model = onnx.load(onnx_model_path)

tf_rep = prepare(onnx_model)
    
tf_rep.export_graph(tf_model_path)

onnx모델을 frozenGraph형식의 단일 .pb 파일로 변환합니다.

.pb to .tflite

"""
.pb to .tflite
"""
import tensorflow as tf

# pb 모델 경로
pb_model_path = "{pb_model_directory}/best.pb"
# TFLite 모델 저장 경로
tflite_model_path = "{target_directory}/best.tflite"

# TensorFlow Lite 포맷으로 모델 변환
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(pb_model_path, #TensorFlow freezegraph .pb model file
                                                      input_arrays=['input.1'], # name of input arrays as defined in torch.onnx.export function before.
                                                      output_arrays=['268', '257']  # name of output arrays defined in torch.onnx.export function before.
                                                      )

converter.optimizations = [tf.lite.Optimize.DEFAULT]	# 최적화
tflite_model = converter.convert()	# tflite로 변환

# TensorFlow Lite 모델 저장
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

input_arrays와 output_arrays에는 입출력 노드의 이름을 적어주어야 합니다.

각 노드의 이름을 모를 경우에는 netron과 같은 모델 시각화 툴을 사용해서 확인할 수 있습니다.

위 링크를 타고 netron에 변환 전의 pb 모델을 업로드하여 입출력 노드의 이름을 확인하고 input_arrays와 output_arrays 파라미터에 배열로 넣어주시면 됩니다.

Validation

모델 변환이 완료되었으면, 제대로 변환이 되었는지 입력을 넣어서 출력을 확인해보는 작업이 필요합니다.

"""
torch model validation
"""
import torch
from model import EAST
import numpy as np

# Load the .pth file
pth_path = "/opt/ml/input/code/trained_best/best.pth"
model = EAST()
model.load_state_dict(torch.load(pth_path, map_location='cpu'))
# model = torch.load("/opt/ml/input/code/trained_medical_finanace_6000_gaussian/latest.pth")
model.eval()  # Set the model to evaluation mode

# Prepare input data
input_data = torch.full((1, 3, 1024, 1024), 0.5)  # Example input data

# Make predictions
with torch.no_grad():
    output = model(input_data)

# Print the output
print(output[0].type())
print(torch.max(output[0]), torch.min(output[0]))
print(torch.max(output[1]), torch.min(output[1]))
"""
tensorflow model validation
"""
import tensorflow as tf
import numpy as np

# FrozenGraph 모델 경로
frozen_graph_path = '/opt/ml/input/code/trained_best/best.pb'

# TensorFlow 2.x에서 FrozenGraph 로드
with tf.io.gfile.GFile(frozen_graph_path, "rb") as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

# 로드한 모델을 기반으로 TensorFlow 그래프 생성
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="")

# 모델 실행
with tf.compat.v1.Session(graph=graph) as sess:
    # 입출력 노드 정의
    input_tensor = graph.get_tensor_by_name("input.1:0")  # 입력 텐서의 이름
    # input_tensor = tf.compat.v1.placeholder(tf.float32, shape=(1, 3, 1024, 1024))  # 입력 텐서 생성
    output_tensor1 = graph.get_tensor_by_name("268:0")  # 첫 번째 출력 텐서의 이름
    output_tensor2 = graph.get_tensor_by_name("257:0")  # 두 번째 출력 텐서의 이름

    # 입력 데이터
    input_data = np.full((1, 3, 1024, 1024), 0.5).astype(np.float32)
    output_data1, output_data2 = sess.run([output_tensor1, output_tensor2], feed_dict={input_tensor: input_data})

    print(np.max(output_data1), np.min(output_data1))
    print(np.max(output_data2), np.min(output_data2))

"""
tflite model validation
"""
import tensorflow as tf
import numpy as np

# TFLite 모델 파일 경로
model_path = "/opt/ml/input/code/trained_best/best.tflite"

# TFLite 모델 로드
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# 입력 텐서와 출력 텐서 정보 가져오기
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 입력 데이터 생성
input_data = np.full((1, 3, 1024, 1024), 0.5).astype(np.float32)

# 입력 데이터 설정
interpreter.set_tensor(input_details[0]['index'], input_data)

# 모델 실행
interpreter.invoke()

# 출력 데이터 가져오기
output_data1 = interpreter.get_tensor(output_details[0]['index'])
output_data2 = interpreter.get_tensor(output_details[1]['index'])

print(np.max(output_data1), np.min(output_data1))
print(np.max(output_data2), np.min(output_data2))

.pth, .pb, .tflite 모델을 각각 load 하여 동일한 입력을 넣고 출력결과를 확인합니다.

결과는 아래와 같습니다.

# .pth 모델
tensor(125.5052) tensor(-0.2398)
tensor(0.0046) tensor(4.0554e-05)

# .pb 모델
125.50522 -0.23980959
0.004594922 4.0555362e-05

#.tflite 모델
124.553635 -0.23529902
0.0049833357 4.3301996e-05

pth와 pb 모델은 어느정도 비슷한 출력이 나오는데 tflite에서 오차가 조금 있어 보이네요.

나중에 각 framework에 대한 이해가 좀 더 생기면 분석해보면 좋을것 같습니다. 그래도 비슷하게 출력이 나오는것을 보니 변환이 어느정도 잘 되었다 결론짓고 마무리하도록 하겠습니다.

혹시라도 변환과정에서 어려움을 겪는분 혹은 미래의 저를 위해 과정을 깔끔하게 정리해둔 글인데, 이 과정에서 겪은 각종 시행착오는 삽질 카테고리에 따로 작성해두었습니다.
https://velog.io/@renovatio_hyuns/.pth-.tflite-%EB%A1%9C%EC%9D%98-%EC%97%AC%EC%A0%95

profile
안했으면 빨리 백준하나 풀고자.

1개의 댓글

comment-user-thumbnail
2023년 7월 20일

많은 도움이 되었습니다, 감사합니다.

답글 달기