onnx graph node 삭제 (python api)

Lee Hyun Joon ·2023년 3월 27일
0

Python code 를 사용한 onnx graph node 삭제 과정

실험 목적

onnx graph node 중 quantized linear를 추출하여 실험해야했기에 python code를 사용해 생성하려고 했다.

사용한 quantize model은 onnx model zoo 에서 가져온 vision/classification/alexnet qdq onnx model이다

Alexnet model onnx

아래 그래프 중 나는 QuantizeLinear 노드 한 개만 추출하려고 한다.

구현 코드

import onnx, os
import onnx_graphsurgeon as gs 

path = 'onnxmodels'
model_name = 'bvlcalexnet-12-qdq.onnx'
accepted = 'data_0_Conv_nc_rename_0_QuantizeLinear'
onnx_model = onnx.load(os.path.join(path,model_name))
graph = onnx_model.graph
# Get the nodes of the graph
nodes = graph.node
# print(len(nodes))
# print([for node in nodes if nodes.])
# Delete all nodes except the first node
for i in range(len(nodes)-1, 0, -1) :
    if nodes[i].name != accepted: graph.node.remove(nodes[i])
# graph.node.add
for output in graph.output:
    graph.output.remove(output)
prob_info = onnx.helper.make_tensor_value_info('data_0_quantized', onnx.TensorProto.UINT8, [1, 3, 224, 224])

graph.output.insert(1,prob_info)
print(graph.node)
# Save the modified model
onnx.save(onnx_model, os.path.join(path, 'result.onnx'))

위 코드에서 주목해야할 부분은 2가지다.

  1. graph.node를 통해 그래프 노드들에 대한 정보를 한번에 가져온다는 것
    • 이 과정을 통해 그냥 for문으로 노드들 삭제하면 된다.
    • 앞서 기존 그래프에서 보았듯이 quantized linear 노드가 제일 첫번째에 있기 때문에 그냥 인덱스 제어를 통해 나머지 노드들을 전부 삭제해줬다.
  2. output 제어
    • 실제로 onnx graph에서 필요없는 노드들을 모두 삭제하기만 하면 완벽히 수정된 onnx graph를 만들 수 없다.
    • 그 이유는 남긴 onnx graph node의 output이 없기 때문이다. output 정보를 추가해줘야 한다.
    • 위에 prob_info code에서 볼 수 있듯이 onnx.helper를 사용해 tensor value info를 만들어준다.
    • 이후 그래프 output에 추가해주는 것을 볼 수 있다.
    • 또한 노드들을 모두 삭제했어도, 기존 그래프의 최종 output에 대한 정보가 아직 살아있기 때문에 이것 또한 삭제해주는 것을 위 코드를 통해 볼 수 있다.
    • 추가해야 하는 output의 tensor 이름은 사용자가 의도한대로 남긴 노드들 중 제일 마지막 노드에 명시된 output 이름을 주기만하면 된다. 이 정보는 netron에서 모델을 열어서 확인하면 된다.

수정 후 결과

참고 문헌

profile
우당탕탕 개발 지망생

0개의 댓글