[YOLOv5]커스텀 모델의 객체 탐지 결과 띄우기(REST API) with Flask

ywwwon01·2022년 12월 3일
0
post-thumbnail

최근, YOLOv5 를 이용한 인공지능 프로젝트를 진행할 기회가 있었습니다..

그래서 저는 현재, 사전에 학습해 둔 YOLOv5 모델이 있는 상태이며

이 모델을 이용해 해당 포스트를 작성해 보겠습니다.


🔗 ultralytics/yolov5/utils/flask_rest_api

YOLOv5의 공식 깃허브 레포지토리에도

Flask를 이용한 REST API 작성 및 YOLOv5 모델 배포 관련 내용과 디렉토리가 따로 있더군요!

README.md
REST APIs are commonly used to expose Machine Learning (ML) models to other services. This folder contains an example REST API created using Flask to expose the YOLOv5s model from PyTorch Hub.

REST API는 일반적으로 머신러닝 모델을 다른 서비스에 노출하는 데 사용되며,

해당 폴더에서는 PyTorch Hub에서 yolov5s 모델을 노출하기 위해 작성되었다고 쓰여있는 것 같습니다.

Requirements

$ pip install Flask

Flask를 설치해야 합니다만

사실 뒤에 적어둔 코드를 보시면

Flask 외에도 몇 가지 더 필요한 라이브러리가 있습니다..

저는 ModuleNotFoundError : no module named '~~'가 나올 때마다 하나씩 추가로 설치 해줬네요.. 🥲

...

아니면 이런 방법도 있었습니다

아래 .txt 파일을 준비하고

requirements.txt

flask
requests

# YOLOv5 requirements
matplotlib>=3.2.2
numpy>=1.18.5
opencv-python>=4.1.1
Pillow>=7.1.2
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.64.0
seaborn>=0.11.0
pandas

# Extras --------------------------------------
ipython  # interactive notebook
psutil  # system utilization
thop>=0.1.1  # FLOPs computation

터미널에서

$ pip install -r requirements.txt

하면 필요한 라이브러리들이 모두 다운로드 됩니다

requirements.txt 내용은 아래 깃허브에서 가져왔습니다.


🔗 robmarkcole / yolov5-flask

restapi.py

다음은 아주 쉽습니다.

일단 앞서 언급한 공식 깃허브에서 restapi.py를 그대로 가져와서, 모델을 가져오는 부분만 일부 수정하면 됩니다.

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run a Flask REST API exposing one or more YOLOv5s models
"""

import argparse
import io

import torch
from flask import Flask, request
from PIL import Image

app = Flask(__name__)
models = {}

DETECTION_URL = "/v1/object-detection/<model>"


@app.route(DETECTION_URL, methods=["POST"])
def predict(model):
    if request.method != "POST":
        return

    if request.files.get("image"):
        # Method 1
        # with request.files["image"] as f:
        #     im = Image.open(io.BytesIO(f.read()))

        # Method 2
        im_file = request.files["image"]
        im_bytes = im_file.read()
        im = Image.open(io.BytesIO(im_bytes))

        if model in models:
            results = models[model](im, size=640)  # reduce size=320 for faster inference
            return results.pandas().xyxy[0].to_json(orient="records")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Flask API exposing YOLOv5 model")
    parser.add_argument("--port", default=5000, type=int, help="port number")
    parser.add_argument('--model', nargs='+', default=['yolov5s'], help='model(s) to run, i.e. --model yolov5n yolov5s')
    opt = parser.parse_args()

    for m in opt.model:
        models[m] = torch.hub.load("ultralytics/yolov5", 'custom', 'best.pt', force_reload=True, skip_validation=True)

    app.run(host="0.0.0.0", port=opt.port)  # debug=True causes Restarting with stat

line 46의 이 부분을 고쳐줘야 하는데

저는 여기서 조금 헤맸습니다..



🔗 Load YOLOv5 from PyTorch Hub ⭐ #36

torch.hub.load()를 쓰는 방법에 대한 이해가 부족했네요 🥲

커스텀 모델은 이런 식으로 불러오니 성공했습니다!

models[m] = torch.hub.load("ultralytics/yolov5", 'custom', '모델의_경로', force_reload=True, skip_validation=True)

실행

$ python restapi.py --port 5000 --model 모델명

제 모델의 이름은 hold 였습니다.

Q. 모델명이 뭔가요?

A. train--name으로 지정해준 이름 입니다.

curl -X POST -F image=@테스트이미지경로 'http://localhost:5000/v1/object-detection/모델명'

성공입니다..

저는 --model parameter를 무시하고 있었는데

참.. 그래선 안되는 거였는데.. (🥲)(아무튼)

이제 이렇게 JSON 형식으로 반환된 좌표, 즉 객체 탐지 결과는

앱이나 웹에서 네트워크 요청을 통해 사용하면 되겠습니다!!

이상입니다.

profile
생각의 전개를 공유합니다.

0개의 댓글