YOLOv8 모델 loss function 수정

이지해·2023년 8월 31일
2

object detection

목록 보기
2/2
post-thumbnail

처음 YOLOv8 모델을 개선하려고 했을 때,
자료가 너무 부족해서 어느 파일의 어떤 부분을 고쳐야 하는지 찾기 힘들었다 😂

그래서, 어떻게 수정해야 하는지를 공유하고자 한다.

이번 포스트에서는 모델의 loss function을 수정하는 방법에 대해 설명할 것이다.


1. YOLOv8 패키지 클론

먼저, YOLOv8을 개발한 ultralytics의 깃허브에서 코드를 클론해준다.

https://github.com/ultralytics/ultralytics

코드를 보면 ultralytics 안에 또다른 ultratlytics 폴더가 있을 것이다.
들어가보자.

여기서 우리가 주목해야 할 곳은 바로 utils 폴더이다.


2. loss function 수정

utils 폴더에 들어가보자.
이것저것 파일이 많은데 그 중 loss.py에 들어가준다.

이 파일이 바로 yolo의 손실함수가 정의되어있는 파일이다.

여러 클래스 중 V8DetectionLoss 클래스가,
기본 YOLOv8 모델이 loss를 계산하는 부분이다.

이 부분을 수정하기 전 yolov8의 loss function 구조를 뜯어보자.


yolov8의 loss function은 다음과 같이 되어있다.

-> 위의 세 가지 loss 값에 각 가중치를 곱해주면 그것이 바로 loss값이다.

V8DetectionLoss 클래스에서는
세 가지 loss값을 각각
loss[0], loss[1], loss[2]로 저장하는데, 수정하기 원하는 부분을 고쳐주면 된다.


나는 class loss를 기존 BCE에서 Focal Loss로 수정해주려고 한다.


이 부분이 바로 class loss를 계산하는 부분인데 bce함수를 사용하고 있다.
이 부분을 고쳐주자.

1) 먼저, 내가 사용할 loss function 클래스를 정의해준다.

다행히, 내가 사용하려고 하는 focal loss 클래스는 loss.py 파일내에 정의가 되어있었다.

만약, 새로운 loss function을 추가하고 싶다면,
파이토치 프레임워크 형태로 정의해주면 된다.

2) 그 다음, loss function 오브젝트를 V8DetectionLoss 클래스의 __init__ 함수에 초기화해준다.

사용하려고 하는 loss function을 밑줄친 것과 같이 초기화해준다.

3) 이제 기존 loss function 대신 새로운 loss function을 사용하도록 __call__ 함수를 수정해준다.

기존 코드를 주석처리하고 새로운 loss function을 사용해 계산하도록 수정해주면 된다.

이제 이대로 저장하고, 모델을 학습시키면 된다!


이렇게 해서 간단하게 yolov8의 loss function을 수정하는 방법에 대해 알아보았다.

profile
한 줄 두 줄 기록하는 내 맘대로 블로그

0개의 댓글

관련 채용 정보