DeepLabv3+

박광욱·2023년 7월 15일
0

Image Segmentation

목록 보기
2/6
post-thumbnail

📕 DeepLabv3+ Architecture


DeepLab V3+ 는 2개의 Encoder를 제시하는데, 전체적인 모델의 구조는 위 그림과 같다.

✍ DeepLab V3를 인코더로 사용한 모델

  1. Encoder에서 나온 최종 feature map에 대해 4배 bilinear upsample을 수행합니다.
  2. Encoder 중간에서 나온 feature map (Low-Level Features)을 1x1 convolution을 적용하여 channel을 줄입니다. (2번 과정을 통해 1번과 2번 feature map이 concatenation이 가능해집니다.)
  3. 1번과 2번 feature map을 concatenate합니다.
  4. 3x3 convolution layers를 거친 후 마지막 1x1 convolution을 거쳐 output을 뽑아냅니다.
  5. 4배 bilinear upsample을 수행하여 원래 input size로 복원된 최종 segmentated data를 출력합니다.

✍ 변형된 Xception 을 인코더로 사용한 모델

  1. 더 깊고
  2. Max Pooling을 stride가 있는 Depthwise Separable Convolution로 바꾸었고
  3. 3x3 Depthwise Convolution 후에 Batch Normalization 과 ReLU 를 더해주었습니다.

📗 Atrous Spatial Pyramid Pooling (ASPP)

DeepLab V2에서 제안된 방법으로, feature map으로부터 확장비율 (dilation rate, r)가 다른 Atrous convolution을 병렬로 적용한 뒤 다시 합쳐주는 방법.
그림처럼 확장비율 (dilation rate, r)을 6 ~ 24 까지 다양하게 변화하면서 다양한 receptive field를 볼 수 있도록 적용.

📘 DeepLabv3+ Code (Pytorch)

import torch
import torch.nn as nn
import torch.nn.functional as F
from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from modeling.aspp import build_aspp
from modeling.decoder import build_decoder
from modeling.backbone import build_backbone

class DeepLab(nn.Module):
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        self.freeze_bn = freeze_bn

    def forward(self, input):
        x, low_level_feat = self.backbone(input)
        x = self.aspp(x)
        x = self.decoder(x, low_level_feat)
        x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)

        return x

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, SynchronizedBatchNorm2d):
                m.eval()
            elif isinstance(m, nn.BatchNorm2d):
                m.eval()

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
    def get_10x_lr_params(self):
        modules = [self.aspp, self.decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

Reference

https://medium.com/hyunjulie/2%ED%8E%B8-%EB%91%90-%EC%A0%91%EA%B7%BC%EC%9D%98-%EC%A0%91%EC%A0%90-deeplab-v3-ef7316d4209d
https://wikidocs.net/143446
https://github.com/jfzhang95/pytorch-deeplab-xception

profile
Vancouver

0개의 댓글