DeepLab V3+ 는 2개의 Encoder를 제시하는데, 전체적인 모델의 구조는 위 그림과 같다.
DeepLab V2에서 제안된 방법으로, feature map으로부터 확장비율 (dilation rate, r)가 다른 Atrous convolution을 병렬로 적용한 뒤 다시 합쳐주는 방법.
그림처럼 확장비율 (dilation rate, r)을 6 ~ 24 까지 다양하게 변화하면서 다양한 receptive field를 볼 수 있도록 적용.
import torch
import torch.nn as nn
import torch.nn.functional as F
from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from modeling.aspp import build_aspp
from modeling.decoder import build_decoder
from modeling.backbone import build_backbone
class DeepLab(nn.Module):
def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
sync_bn=True, freeze_bn=False):
super(DeepLab, self).__init__()
if backbone == 'drn':
output_stride = 8
if sync_bn == True:
BatchNorm = SynchronizedBatchNorm2d
else:
BatchNorm = nn.BatchNorm2d
self.backbone = build_backbone(backbone, output_stride, BatchNorm)
self.aspp = build_aspp(backbone, output_stride, BatchNorm)
self.decoder = build_decoder(num_classes, backbone, BatchNorm)
self.freeze_bn = freeze_bn
def forward(self, input):
x, low_level_feat = self.backbone(input)
x = self.aspp(x)
x = self.decoder(x, low_level_feat)
x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
def freeze_bn(self):
for m in self.modules():
if isinstance(m, SynchronizedBatchNorm2d):
m.eval()
elif isinstance(m, nn.BatchNorm2d):
m.eval()
def get_1x_lr_params(self):
modules = [self.backbone]
for i in range(len(modules)):
for m in modules[i].named_modules():
if self.freeze_bn:
if isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
else:
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
def get_10x_lr_params(self):
modules = [self.aspp, self.decoder]
for i in range(len(modules)):
for m in modules[i].named_modules():
if self.freeze_bn:
if isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
else:
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
https://medium.com/hyunjulie/2%ED%8E%B8-%EB%91%90-%EC%A0%91%EA%B7%BC%EC%9D%98-%EC%A0%91%EC%A0%90-deeplab-v3-ef7316d4209d
https://wikidocs.net/143446
https://github.com/jfzhang95/pytorch-deeplab-xception