YOLO - Classification 실습

dumbbelldore·2025년 1월 26일
0

zero-base 33기

목록 보기
84/97

1. 데이터셋 확보

!pip install roboflow
!pip install ultralytics

from IPython.displays import clear_output
clear_output()

from roboflow import Roboflow
rf = Roboflow(api_key="API KEY")
project = rf.workspace("joseph-nelson").project("rock-paper-scissors")
version = project.version(1)
dataset = version.download("folder")

2. 모델 훈련

  • train() 함수 이용 시 데이터셋 폴더 내 train, val, test 이미지를 자동으로 인식함
from ultralytics import YOLO

model = YOLO("yolov8n-cls.pt")
res = model.train(
    data="/content/Rock-Paper-Scissors-1",
    epochs=10,
    imgsz=300
)
  • Train & Validation 과정 확인
import matplotlib.pyplot as plt
fpath = "runs/classify/train/results.png"
img = plt.imread(fpath)
plt.imshow(img)
plt.axis("off")
plt.show()

3. 예측

  • 훈련 과정 중 가장 좋은 성능을 보인 모델의 가중치를 불러온 뒤, predict() 함수로 예측 실시
best_model = YOLO("runs/classify/train/weights/best.pt")
pred = best_model.predict(source="sample.jpg", save=True)

# 주먹을 주먹으로 잘 예측하였음
fpath = "runs/classify/predict/sample.jpg"
img = plt.imread(fpath)
plt.imshow(img)
plt.axis("off")
plt.show()

*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.

profile
데이터 분석, 데이터 사이언스 학습 저장소

0개의 댓글