Grounding DINO로 Object Detection하기

devstone·2023년 11월 24일
3

Capstone Project

목록 보기
1/2
post-thumbnail
post-custom-banner

Intro

이번 졸업 프로젝트에서 객체 탐지를 바탕으로 특정 상황을 판단해야하는 로직이 필요하게 되었다. 그래서 여러가지 객체탐지 모델을 찾아보던 중 비교적 최근 나온 Grounding Dino를 다뤄보게 되었다.

Grounding Dino란?

Grounding Dino(Grounding Discriminator INterpolation)는 Zero-Shot Object Detection을 수행하기 위한 기술이다.
공식 깃헙 링크

Zero-Shot이 무엇인가

학습 데이터에 해당 클래스에 대한 레이블이 없는 상태에서 새로운 클래스에 대해 모델이 작동하는 능력을 의미한다. 일반적으로 Object Detection은 학습된 모델을 사용해 annotated된 data에 대한 탐지를 수행하지만, Grounding DINO와 같은 Zero-Shot Object Detection모델은 새로운 클래스에 대한 annotated data가 없어도 객체 탐지가 가능하다.

Zero-Shot은 학습 단계에서 새로운 클래스에 대한 어떠한 예시나 레이블도 제공하지 않고, 이러한 클래스를 인식하거나 처리하는 능력을 개발하는 것을 목표로 한다.

이미지에서 객체 검출하기

해당 모델을 이용해 객체를 검출해보았다.

1. 일단 사전 준비를 해준다

Grounding DINO는 NVIDIA GPU를 사용하기 때문에 해당 명령어로 관련 리소스들을 다운로드 받아줘야 한다. 로컬에서 해당 GPU 셋업을 하기 어려울 것 같아 나는 Google Colab을 이용해 해당 GPU 환경을 세팅하였다.


!nvidia-smi

그리고 관련 모듈과 디렉토리를 가져온다

import os # os모듈 가져오기 
HOME = os.getcwd() # 현재 작업 중이 디렉토리 절대 경로 반환
print(HOME)

2. Grounding DINO 설치하기

# Grunding DINO Github을 통해 소스 다운로드 
%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git

# tourch, roboflow 를 설치 
%pip install torch
%cd {HOME}/GroundingDINO
%pip install -q -e .
%pip install -q roboflow
# 제대로 설치되었는지 확인 
import os

CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
print(CONFIG_PATH, "; exist:", os.path.isfile(CONFIG_PATH))

위처럼 Grounding DINO를 설치해 준 이후에, 아래의 코드로 해당 모델의 학습된 파라미터를 다운받아준다.

%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth

import os

WEIGHTS_NAME = "groundingdino_swint_ogc.pth"
WEIGHTS_PATH = os.path.join(HOME, "weights", WEIGHTS_NAME)
print(WEIGHTS_PATH, "; exist:", os.path.isfile(WEIGHTS_PATH))

3. Grounding DINO Model 로드하기

아래 코드를 통해 Grounding DINO Model을 로드해주었다.

%cd {HOME}/GroundingDINO

from groundingdino.util.inference import load_model, load_image, predict, annotate

model = load_model(CONFIG_PATH, WEIGHTS_PATH)

4. 흉기 난동 이미지에서 사람과 흉기 검출하기

흉기난동 영상에서 흉기과 사람 객체를 검출하기에 앞서, 사진을 통해 먼저 테스트를 해보았다.

import os
import supervision as sv

# 흉기를 들고 있는 사람 이미지를 /content/data 폴더 내부에 넣어주었다.
IMAGE_NAME = "person_knife.png"
IMAGE_PATH = os.path.join(HOME, "data", IMAGE_NAME)

# 객체 인식에 사용될 임계값 설정
TEXT_PROMPT = "A person with a knife in their hand"
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25

# 이미지 로드
image_source, image = load_image(IMAGE_PATH)

# 모델을 이용해 특정 객체 인식 
boxes, logits, phrases = predict(
    model=model,
    image=image,
    caption=TEXT_PROMPT,
    box_threshold=BOX_TRESHOLD,
    text_threshold=TEXT_TRESHOLD
)

annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)

# 이미지 시각화 
%matplotlib inline
sv.plot_image(annotated_frame, (16, 16))

해당 코드를 이용해 두 가지 사진을 돌려보았는데 아래와 같이 person과 knife가 검출 되었음을 알 수 있다.

영상에서 흉기를 들고 있는 사람 탐지하기

먼저 이미지를 처리하는 함수를 구성해야 한다.

# 이미지 처리 함수 
def process_image(image_path, model, text_prompt):
    BOX_TRESHOLD = 0.35
    TEXT_TRESHOLD = 0.25

    image_source, image = load_image(image_path)

    boxes, logits, phrases = predict(
        model=model,
        image=image,
        caption=text_prompt,
        box_threshold=BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD
    )

    # Apply Gaussian blur to the image_source
    blurred_image = cv2.GaussianBlur(image_source, (25, 25), 0)

    annotated_frame = annotate(image_source=blurred_image, boxes=boxes, logits=logits, phrases=phrases)

    return annotated_frame, phrases

그리고 동영상 파일을 읽어와 각 프레임에 대해 객체를 판별하고, 동영상에 프레임을 추가하는 코드를 구성한다.

import os
import cv2
from tqdm import tqdm
import supervision as sv
import tempfile

text_prompt = "A person with a knife in their hand"
input_video_path = '/content/source_video.mp4'
output_video_path = '/content/output_video.mp4'
output_knife_path = '/content/output_knife'
output_no_knife_path = '/content/output'

# Create output directories if they don't exist
os.makedirs(output_knife_path, exist_ok=True)
os.makedirs(output_no_knife_path, exist_ok=True)

# 동영상 파일 읽기
cap = cv2.VideoCapture(input_video_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

# Get the video width/height
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# 동영상 파일 만들기
video = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))

# 각 프레임에 대해
for _ in tqdm(range(frame_count)):

    # Read the frame
    ret, frame = cap.read()

    # Save the frame to a temporary file
    temp_filename = tempfile.mktemp(suffix='.jpg')
    cv2.imwrite(temp_filename, frame)

    # Process the image
    processed_frame, phrases = process_image(temp_filename, model, text_prompt)

    # Remove the temporary file
    os.remove(temp_filename)

    # 칼이 포함된 프레임 이미지
    if 'knife' in phrases:
        save_path = os.path.join(output_knife_path, f"frame_{_}.jpg")
    # 칼이 포함되어 있지 않은 이미지
    else:
        save_path = os.path.join(output_no_knife_path, f"frame_{_}.jpg")

    # 이미지 저장
    cv2.imwrite(save_path, processed_frame)

    # 동영상에 프레임 추가
    video.write(processed_frame)

# 동영상 파일 닫기
video.release()
cap.release()

이렇게 Grounding DINO를 통해 영상에서 흉기를 들고 있는 사람 객체를 탐지할 수 있다.

Epilogue

해당 모델을 사용해보니 생각보다 저화질에서도 객체 감지가 잘 되는 것을 확인할 수 있었다. 그래서 이어지는 시리즈에서는 해당 모델이 좀 더 흉기 난동 상황을 더 잘 탐지할 수 있도록 파인튜닝을 진행하고, Fast API를 이용해 API로 맵핑할 예정이다.

레퍼런스

https://www.cookieparking.com/share/dPIOi8sc

profile
개발하는 돌멩이
post-custom-banner

0개의 댓글