241221 TIL #569 AI Tech #102 FastAPI 구현

김춘복·5일 전
0

TIL : Today I Learned

목록 보기
571/575

Today I Learned

오늘은 사이드 프로젝트에 쓸 모델 서버를 fast api로 구현해봤다.


FastAPI 구현

  • 우선 모델 서버에 들어갈 api는 airflow랑 병행하면 하나이기때문에 추후 서비스 서버와 통합해서 만들수도 있다. 확장성을 고려해서 일단 설계중

  • 실행은 도커로 서버를 띄울 예정

  • Dockerfile

FROM python:3.9.13-slim as requirements-stage

WORKDIR /tmp
RUN pip install poetry
COPY ./pyproject.toml ./poetry.lock* /tmp/
RUN poetry export -f requirements.txt --output requirements.txt --without-hashes

FROM python:3.9.13-slim

WORKDIR /code
COPY --from=requirements-stage /tmp/requirements.txt /code/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
COPY . /code
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
  • main.py
from contextlib import asynccontextmanager

from fastapi import FastAPI
from loguru import logger
from sqlmodel import SQLModel
from config import config

from api import router
from database import engine
from model import ModelOptions
from dependencies import load_model


@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        # Create Database - 추후 sql 변경
        logger.info("Creating database tables")
        SQLModel.metadata.create_all(engine)

        # 모델 로드
        logger.info("Loading model")
        load_model(config.model_path)

        yield

    except Exception as e:
        logger.error(f"Startup error: {e}")
        raise
    finally:
        logger.info("Shutting down application")


app = FastAPI(lifespan=lifespan)
app.include_router(router)


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)
  • schemas.py
from typing import List
from pydantic import BaseModel, Field


class PredictionRequest(BaseModel):
    user_id: int
    movie_list: List[int] = Field(min_items=1)


class PredictionResponse(BaseModel):
    user_id: int
    movie_list: List[int] = Field(min_items=10, max_items=10)
  • dependancies.py
from fastapi import HTTPException

model = None


def load_model(model_path: str):
    try:
        import joblib

        global model
        model = joblib.load(model_path)
    except Exception as e:
        raise RuntimeError(f"Failed to load model: {e}")


def get_model():
    global model
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    return model

  • api.py
import os
from typing import List
from fastapi import APIRouter, HTTPException
from sqlmodel import Session, select
from loguru import logger
import random

from schemas import PredictionRequest, PredictionResponse
from database import engine, PredictionResult
from dependencies import get_model
from config import config


router = APIRouter(prefix="/model", tags=["Model"])

model_path = config.model_path


@router.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest) -> PredictionResponse:
    # 생성된 모델 dependencies에서 가져오기
    model = get_model()
    # 모델 predict 함수의 파라미터 확인 필요
    # prediction = model.predict(user_id=request.user_id, movies=request.movie_list)

    # 테스트용
    prediction = random.sample(range(1, 101), 10)

    # 예측 결과 변환
    prediction_result = PredictionResult(user_id=request.user_id, movie_list=prediction)

    with Session(engine) as session:
        session.add(prediction_result)
        session.commit()
        session.refresh(prediction_result)

        return PredictionResponse(user_id=request.user_id, movie_list=prediction)
profile
Backend Dev / Data Engineer

0개의 댓글