오늘은 사이드 프로젝트에 쓸 모델 서버를 fast api로 구현해봤다.
우선 모델 서버에 들어갈 api는 airflow랑 병행하면 하나이기때문에 추후 서비스 서버와 통합해서 만들수도 있다. 확장성을 고려해서 일단 설계중
실행은 도커로 서버를 띄울 예정
Dockerfile
FROM python:3.9.13-slim as requirements-stage
WORKDIR /tmp
RUN pip install poetry
COPY ./pyproject.toml ./poetry.lock* /tmp/
RUN poetry export -f requirements.txt --output requirements.txt --without-hashes
FROM python:3.9.13-slim
WORKDIR /code
COPY --from=requirements-stage /tmp/requirements.txt /code/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
COPY . /code
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
from contextlib import asynccontextmanager
from fastapi import FastAPI
from loguru import logger
from sqlmodel import SQLModel
from config import config
from api import router
from database import engine
from model import ModelOptions
from dependencies import load_model
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
# Create Database - 추후 sql 변경
logger.info("Creating database tables")
SQLModel.metadata.create_all(engine)
# 모델 로드
logger.info("Loading model")
load_model(config.model_path)
yield
except Exception as e:
logger.error(f"Startup error: {e}")
raise
finally:
logger.info("Shutting down application")
app = FastAPI(lifespan=lifespan)
app.include_router(router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
from typing import List
from pydantic import BaseModel, Field
class PredictionRequest(BaseModel):
user_id: int
movie_list: List[int] = Field(min_items=1)
class PredictionResponse(BaseModel):
user_id: int
movie_list: List[int] = Field(min_items=10, max_items=10)
from fastapi import HTTPException
model = None
def load_model(model_path: str):
try:
import joblib
global model
model = joblib.load(model_path)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
def get_model():
global model
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
return model
import os
from typing import List
from fastapi import APIRouter, HTTPException
from sqlmodel import Session, select
from loguru import logger
import random
from schemas import PredictionRequest, PredictionResponse
from database import engine, PredictionResult
from dependencies import get_model
from config import config
router = APIRouter(prefix="/model", tags=["Model"])
model_path = config.model_path
@router.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest) -> PredictionResponse:
# 생성된 모델 dependencies에서 가져오기
model = get_model()
# 모델 predict 함수의 파라미터 확인 필요
# prediction = model.predict(user_id=request.user_id, movies=request.movie_list)
# 테스트용
prediction = random.sample(range(1, 101), 10)
# 예측 결과 변환
prediction_result = PredictionResult(user_id=request.user_id, movie_list=prediction)
with Session(engine) as session:
session.add(prediction_result)
session.commit()
session.refresh(prediction_result)
return PredictionResponse(user_id=request.user_id, movie_list=prediction)