딥러닝 프레임워크 간의 호환성을 제공하기 위해 만들어진 오픈 소스 모델 교환 형식
pytorch, tensorflow, scikit learn, Keras 등 여러 프레임워크에서 학습된 모델을 import, export 할 수 있도록 호환성
import torch
import torch.nn as nn
import torch.onnx
class SVDModel(nn.Module):
def __init__(self, num_users, num_items, embedding_dim=20):
super(SVDModel, self).__init__()
self.user_embedding = nn.Embedding(num_users, embedding_dim)
self.item_embedding = nn.Embedding(num_items, embedding_dim)
self.user_bias = nn.Embedding(num_users, 1)
self.item_bias = nn.Embedding(num_items, 1)
def forward(self, user_id, item_id):
user_vec = self.user_embedding(user_id)
item_vec = self.item_embedding(item_id)
dot_product = (user_vec * item_vec).sum(1, keepdim=True)
return dot_product + self.user_bias(user_id) + self.item_bias(item_id)
# 모델 초기화 및 변환
num_users, num_items = 1000, 500
model = SVDModel(num_users, num_items)
model.eval()
dummy_user = torch.tensor([1], dtype=torch.long)
dummy_item = torch.tensor([10], dtype=torch.long)
torch.onnx.export(
model,
(dummy_user, dummy_item),
"svd_model.onnx",
input_names=["user_id", "item_id"],
output_names=["rating"],
opset_version=11
)
print("success")