Python code 를 사용한 onnx graph node 삭제 과정
onnx graph node 중 quantized linear를 추출하여 실험해야했기에 python code를 사용해 생성하려고 했다.
사용한 quantize model은 onnx model zoo 에서 가져온 vision/classification/alexnet qdq onnx model이다
아래 그래프 중 나는 QuantizeLinear 노드 한 개만 추출하려고 한다.
import onnx, os
import onnx_graphsurgeon as gs
path = 'onnxmodels'
model_name = 'bvlcalexnet-12-qdq.onnx'
accepted = 'data_0_Conv_nc_rename_0_QuantizeLinear'
onnx_model = onnx.load(os.path.join(path,model_name))
graph = onnx_model.graph
# Get the nodes of the graph
nodes = graph.node
# print(len(nodes))
# print([for node in nodes if nodes.])
# Delete all nodes except the first node
for i in range(len(nodes)-1, 0, -1) :
if nodes[i].name != accepted: graph.node.remove(nodes[i])
# graph.node.add
for output in graph.output:
graph.output.remove(output)
prob_info = onnx.helper.make_tensor_value_info('data_0_quantized', onnx.TensorProto.UINT8, [1, 3, 224, 224])
graph.output.insert(1,prob_info)
print(graph.node)
# Save the modified model
onnx.save(onnx_model, os.path.join(path, 'result.onnx'))
위 코드에서 주목해야할 부분은 2가지다.