[Flask] 간단한 모델 서빙해보기

정원석·2024년 3월 24일
0

MLOps

목록 보기
14/14
post-thumbnail

간단한 모델을 만들고 직접 서빙을 해보자!

1. Flask 에서 사용할 모델 학습 및 저장

  • Flask 는 사용하기 쉽고, 간단한 기능을 가볍게 구현하기에 적합하기 때문에 대부분의 ML Model의 첫 배포 Step 으로 자주 사용하는 Framework 중 하나이다.
  • iris data를 사용한 간단한 classification model을 학습한 뒤, 모델을 pickle 파일로 저장하고, Flask를 사용해 해당 파일을 load하여 predict 하는 server를 구현할 것이다.
  • 그 이후, 해당 server를 run하여 직접 http request를 요청하여 정상적으로 response가 반환되는지 확인해보자.

1) Sample code

  • Sample python code
    • requirements

      • scikit-learn
    • 소스 코드

      import os
      import pickle
      
      from sklearn.datasets import load_iris
      from sklearn.ensemble import RandomForestClassifier
      from sklearn.metrics import accuracy_score, classification_report
      from sklearn.model_selection import train_test_split
      
      RANDOM_SEED = 1234
      
      # 1. data load
      data = load_iris()
      
      # 2. data split
      X = data['data']
      y = data['target']
      
      X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=RANDOM_SEED)
      
      # 3. train model
      model = RandomForestClassifier(n_estimator=300, random_state=RANDOM_SEED)
      model.fit(X_train, y_train)
      
      # 4. evaluate model
      print(f"Accuracy : {accuracy_score(y_test, model.predict(X_test))}")
      print(classification_report(y_test, model.predict(X_test)))
      
      # 5. save model to ./build/model.pkl
      os.makedirs("./build", exist_ok=True)
      pickle.dump(model, open('./build/model.pkl', 'wb'))

2) 모델 학습 및 저장

  • 위에 파일 실행해보자. python train.py 로 실행
    iris classification
  • 실행하면 build 폴더가 생기고, 안에 model.pkl 파일이 생긴것을 확인할 수 있다.

2. Flask server 구현

    1. 에서 학습 후 저장했던 모델(pickle 파일)을 load하여, POST /predict API를 제공하는 Flask Server를 구현해보자.
    import pickle
    
    import numpy as np
    from flask import Flask, jsonify, request
    
    # 지난 시간에 학습한 모델 파일을 불러오기
    model = pickle.load(open('./build/model.pkl', 'rb'))
    
    # Flask Server
    app = Flask(__name__)
    
    # POST /predict 라는 API 를 구현.
    @app.route('/predict', methods=['POST'])
    def make_predict():
        # API Request Body 를 python dictionary object 로 변환하기.
        request_body = request.get_json(force=True)
    
      # request body 를 model 의 형식에 맞게 변환
      X_test = [request_body['sepal_length'], request_body['sepal_width'],
                request_body['petal_length'], request_body['petal_width']]
      X_test = np.array(X_test)
      X_test = X_test.reshape(1, -1)
    
      # model 의 predict 함수를 호출하여, prediction 값 구하기
      y_test = model.predict(X_test)
    
      # prediction 값을 json화
      response_body = jsonify(result=y_test.tolist())
    
      # predict 결과를 담아 API Response Body 를 return
      return response_body
    
    if __name__ == '__main__':
      app.run(port=5000, debug=True)

3. API 테스트

  • 위의 Flask server를 run하고
    python flask_server.py

  • 해당 Flask server에 POST /predict API를 요청하여, 어떤 결과가 반환되는지 확인하자.

    curl -X POST -H "Content-Type:application/json" --data '{"sepal_length":5.9, "sepal_width": 3.0, "petal_lentgh": 5.1, "petal_width": 1.8}' http:localhost:5000/predict
    
    # {"result":[2]}

    test1

    • 0, 1, 2 중의 하나의 type으로 classification 하게 된다.
    • "sepal_length": 2.6, "sepal_width": 5.9, "petal_length":2.0, "petal_width": 5.1로 바꿔서 다시 테스트
      test2
    • {"result":[0]}
profile
이기적이타주의자

0개의 댓글