[Pytorch] Yolo v3 논문 리뷰 및 모델 구현

도룩·2024년 2월 8일
0
post-thumbnail

목적

Yolo v3을 이해하고 Pytorch로 모델과 Loss Function을 구현할 수 있다.

Architecture

  • Network architecture
    Feature extractor로 사용된 Darknet-53의 구조이다. Output layer에 해당하는 부분을 제거하고 detection을 위한 layer들을 추가하였다.

    Yolo v3의 전체 네트워크 구조이다. (num_classes = 20)
    각 output을 예측하는 마지막 Conv의 채널 수가 75인 이유는 예측하고자 하는 class가 20개이고, 각 output tensor당 anchor가 3 개씩 할당되기 때문이다. (5+classes)3\rightarrow (5 + classes) * 3
    \\
    마지막 Conv를 통과한 후 Output shape을 맞춰주어야 한다.
    (Batch_size x (num_classes + 5) * 3 x H x W)
    \rightarrow (Batch_size x 3 x H x W x (num_classes + 5))
    \\
    모델 측면에서 Yolo v2와 비교해보았을 때 큰 차이점은 다음과 같다.
    \\
    \\
    1. Feature extractor를 변경 (Darknet 19 \rightarrow Darknet 53)
      Yolo v3에서는 ResNet(2015)을 모방하여 만든 모델인 Darknet 53을 feature extractor로 사용하였다.
      \\
      \\
    2. Output tensor의 수가 늘어났다. (1 \rightarrow 3)

      이미지 출처: https://blog.paperspace.com/how-to-implement-a-yolo-object-detector-in-pytorch/
      \\
      여러 Conv 연산을 거치는 CNN 모델 특성상 깊어짐에 따라 하단의 feature map은 receptive field가 커지게 된다. 즉, CNN 모델 하단의 feature map은 매우 큰 receptive field의 정보가 작은 사이즈로 압축되어 있기 때문에 물체가 구체적으로 어디에 있는지는 알기 어렵다. 이러한 특징은 classification에는 유리할지 몰라도 localization에는 불리할 것이다.
      \\
      Yolo v2에서 작은 물체를 잘 검출하지 못하는 문제를 완화하기 위해 Yolo v3에서는 feature extractor 앞단의 작은 receptive field를 갖는 feature map 들을 활용하여 서로 다른 세 가지 resolution 으로 물체를 검출하고자 했다.
      \\
      \rightarrow Input image size 416x416 기준: 13x13, 26x26, 52x52

\\
\\

특징 및 결과

Yolo v3의 특징을 Yolo v2와 비교해보며 살펴보자.
차이점의 대부분이 architecture에 녹아있기 때문에 이해가 쉬울 것이다.

Bounding box를 예측하는 부분은 Yolo v2와 동일하므로 생략하겠다.

Feature Extractor

  • Darknet-53
    Yolo v3에서는 ResNet의 residual block을 적용한 Darknet-53을 feature extractor로 사용하였다.
    각 모델별 ImageNet에 대한 결과표이다. Darknet-53이 Yolo v2의 feature extractor 였던 Darknet-19보다 성능이나 속도를 와전히 압도하면서, ResNet-101보다 근소하게 성능이 좋고, ResNet-152에 비해서는 성능이 약간 떨어지지만 (Top-1 acc) FPS가 두 배 이상 빠른 것을 확인 할 수 있다.
    \\
    \\

Class Prediction

  • Mutli-label classification
    예를 들어 class 중 Man, Person이 포함되어 있다고 할 때 Multi-class classification 문제로 접근하여 Softmax 함수를 쓴다고 가정해보자. 만약 man의 확률이 높게 나왔다면 person은 man을 포함하는 단어임에도 불구하고 낮은 확률로 나올 수 밖에 없다. 따라서 Yolo v3에서는 활용한 데이터 셋 특성상 이러한 문제를 완화하기 위해 multilabel classification 문제로 접근하여 class별 확률을 구할 때 각 class별 독립적인 logistic classifier를 이용했다. (학습 시에는 binary cross-entropy 사용)
    \\
    \\

Predictions Across Scales

  • Yolo v3 predicts boxes at 3 different scales
    Upsampling을 이용해 feature map을 resolution을 증가시키고,(Upsampling(2)) feature extractor의 중간에서 feature map을 떼어와 채널축으로 포함해 (concat(dim = 1)) detecting 하는데 사용하였다.
    (Input 이미지 사이즈 416x416 기준으로 13x13, 26x26, 52x52 사이즈의 feature map에서 box를 예측하였음.)
    \\
    Yolo v3에서는 총 9개의 Anchor box를 사용하였다. (Yolo v2와 마찬가지로 k-means clustering 방식을 사용해서 anchors box를 구하였음.) Anchor box를 적용할 때는 Yolo v2와는 다르게 각 scale 마다 3 개의 anchor box만 사용하였다.
    \\
    ex)
    13x13 scale에서는 anchor box 1, 2, 3만 사용
    26x26 scale에서는 anchor box 4, 5, 6만 사용
    52x52 scale에서는 anchor box 7, 8, 9만 사용
    \\
    \\

Result

  • inference time VS mAP50_{50} 그래프

    그래프가 좀 이상하지만 여러 detection 모델보다 Yolo v3가 더 적은 inference time을 가지면서 성능이 좋다는 것을 강조하고 싶었던 것 같다.

    그래프 모양에 대해 첫 번째 reviewer의 지적으로 Figure 4를 추가로 첨부하였다.
    \\

    해당 Table에서 Yolo v3는 SSD(Single shot detector)보다는 더 높은 성능을 보였으나, RetinaNet에 비해서는 낮은 성능을 보인다. (저자는 RetinaNet이 Yolo v3보다 3.8 배 이상 느리다는 것을 강조하였다.)
    \\
    Yolo v3는 위 그래프처럼 다른 detector에 비해 inference time이 매우 짧으면서도 mAP50 기준으로 다른 SOTA모델들에 비해 뒤지지 않는 성능을 보인다.
    그러나 iou threshold가 높아질수록 다른 detector에 비해 꽤 낮은 성능을 보이는 것을 알 수 있다.
    \\
    이에 대해서는 논문 마지막 부분인 Rebuttal에서 COCO metric에서 높은 IoU threshold에 대한 mAP를 지적한다. 높은 IoU threshold를 갖는 mAP로 성능을 비교하는 것은 해당 box가 얼마나 잘 class label을 잘 부여하는지 보다 box를 얼마나 잘 치는지에 초점이 맞춰져있다. (어떤 물체를 예측할 때 Groud truth의 box의 IoU가 IoU threshold를 넘지 못하면 object가 없다고 간주하기 때문에 박스에 대한 분류를 했는지 안 했는지는 상관이 없게 되어버림.) 저자는 Box는 어느 수준 이상만 잘 치면되고 box보다는 class label을 얼마나 올바르게 예측하는지를 더 중요하다고 주장한다.
    \\
    그리고 mAP metric 자체의 문제점을 지적한다. (Figure 5)

    직관적으로 보기에 Detector #1이 Detector #2보다 더 잘 예측한 것처럼 보인다. 하지만 두 Detector 모두 해당 이미지에서 mAP는 1인 것을 알 수 있다. Detector #2도 mAP가 1인게 이상하게 느껴질 수도 있겠다. 한 번 살펴보자.
    \\
    mAP는 각 class별 AP를 평균낸 것이고, AP는 precision-recall 곡선을 보정한 area의 넓이다. 이 때문에 recall이 1일 때 precision이 1이 된다면 해당 class에 대한 AP값은 1이 된다.
    \\
    예를 들어, Dog에 대한 AP를 각각의 Detector에서 구해보자.
    먼저 Detector #1에서 Dog에 대한 AP를 구해보면, confidence threshold가 99% 이상 일 때 recall이 1이고, precision이 1이다. recall이 1일 때 precisioneh 1이기 때문에 AP는 1이다.
    \\
    Detector #2에서 Dog에 대한 AP도 구해보자.
    Confidence threshold 가 90% 이상 일 때 Precision은 1이고, recall도 1이다. 마찬가지로 AP값은 1 이다. 물론 Confidence threshold가 89% 이상일 때는 precision이 1/2이고, recall은 1이다. 하지만 이미 recall이 1일 때 precision이 1이 나왔기 때문에 저 값은 AP를 구하는 과정에서 그래프가 보정되므로 AP값에 영향이 없다.
    \\
    이런 식으로 각 Detector에서 class별 AP를 구하고 평균을 내면 (=mAP) 1이 나온다.

Code

환경

  • python 3.8.16
  • pytorch 2.1.0
  • torchinfo 1.8.0

구현

import torch
from torch import nn
from torchinfo import summary
# Build Model
class BasicConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = (kernel_size - 1) // 2, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1, inplace = True)
        )
    
    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.residual = nn.Sequential(
            BasicConv(channels, channels // 2, 1),
            BasicConv(channels // 2, channels, 3),
        )
    
    def forward(self, x):
        return self.residual(x) + x

class DarkNet53(nn.Module):
    def __init__(self):
        super().__init__()

        self.first_conv_block = BasicConv(3, 32, 3)

        self.residual_block_01 = nn.Sequential(
            BasicConv(32, 64, 3, stride = 2),
            ResidualBlock(64),
        )
        self.residual_block_02 = nn.Sequential(
            BasicConv(64, 128, 3, stride = 2),
            nn.Sequential(*[ResidualBlock(128) for _ in range(2)]),
        )
        self.residual_block_03 = nn.Sequential(
            BasicConv(128, 256, 3, stride = 2),
            nn.Sequential(*[ResidualBlock(256) for _ in range(8)]),
        )
        self.residual_block_04 = nn.Sequential(
            BasicConv(256, 512, 3, stride = 2),
            nn.Sequential(*[ResidualBlock(512) for _ in range(8)]),
        )
        self.residual_block_05 = nn.Sequential(
            BasicConv(512, 1024, 3, stride = 2),
            nn.Sequential(*[ResidualBlock(1024) for _ in range(4)]),
        )
    
    def forward(self, x):
        x = self.first_conv_block(x)
        x = self.residual_block_01(x)
        x = self.residual_block_02(x)
        feature_map_01 = self.residual_block_03(x)
        feature_map_02 = self.residual_block_04(feature_map_01)
        feature_map_03 = self.residual_block_05(feature_map_02)
        return feature_map_01, feature_map_02, feature_map_03

class YoloBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.route_conv = nn.Sequential(
            BasicConv(in_channels, out_channels, 1),
            BasicConv(out_channels, out_channels * 2, 3),
            BasicConv(out_channels * 2, out_channels, 1),
            BasicConv(out_channels, out_channels * 2, 3),
            BasicConv(out_channels * 2, out_channels, 1),
        )
        
        self.output_conv = BasicConv(out_channels, out_channels * 2, 3)

    def forward(self, x):
        route = self.route_conv(x)
        output = self.output_conv(route)
        return route, output

class DetectionLayer(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.num_classes = num_classes
        self.pred = nn.Conv2d(2 * in_channels, (num_classes + 5) * 3, 1)

    def forward(self, x):
        output = self.pred(x)
        output = output.view(x.size(0), 3, self.num_classes + 5, x.size(2), x.size(3))
        output = output.permute(0, 1, 3, 4, 2)
        return output

class Upsampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.upsample = nn.Sequential(
            BasicConv(in_channels, out_channels, 1),
            nn.Upsample(scale_factor = 2)
        )

    def forward(self, x):
        return self.upsample(x)

class Yolov3(nn.Module):
    def __init__(self, num_classes = 20):
        super().__init__()

        self.num_classes = num_classes

        self.darknet53 = DarkNet53()

        self.yolo_block_01 = YoloBlock(1024, 512)
        self.detectlayer_01 = DetectionLayer(512, num_classes)
        self.upsample_01 = Upsampling(512, 256)

        self.yolo_block_02 = YoloBlock(512 + 256, 256)
        self.detectlayer_02 = DetectionLayer(256, num_classes)
        self.upsample_02 = Upsampling(256, 128)

        self.yolo_block_03 = YoloBlock(256 + 128, 128)
        self.detectlayer_03 = DetectionLayer(128, num_classes)
    
    def forward(self, x):

        self.feature_map_01, self.feature_map_02, self.feature_map_03 = self.darknet53(x)

        x, output_01 = self.yolo_block_01(self.feature_map_03)
        output_01 = self.detectlayer_01(output_01)
        x = self.upsample_01(x)

        x, output_02 = self.yolo_block_02(torch.cat([x, self.feature_map_02], dim = 1))
        output_02 = self.detectlayer_02(output_02)
        x = self.upsample_02(x)

        x, output_03 = self.yolo_block_03(torch.cat([x, self.feature_map_01], dim = 1))
        output_03 = self.detectlayer_03(output_03)


        return output_01, output_02, output_03
x = torch.randn((1, 3, 416, 416))
model = Yolov3(num_classes = 20)
out = model(x)
print(out[0].shape) # torch.Size([1, 3, 13, 13, 25])
print(out[1].shape) # torch.Size([1, 3, 26, 26, 25])
print(out[2].shape) # torch.Size([1, 3, 52, 52, 25])
summary(model, input_size = (2, 3, 416, 416), device = "cpu")
#### OUTPUT ####
==============================================================================================================
Layer (type:depth-idx)                                       Output Shape              Param #
==============================================================================================================
Yolov3                                                       [2, 3, 13, 13, 25]        --
├─DarkNet53: 1-1                                             [2, 256, 52, 52]          --
│    └─BasicConv: 2-1                                        [2, 32, 416, 416]         --
│    │    └─Sequential: 3-1                                  [2, 32, 416, 416]         928
│    └─Sequential: 2-2                                       [2, 64, 208, 208]         --
│    │    └─BasicConv: 3-2                                   [2, 64, 208, 208]         18,560
│    │    └─ResidualBlock: 3-3                               [2, 64, 208, 208]         20,672
│    └─Sequential: 2-3                                       [2, 128, 104, 104]        --
│    │    └─BasicConv: 3-4                                   [2, 128, 104, 104]        73,984
│    │    └─Sequential: 3-5                                  [2, 128, 104, 104]        164,608
│    └─Sequential: 2-4                                       [2, 256, 52, 52]          --
│    │    └─BasicConv: 3-6                                   [2, 256, 52, 52]          295,424
│    │    └─Sequential: 3-7                                  [2, 256, 52, 52]          2,627,584
│    └─Sequential: 2-5                                       [2, 512, 26, 26]          --
│    │    └─BasicConv: 3-8                                   [2, 512, 26, 26]          1,180,672
│    │    └─Sequential: 3-9                                  [2, 512, 26, 26]          10,498,048
│    └─Sequential: 2-6                                       [2, 1024, 13, 13]         --
│    │    └─BasicConv: 3-10                                  [2, 1024, 13, 13]         4,720,640
│    │    └─Sequential: 3-11                                 [2, 1024, 13, 13]         20,983,808
├─YoloBlock: 1-2                                             [2, 512, 13, 13]          --
│    └─Sequential: 2-7                                       [2, 512, 13, 13]          --
│    │    └─BasicConv: 3-12                                  [2, 512, 13, 13]          525,312
│    │    └─BasicConv: 3-13                                  [2, 1024, 13, 13]         4,720,640
│    │    └─BasicConv: 3-14                                  [2, 512, 13, 13]          525,312
│    │    └─BasicConv: 3-15                                  [2, 1024, 13, 13]         4,720,640
│    │    └─BasicConv: 3-16                                  [2, 512, 13, 13]          525,312
│    └─BasicConv: 2-8                                        [2, 1024, 13, 13]         --
│    │    └─Sequential: 3-17                                 [2, 1024, 13, 13]         4,720,640
├─DetectionLayer: 1-3                                        [2, 3, 13, 13, 25]        --
│    └─Conv2d: 2-9                                           [2, 75, 13, 13]           76,875
├─Upsampling: 1-4                                            [2, 256, 26, 26]          --
│    └─Sequential: 2-10                                      [2, 256, 26, 26]          --
│    │    └─BasicConv: 3-18                                  [2, 256, 13, 13]          131,584
│    │    └─Upsample: 3-19                                   [2, 256, 26, 26]          --
├─YoloBlock: 1-5                                             [2, 256, 26, 26]          --
│    └─Sequential: 2-11                                      [2, 256, 26, 26]          --
│    │    └─BasicConv: 3-20                                  [2, 256, 26, 26]          197,120
│    │    └─BasicConv: 3-21                                  [2, 512, 26, 26]          1,180,672
│    │    └─BasicConv: 3-22                                  [2, 256, 26, 26]          131,584
│    │    └─BasicConv: 3-23                                  [2, 512, 26, 26]          1,180,672
│    │    └─BasicConv: 3-24                                  [2, 256, 26, 26]          131,584
│    └─BasicConv: 2-12                                       [2, 512, 26, 26]          --
│    │    └─Sequential: 3-25                                 [2, 512, 26, 26]          1,180,672
├─DetectionLayer: 1-6                                        [2, 3, 26, 26, 25]        --
│    └─Conv2d: 2-13                                          [2, 75, 26, 26]           38,475
├─Upsampling: 1-7                                            [2, 128, 52, 52]          --
│    └─Sequential: 2-14                                      [2, 128, 52, 52]          --
│    │    └─BasicConv: 3-26                                  [2, 128, 26, 26]          33,024
│    │    └─Upsample: 3-27                                   [2, 128, 52, 52]          --
├─YoloBlock: 1-8                                             [2, 128, 52, 52]          --
│    └─Sequential: 2-15                                      [2, 128, 52, 52]          --
│    │    └─BasicConv: 3-28                                  [2, 128, 52, 52]          49,408
│    │    └─BasicConv: 3-29                                  [2, 256, 52, 52]          295,424
│    │    └─BasicConv: 3-30                                  [2, 128, 52, 52]          33,024
│    │    └─BasicConv: 3-31                                  [2, 256, 52, 52]          295,424
│    │    └─BasicConv: 3-32                                  [2, 128, 52, 52]          33,024
│    └─BasicConv: 2-16                                       [2, 256, 52, 52]          --
│    │    └─Sequential: 3-33                                 [2, 256, 52, 52]          295,424
├─DetectionLayer: 1-9                                        [2, 3, 52, 52, 25]        --
│    └─Conv2d: 2-17                                          [2, 75, 52, 52]           19,275
==============================================================================================================
Total params: 61,626,049
Trainable params: 61,626,049
Non-trainable params: 0
Total mult-adds (G): 65.43
==============================================================================================================
Input size (MB): 4.15
Forward/backward pass size (MB): 1229.50
Params size (MB): 246.50
Estimated Total Size (MB): 1480.15
==============================================================================================================
# Anchors
ANCHORS = [ 
    [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)], 
    [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)], 
    [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)], 
]

GRID_SIZE = [13, 26, 52] 
scaled_anchors = torch.tensor(ANCHORS) / ( 
    1 / torch.tensor(GRID_SIZE).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) 
) 
print(scaled_anchors, scaled_anchors.shape)
"""
tensor([[[ 3.6400,  2.8600],
         [ 4.9400,  6.2400],
         [11.7000, 10.1400]],

        [[ 1.8200,  3.9000],
         [ 3.9000,  2.8600],
         [ 3.6400,  7.5400]],

        [[ 1.0400,  1.5600],
         [ 2.0800,  3.6400],
         [ 4.1600,  3.1200]]]) torch.Size([3, 3, 2])
"""
# Loss function
# 출처: https://www.geeksforgeeks.org/yolov3-from-scratch-using-pytorch/

def iou(box1, box2, is_pred = True):
    if is_pred:
        # IoU score for prediction and label
        # box1 (prediction) and box2 (label) are both in [x, y, width, height] format

        box1_x_center = box1[..., 0:1]; box2_x_center = box2[..., 0:1]
        box1_y_center = box1[..., 1:2]; box2_y_center = box2[..., 1:2]
        box1_width = box1[..., 2:3]; box2_width = box2[..., 2:3]
        box1_height = box1[..., 3:4]; box2_height = box2[..., 3:4]

        # Box coordinates for prediction
        box1_xmin = box1_x_center - box1_width / 2
        box1_ymin = box1_y_center - box1_height / 2
        box1_xmax = box1_x_center + box1_width / 2
        box1_ymax = box1_y_center + box1_height / 2

        # Box coordinates for ground truth
        box2_xmin = box2_x_center - box2_width / 2
        box2_ymin = box2_y_center - box2_height / 2
        box2_xmax = box2_x_center + box2_width / 2
        box2_ymax = box2_y_center + box2_height / 2

        # Get the coordinates of the intersection rectangle
        its_xmin = torch.max(box1_xmin, box2_xmin)
        its_ymin = torch.max(box1_ymin, box2_ymin)
        its_xmax = torch.min(box1_xmax, box2_xmax)
        its_ymax = torch.min(box1_ymax, box2_ymax)

        # Calculate Intersection area (min: 0)
        intersection_area = (its_xmax - its_xmin).clamp(min = 0) * (its_ymax - its_ymin).clamp(min = 0)

        # Calculate the union area
        box1_area = abs(box1_width * box1_height)
        box2_area = abs(box2_width * box2_height)
        union = box1_area + box2_area - intersection_area

        # Calculate the IoU score
        epsilon = 1e-6
        iou_score = intersection_area / (union + epsilon)

        return iou_score
    
    else:
        # IoU score based on width and height of bounding boxes (If the two boxes have the same center coordinates)

        box1_width = box1[..., 0]; box2_width = box2[..., 0]
        box1_height = box1[..., 0]; box2_height = box2[..., 1]

        # Calculate interaction area
        intersection_area = torch.min(box1_width, box2_width) * torch.min(box2_width, box2_height)

        # Calculate union area
        box1_area = box1_width * box1_height
        box2_area = box2_width * box2_height
        union_area = box1_area + box2_area - intersection_area

        # Calculate the IoU score
        iou_score = intersection_area / union_area

        # Return IoU score
        return iou_score
    
def convert_cells_to_bboxes(predictions, anchors, s, is_predictions = True):
    batch_size = predictions.shape[0]
    num_anchors = len(anchors)
    box_predictions = predictions[..., 1:5]

    # If the input is predictions then we will pass the x and y coordinate
    # through sigmoid function and width and height to exponent function and
    # calculate the score and best class.
    if is_predictions:
        anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
        box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
        box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:] * anchors)
        scores = torch.sigmoid(predictions[..., 0:1])
        best_class = torch.argmax(predictions[..., 5:], dim = 1).unsqueeze(-1)
    
    # Else we will just calculate scores and best class.
    else:
        scores = predictions[..., 0:1]
        best_class = predictions[..., 5:6]
    
    # Calculate cell indices
    cell_indices = (
        torch.arange(s)
        .repeat(predictions.shape[0], 3, s, 1)
        .unsqueeze(-1)
        .to(predictions.device)
    )

    # Calculate x, y, width and height with proper scaling
    x = 1 / s * (box_predictions[..., 0:1] + cell_indices)
    y = 1 / s * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))

    width_height = 1 / s * box_predictions[..., 2:4]

    # Concatinating the values and reshaping them in
    # (BATCH_SIZE, num_anchors * S * S, 6) shape
    converted_bboxes = torch.cat(
        (best_class, scores, x, y, width_height), dim = -1
    ).reshape(batch_size, num_anchors * s * s, 6)

    # Returning the reshaped and converted bounding box list
    return converted_bboxes.tolist()

class YoloLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.ce = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, pred, target, anchors):

        # Identifying which cells in target have objects and which have no objects
        obj = target[..., 0] == 1
        no_obj = target[..., 0] == 0

        # Calculating No object loss
        no_object_loss = self.bce(
            (pred[..., 0:1][no_obj]), (target[..., 0:1][no_obj]),
        )

        # Reshaping anchors to match predictions
        anchors = anchors.reshape(1, 3, 1, 1, 2)

        # Box predict confidence
        box_preds = torch.cat([self.sigmoid(pred[..., 1:3]),
                               torch.exp(pred[..., 3:5]) * anchors], dim = -1)
        
        # Calculating IoU for prediction and target
        ious = iou(box_preds[obj], target[..., 1:5][obj]).detach()

        # Calculating Object loss
        object_loss = self.mse(self.sigmoid(pred[..., 0:1][obj]),
                               ious * target[..., 0:1][obj])
        
        # Predicted box coordinates
        pred[..., 1:3] = self.sigmoid(pred[..., 1:3])

        # Target box coordinates
        target[..., 3:5] = torch.log(1e-6 + target[..., 3:5] / anchors)

        # Calculating box coordinates
        box_loss = self.mse(pred[..., 1:5][obj], target[..., 1:5][obj])

        # Calculating class loss
        class_loss = self.ce((pred[..., 5:][obj]), target[..., 5:][obj].long())
        
        # Total loss
        return (
            box_loss
            + object_loss
            + no_object_loss
            + class_loss
        )

0개의 댓글

관련 채용 정보