[ML] ONNX 개요

Woong·2025년 2월 26일
0

Python / Machine Learning

목록 보기
27/27

ONNX (Open Neural Network Exchange)

  • 딥러닝 프레임워크 간의 호환성을 제공하기 위해 만들어진 오픈 소스 모델 교환 형식

  • pytorch, tensorflow, scikit learn, Keras 등 여러 프레임워크에서 학습된 모델을 import, export 할 수 있도록 호환성

    • 학습과 배포가 서로 다른 프레임워크일 수 있음
    • ex) 학습은 PyTorch에서 하고, 배포는 ONNX Runtime을 사용하는 경우
  • ex) pytorch SVD 모델을 ONNX 로 export
import torch
import torch.nn as nn
import torch.onnx

class SVDModel(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim=20):
        super(SVDModel, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.user_bias = nn.Embedding(num_users, 1)
        self.item_bias = nn.Embedding(num_items, 1)

    def forward(self, user_id, item_id):
        user_vec = self.user_embedding(user_id)
        item_vec = self.item_embedding(item_id)
        dot_product = (user_vec * item_vec).sum(1, keepdim=True)
        return dot_product + self.user_bias(user_id) + self.item_bias(item_id)

# 모델 초기화 및 변환
num_users, num_items = 1000, 500
model = SVDModel(num_users, num_items)
model.eval()

dummy_user = torch.tensor([1], dtype=torch.long)
dummy_item = torch.tensor([10], dtype=torch.long)

torch.onnx.export(
    model, 
    (dummy_user, dummy_item), 
    "svd_model.onnx", 
    input_names=["user_id", "item_id"], 
    output_names=["rating"], 
    opset_version=11
)
print("success")

reference

0개의 댓글

관련 채용 정보

Powered by GraphCDN, the GraphQL CDN