.pth -> .tflite 로의 여정

너 오늘 코드 짰니?·2023년 7월 20일
0

삽질 한스푼

목록 보기
2/2

변환 과정

pth -> onnx -> pb -> tflite 순서로 변환을 하였다.

pth -> onnx 과정은 매우 순조로웠다. torch에 내장된 torch.onnx.export 함수에 내가 사용할 onnx 버전에 맞는 opset_version 변수만 잘 넣어줘서 변환에 성공했다.

onnx -> pb 과정에서 약간의 고생을 하긴 했지만 다시 생각해보면 크게 어려운점은 없었던것 같다.

  • onnx : onnx 모델을 load
  • onnx_tf : load된 onnx 모델을 tensorflow 모델로 변환

위 두 개의 라이브러리를 사용했는데, onnx_tf 를 import 하는 과정에서 versioning 이 맞지 않아 import error가 났었다. 그래서 호환되는 버전을 잘 확인해서 설치하여 해결하였다.

나에게 닥친 시련

문제는 pb에서 tflite 모델로 변환할 때 발생하였다.

# pb 모델 경로
pb_model_path = "/opt/ml/input/code/trained_best/best.pb"
# TFLite 모델 저장 경로
tflite_model_path = "/opt/ml/input/code/trained_best/best.tflite"

# Converting a GraphDef from session.
converter = tf.compat.v1.lite.TFLiteConverter.from_session(
  sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

위와 같이 pb모델이 저장된 파일경로를 지정해서 변환을 시도했는데 내가 지정해준 pb_model_path에 해당하는 경로가 아닌 이상한 경로가 뜨면서 savedModel 이 없다고 에러가 났다.

File "/opt/conda/lib/python3.8/site-packages/tensorflow/python/saved_model/loader_impl.py", line 110, in parse_saved_model
    raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
OSError: SavedModel file does not exist at: /opt/ml/input/code/trained_medical_finanace_6000_gaussian/latest.pb/{saved_model.pbtxt|saved_model.pb}

아니 나는 분명 /opt/ml/input/code/trained_medical_finanace_6000_gaussian/latest.pb 까지만 해서 saved_model path를 지정해서 넘겼는데 자꾸 내가 지정해준 경로 뒤에 뭐 이상한 경로가 추가되서 /opt/ml/input/code/trained_medical_finanace_6000_gaussian/latest.pb/{saved_model.pbtxt|saved_model.pb}

이런 경로에 saved model 이 없다고 떠들어댔다. 그래서 loader_impl.py 구현체 부분을 따고 들어가봤더니 이런 코드가 있었다.

saved_model = saved_model_pb2.SavedModel()
  if file_io.file_exists(path_to_pb):
    try:
      file_content = file_io.FileIO(path_to_pb, "rb").read()
      saved_model.ParseFromString(file_content)
      return saved_model
    except message.DecodeError as e:
      raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
  elif file_io.file_exists(path_to_pbtxt):
    try:
      file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
      text_format.Merge(file_content.decode("utf-8"), saved_model)
      return saved_model
    except text_format.ParseError as e:
      raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
  else:
    raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
                  (export_dir,
                   constants.SAVED_MODEL_FILENAME_PBTXT,
                   constants.SAVED_MODEL_FILENAME_PB))

주목해야할 부분은 맨 마지막 else부분의 raise IOError 부분.

if문에서 path_to_pb에 적절한 파일이 있는지 체크해서 없으면 해당 에러를 뱉어내게 구현이 되어 있었는데 더 위쪽을 찾아보니까 구현체 코드 내에서 내가 지정해서 넘겨준 파일경로 export_dir에 뭘 추가해서 새로운 경로로 바꿔버리는 코드가 있더라.

# Build the path to the SavedModel in pbtxt format.
  path_to_pbtxt = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
  # Build the path to the SavedModel in pb format.
  path_to_pb = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))

그래서 저거 join하는부분 싹다 지워버리고 내가 지정해준 export_dir 경로 그대로 들어가게 한다음 돌렸는데 무슨 saved Model 형식이 아니라고 에러가 났다. 이때 부터 슬슬 빡이 치기 시작. 아니 .pb 모델로 저장해서 경로도 이제 제대로 들어가게 고쳤는데 왜 savedModel이 아니라는거지?

그래서 savedModel이 뭔지 좀 찾아봤다. 그 결과 알게된것 -> .pb 형식의 파일은 savedModel형식인 경우와 아닌경우 두가지로 존재할 수 있는데

  • SavedModel : TensorFlow 2.x 버전부터 도입된 모델 저장 형식으로 모델의 가중치, 그래프 구조, 변수, 연산 등을 포함하는 디렉토리 형태로 저장됨. 그러니까 SavedModel은 saved_model.pb 파일과 해당 디렉토리 안에 있는 변수 및 리소스 파일들로 구성된거다.
  • FrozenGraph : TensorFlow 1.x 버전에서 주로 사용되던 모델 저장 형식으로 단일 .pb파일로 구성되고 TensorFlow 그래프 구조와 가중치를 포함하고 있으며, 일반적으로 Protobuf 형식으로 직렬화되어 있다.

아!!!! 내가 저장한 방식은 FrozenGraph였는데, 위에서 SavedModel 형식을 tflite 파일로 바꾸려는 코드를 사용했기 때문에 에러가 났나보다. 그래서 export_dir에 saved_model이 저장되어있는 디렉토리 경로를 주면 라이브러리 구현체 안에서 추가적으로 경로 수정해서 파일 참조하게 작성되어 있던것이다. 이제 퍼즐이 좀 맞춰진다.

아 그럼 frozenGraph를 tflite로 변환하는 코드를 찾아봐야겠네.

구글링 했었을 때 죄다 from_session이나 from_saved_model 과 같은 함수만 나와서 삽질을 했었는데

https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter

docs에서 찾아보니까 from_frozen_graph 라는 함수가 있었다!!

frozen_graph 로부터 tflite로 변환하는 코드

# TensorFlow Lite 포맷으로 모델 변환
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph('/opt/ml/input/code/trained_medical_finanace_6000_gaussian/latest.pb', #TensorFlow freezegraph .pb model file
                                                      input_arrays=['input.1'], # name of input arrays as defined in torch.onnx.export function before.
                                                      output_arrays=['268', '257']  # name of output arrays defined in torch.onnx.export function before.
                                                      )

converter.optimizations = [tf.lite.Optimize.DEFAULT]	# 최적화 시키기
tflite_model = converter.convert()

# TensorFlow Lite 모델 저장
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

이 때 파라미터로 input_arrays와 output_arrays를 넣어주어야 하는데 이게 뭐냐면 tensorflow에는 각 노드마다 이름이 붙어있다. 그래서 입력노드와 출력노드에 해당하는 이름을 적어주어야 하는데, 내가 첨부터 tensorflow로 모델을 짠게 아니라 pytorch로 짠 모델을 강제로 변환했으니 노드이름을 알 수 가 없다.

그래서 netron 사이트에서 .pb 모델 올려서 시각화 한다음에 이름을 확인했다.


이렇게 왼쪽에 시각화된 그래프의 입출력 혹은 각 노드들을 누르면 오른쪽에 name으로 노드 이름이 나온다.

그런데 왜 savedModel이 아니라 frozenGraph로 변경된거지?

이것도 좀 궁금해서 찾아봤는데 솔직히 못찾았다 ㅠㅠ 확실친 않지만 대충 파악한 바로는 onnx와 tensorflow간의 이동을 할 수 있게 해주는 라이브러리가 두 개가 있는데

  1. onnx_tf : onnx모델을 tensorflow 1.x버전의 모델(frozenGraph)로 바꾸는 라이브러리
  2. tf2onnx : tensorflow 2.x 버전의 모델(savedModel)을 onnx모델로 바꾸는 라이브러리

나는 onnx_tf를 썼는데 이게 1.x 버전의 모델인 frozenGraph로 바꾸는건지 모르고 그냥 썼다. 아니 frozenGraph니 savedModel이니 그런거를 몰랐다 애초에. 아니 같은 .pb 파일인데 다른방식인게 말이 되냐고 진짜 :<

그럼 tf2onnx로 역변환 하면 onnx를 savedModel로 바꿀 수 있는거 아닌가? 싶어서 찾아보긴 했는데 역변환은 안되는것 같다. 그래서 그냥 frozenGraph -> tflite로 노선 최종변경해서 결국 tflite 뽑는데 성공했다.

https://github.com/onnx/onnx-tensorflow/blob/main/Versioning.md

아니근데 이건 뭔데 하... 이거보고 tensorflow 2.x 버전으로 맞춰서 한건데 뭐지..?

이 포스팅은 TroubleShooting 을 다룬 글이므로 pth 부터 tflite로 변환하는 과정을 깔끔하게 정리한 글을 보고싶으신 분은 아래 링크를 타고 확인해주세요

https://velog.io/@renovatio_hyuns/Android-serving%EC%9D%84-%EC%9C%84%ED%95%9C-pth-to-tflite-convert

profile
안했으면 빨리 백준하나 풀고자.

1개의 댓글

comment-user-thumbnail
2023년 7월 20일

가치 있는 정보 공유해주셔서 감사합니다.

답글 달기