runner는 유지 할 것
model을 새로 정의해서, loss를 다르게 동작하도록 하기
pipeline을 수정해서, data_samples에 GT 노이즈가 들어가도록 하기
__get_item__(idx)
으로 이미지만 순차적으로 내뱉는 단순한 로더 새로 정의하기AddGaussian
: img -> noised img , noiseToTensorNormalize
: gaussian 더하기 전 미리 normalize하기PackDenoiseInput
: inputs
,data_samples
를 정의하기BaseSegmentor와 EncoderDecoder이 작성된 걸 참고해서
BaseDenoiser과 DDePEncoderDecoder를 작성하기
m2f를 그대로 output layer만 달아서 디노이징 해보기
SegDataPreProcessor
type을 썼음BaseDataPreProcessor
가 됨._run_forward
를 실행시키는 함수들. train일시 loss
mode, val과 test는 predict
모드_run_forward
는 data를 inputs
와 data_samples
로 쪼개서 단순히 forward를 호출하는 함수임.with optim_wrapper.optim_context(self):
data = self.data_preprocessor(data, True)
losses = self._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
optim_wrapper.update_params(parsed_losses)
return log_vars
extract_feat
, encode_decode
, loss
, predict
, _forward
(mode='tensor') PixelData
포맷으로 추가함.여기서 전략
Denoising의 경우 evaluation을 필요로 하지 않음. 어차피 호출이 안될거니까,DDePEncoderDecoder
는 Rein이 정의한FrozenBackboneEncoderDecoder
를 상속받아도 괜찮음. 대신, 이 경우 FPN을 neck으로 분리해야함. positional embedding은 꼭 튜닝 안해도 될듯.
Mask2FormerHead만 Mask2FormerDenoiseHead로 바꾸면 됨. 여기서, forward_train만 돌아갈 수 있게 loss 함수 동작에만 문제 없으면 됨.
구현상 transformer decoder는 무시한채로, transformer encoder만 load 하면 좋을듯
patch 16x16으로 해야 할 듯..? 그리고 fpn1 형태의 noise_predictor를 달아서 128,128,256을 512, 512, 3으로 만들어야 할듯
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
try:
from mmdet.models.dense_heads import \
Mask2FormerHead as MMDET_Mask2FormerHead
except ModuleNotFoundError:
MMDET_Mask2FormerHead = BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.structures.seg_data_sample import SegDataSample
from mmseg.utils import ConfigType, SampleList
@MODELS.register_module()
class Mask2FormerDenoiseHead(MMDET_Mask2FormerHead):
"""
Note that inference is not implemented.
Note that transformer decoder is not used.
"""
def __init__(self,
num_classes,
align_corners=False,
ignore_index=255,
**kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.align_corners = align_corners
self.out_channels = num_classes
self.ignore_index = ignore_index
feat_channels = kwargs['feat_channels']
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
self.
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
"""Perform forward propagation to convert paradigm from MMSegmentation
to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
normally. Specifically, ``batch_gt_instances`` would be added.
Args:
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two lists.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``labels``, each is
unique ground truth label id of images, with
shape (num_gt, ) and ``masks``, each is ground truth
masks of each instances of a image, shape (num_gt, h, w).
- batch_img_metas (list[dict]): List of image meta information.
"""
batch_img_metas = []
batch_gt_instances = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
gt_sem_seg = data_sample.gt_sem_seg.data
classes = torch.unique(
gt_sem_seg,
sorted=False,
return_inverse=False,
return_counts=False)
# remove ignored region
gt_labels = classes[classes != self.ignore_index]
masks = []
for class_id in gt_labels:
masks.append(gt_sem_seg == class_id)
if len(masks) == 0:
gt_masks = torch.zeros(
(0, gt_sem_seg.shape[-2],
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
else:
gt_masks = torch.stack(masks).squeeze(1).long()
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
batch_gt_instances.append(instance_data)
return batch_gt_instances, batch_img_metas
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Perform forward propagation and loss calculation of the decoder head
on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the upstream
network, each is a 4D-tensor.
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
train_cfg (ConfigType): Training config.
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
# batch SegDataSample to InstanceDataSample
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
batch_data_samples)
# forward
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
# loss
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
batch_gt_instances, batch_img_metas)
return losses
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tuple[Tensor]:
"""Test without augmentaton.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_img_metas (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
test_cfg (ConfigType): Test config.
Returns:
Tensor: A tensor of segmentation mask.
"""
batch_data_samples = [
SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
]
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
mask_cls_results = all_cls_scores[-1]
mask_pred_results = all_mask_preds[-1]
if 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape']
else:
size = batch_img_metas[0]['img_shape']
# upsample mask
mask_pred_results = F.interpolate(
mask_pred_results, size=size, mode='bilinear', align_corners=False)
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
mask_pred = mask_pred_results.sigmoid()
seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
return seg_logits