Titanic 모델을 API로 배포하여 웹에서 사용하기(feat. FastAPI)

김은호·2023년 8월 20일
1

들어가며

머신러닝 및 딥러닝을 공부하면서 여러 모델들은 공부했지만 정작 배포를 하여 사용한 적은 없었다. 그래서 학습한 모델을 API로 배포하여 사용하는 방법을 정리해보았다.

데이터 전처리 과정은 생략했다.

1. ML

1-1 모델 학습하기

Titanic 모델은 간단하게 LogisticRegression으로 학습을 했다.

from sklearn import linear_model
model = linear_model.LogisticRegression()
model.fit(X_train, y_train)

1-2 scaler, model 저장하기

학습한 모델과 scaler을 저장해야 새로운 input에 대해서도 scaler을 적용하고 모델에 적용할 수 있다.

import joblib
joblib.dump(scaler, 'scaler.pkl')
joblib.dump(model, 'model.pkl')

그러면 폴더에 pkl로 모델과 scaler가 저장된 것을 볼 수 있다.

2. FastAPI

from fastapi import FastAPI, Request
import joblib
import pandas as pd
from starlette.middleware.cors import CORSMiddleware

app = FastAPI()
# 저장한 scaler, model 불러오기
joblib_in = open('./model.pkl', 'rb')
joblib_sca = open('./scaler.pkl', 'rb')
model = joblib.load(joblib_in)
scaler = joblib.load(joblib_sca)

origins = [ "*" ]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.add_middleware

@app.get("/")
async def index():
    return {
        'message': '타이타닉 생존 예측'
    }

@app.post('/predict')
async def predict(data: Request):
    data = await data.json()
    pClass = data['PClass']
    gender = data['Sex']
    age = data['Age']
    sibsp = data['SibSp']
    parch = data['Parch']
    fare = data['Fare']
    embarked = data['Embarked']
    familySize = sibsp + parch + 1
    isAlone = 0
    
    if(familySize == 1):
        isAlone = 1
    if embarked == 'S':
        embarked = 0
    elif embarked == 'C':
        embarked = 1
    else:
        embarked = 2
        
    df = pd.DataFrame({
        'Pclass': [pClass],
        'Sex': [gender],
        'Age': [age],
        'SibSp': [sibsp],
        'Parch': [parch],
        'Fare': [fare],
        'Embarked': [embarked],
        'FamilySize': [familySize],
        'IsAlone': [isAlone]
    })

    df_scaled = scaler.transform(df)
    result = model.predict(df_scaled)

    if(result[0] == 0):
        return {
            'message': '사망하였습니다.'
        }
    return {
        'message': '생존하였습니다.'
    }

서로 다른 포트에서 서버를 돌리고 있어 CORS 문제가 발생한다. 이를 해결하기 위해 FastAPI에서는 다음과 같이 해결한다.

from starlette.middleware.cors import CORSMiddleware

origins = [ "*" ] # 허용할 url 주소, *이면 모든 url에 대해 허용

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.add_middleware

저장한 model과 scaler을 불러와 scaling을 하고 모델에 데이터를 적용한다.

joblib_in = open('./model.pkl', 'rb')
joblib_sca = open('./scaler.pkl', 'rb')
model = joblib.load(joblib_in)
scaler = joblib.load(joblib_sca)

df_scaled = scaler.transform(df)
result = model.predict(df_scaled)

그 후 서버를 실행시킨다.

uvicorn main:app --port 8000 --reload

3. Web


기본적으로 HTML은 위와 같이 구성했다.

data를 구성하고 API 요청을 한다.

const data = {
  PClass: parseInt(pclass_value),
  Sex: parseInt(sex_value),
  Age: parseInt(age_value),
  SibSp: parseInt(sibsp_value),
  Parch: parseInt(parch_value),
  Fare: parseFloat(fare_value),
  Embarked: embarked_value,
};

const result = fetch('http://localhost:8000/predict', {
  method: 'POST',
  body: JSON.stringify(data),
})
  .then((response) => response.json())
  .then((data) => {
    console.log(data);
  })
  .catch((err) => {
    console.log(err);
  });

잘 뜨는 모습!

0개의 댓글