DDeP in MMSeg 전략

우병주·2024년 11월 26일
0

runner는 유지 할 것
model을 새로 정의해서, loss를 다르게 동작하도록 하기
pipeline을 수정해서, data_samples에 GT 노이즈가 들어가도록 하기

I. Dataset 전처리 관련

1. Dataloader

  • ImageNet-1k를 __get_item__(idx) 으로 이미지만 순차적으로 내뱉는 단순한 로더 새로 정의하기
  • CityscapesDataset의 get_item 그대로 사용해도 될지도?
  • 또는 I-JEPA의 dataloader 참고해보기.
  • pipeline 동작방식과 관련 있을수도 있음을 참고하기
  • 아니면 여기서 노이즈를 미리 만들어 GT로 사용한다??

2. Pipeline

  • pipeline에 custom pipeline을 추가하여 Denoising을 위한 Input을 만들어야 함.
  • AddGaussian: img -> noised img , noise
  • ToTensorNormalize: gaussian 더하기 전 미리 normalize하기
  • PackDenoiseInput: inputs,data_samples를 정의하기

II. 모델 전략

BaseSegmentor와 EncoderDecoder이 작성된 걸 참고해서
BaseDenoiser과 DDePEncoderDecoder를 작성하기
m2f를 그대로 output layer만 달아서 디노이징 해보기

1. BaseModule (mmengine)

  • nn.module상속
  • init_cfg로 init함.
  • init_weight 관련 기능 제공

2. BaseModel (mmengine)

  • BaseModule 상속
  • init_cfg와 data_preprocessor(dict) 로 init함 (super)
    • Rein은 SegDataPreProcessor type을 썼음
    • None일시 BaseDataPreProcessor가 됨.
  • BaseModel 상속한 모델은 @MODELS.register_module() 가능하며, forward 함수만 정의해주면 됨.
  • train_step, val_step, test_step: data를 받아서 data_preprocess를 거치고 _run_forward를 실행시키는 함수들. train일시 loss mode, val과 test는 predict 모드
  • _run_forward는 data를 inputsdata_samples로 쪼개서 단순히 forward를 호출하는 함수임.
  • train_step:
  • with optim_wrapper.optim_context(self):
        data = self.data_preprocessor(data, True)
        losses = self._run_forward(data, mode='loss')  # type: ignore
    parsed_losses, log_vars = self.parse_losses(losses)  # type: ignore
    optim_wrapper.update_params(parsed_losses)
    return log_vars
  • (1) data_preprocessor를 거치고 loss로 forward (2) backward에 쓸 parsed_losses와 logger에 줄 log_vars를 얻음 (3) optim_wrapper의 update_params가 모델 업데이트를 진행 (4) log_vars 리턴

3. BaseSegmentor (mmseg)

  • BaseModel을 상속
  • abstract method: extract_feat, encode_decode, loss, predict, _forward (mode='tensor')
  • forward가 모드에 따라 loss, predict, _forward를 호출
  • postprocess_result: seg_logits (B, C, H, W)을 후처리 함. 만약 train이라면, data_samples에 있는 정보로 패딩, 플립, 리사이징을 함. C차원이 2이상이면 argmax, 1이면 thresholding을 한 seg_pred또한 얻으며, seg_logits와 seg_pred를 data_samples에 PixelData 포맷으로 추가함.

4. EncoderDecoder (mmseg)

  • BaseSegmentor를 상속
  • init에서 super 및 backbone, neck, decode_head를 build, init
  • extract_feat: neck(backbone(x))
  • encode_decode: decode_head.predict(extract_feat(x, test_cfg))
  • decode_head_forward_train: 디코더의 loss를 호출 및 리턴
  • loss: extract_feat + decode_head_forward_train
  • predict: postprocess 있는 inference
  • _forward: postprocess 없는 predict: extract_feat + decode_head.forward
  • inference: whole이나 slide inference

여기서 전략
Denoising의 경우 evaluation을 필요로 하지 않음. 어차피 호출이 안될거니까, DDePEncoderDecoder는 Rein이 정의한 FrozenBackboneEncoderDecoder 를 상속받아도 괜찮음. 대신, 이 경우 FPN을 neck으로 분리해야함. positional embedding은 꼭 튜닝 안해도 될듯.

Mask2FormerHead만 Mask2FormerDenoiseHead로 바꾸면 됨. 여기서, forward_train만 돌아갈 수 있게 loss 함수 동작에만 문제 없으면 됨.
구현상 transformer decoder는 무시한채로, transformer encoder만 load 하면 좋을듯

patch 16x16으로 해야 할 듯..? 그리고 fpn1 형태의 noise_predictor를 달아서 128,128,256을 512, 512, 3으로 만들어야 할듯

5. Mask2FormerDenoiseHead (custom)

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

try:
    from mmdet.models.dense_heads import \
        Mask2FormerHead as MMDET_Mask2FormerHead
except ModuleNotFoundError:
    MMDET_Mask2FormerHead = BaseModule

from mmengine.structures import InstanceData
from torch import Tensor

from mmseg.registry import MODELS
from mmseg.structures.seg_data_sample import SegDataSample
from mmseg.utils import ConfigType, SampleList


@MODELS.register_module()
class Mask2FormerDenoiseHead(MMDET_Mask2FormerHead):
    """
    Note that inference is not implemented.
    Note that transformer decoder is not used.
    """

    def __init__(self,
                 num_classes,
                 align_corners=False,
                 ignore_index=255,
                 **kwargs):
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.align_corners = align_corners
        self.out_channels = num_classes
        self.ignore_index = ignore_index

        feat_channels = kwargs['feat_channels']
        self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
        self.

    def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
        """Perform forward propagation to convert paradigm from MMSegmentation
        to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
        normally. Specifically, ``batch_gt_instances`` would be added.

        Args:
            batch_data_samples (List[:obj:`SegDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_sem_seg`.

        Returns:
            tuple[Tensor]: A tuple contains two lists.

                - batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                    gt_instance. It usually includes ``labels``, each is
                    unique ground truth label id of images, with
                    shape (num_gt, ) and ``masks``, each is ground truth
                    masks of each instances of a image, shape (num_gt, h, w).
                - batch_img_metas (list[dict]): List of image meta information.
        """
        batch_img_metas = []
        batch_gt_instances = []

        for data_sample in batch_data_samples:
            batch_img_metas.append(data_sample.metainfo)
            gt_sem_seg = data_sample.gt_sem_seg.data
            classes = torch.unique(
                gt_sem_seg,
                sorted=False,
                return_inverse=False,
                return_counts=False)

            # remove ignored region
            gt_labels = classes[classes != self.ignore_index]

            masks = []
            for class_id in gt_labels:
                masks.append(gt_sem_seg == class_id)

            if len(masks) == 0:
                gt_masks = torch.zeros(
                    (0, gt_sem_seg.shape[-2],
                     gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
            else:
                gt_masks = torch.stack(masks).squeeze(1).long()

            instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
            batch_gt_instances.append(instance_data)
        return batch_gt_instances, batch_img_metas

    def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
             train_cfg: ConfigType) -> dict:
        """Perform forward propagation and loss calculation of the decoder head
        on the features of the upstream network.

        Args:
            x (tuple[Tensor]): Multi-level features from the upstream
                network, each is a 4D-tensor.
            batch_data_samples (List[:obj:`SegDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_sem_seg`.
            train_cfg (ConfigType): Training config.

        Returns:
            dict[str, Tensor]: a dictionary of loss components.
        """
        # batch SegDataSample to InstanceDataSample
        batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
            batch_data_samples)

        # forward
        all_cls_scores, all_mask_preds = self(x, batch_data_samples)

        # loss
        losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
                                   batch_gt_instances, batch_img_metas)

        return losses

    def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
                test_cfg: ConfigType) -> Tuple[Tensor]:
        """Test without augmentaton.

        Args:
            x (tuple[Tensor]): Multi-level features from the
                upstream network, each is a 4D-tensor.
            batch_img_metas (List[:obj:`SegDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_sem_seg`.
            test_cfg (ConfigType): Test config.

        Returns:
            Tensor: A tensor of segmentation mask.
        """
        batch_data_samples = [
            SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
        ]

        all_cls_scores, all_mask_preds = self(x, batch_data_samples)
        mask_cls_results = all_cls_scores[-1]
        mask_pred_results = all_mask_preds[-1]
        if 'pad_shape' in batch_img_metas[0]:
            size = batch_img_metas[0]['pad_shape']
        else:
            size = batch_img_metas[0]['img_shape']
        # upsample mask
        mask_pred_results = F.interpolate(
            mask_pred_results, size=size, mode='bilinear', align_corners=False)
        cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
        mask_pred = mask_pred_results.sigmoid()
        seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
        return seg_logits

2. mask2former_denoise.py

  • 모델 build는 EncoderDecoder로 그대로 사용하면서, Decoder 파트를 새로 정의한 친구로 바꿀 것임.
  • 모델 구조는 mask2former랑 동일함. mask2formerhead를 상속해서, loss를 다 제거하고 forward를 재정의할듯?
  • pixel decoder 마지막단에 Conv-GELU-LayerNorm-Conv layer 달아서 [144, 144, 256] -> [504, 503, 3] 으로 만들어주고, noise 예측으로

3. converter

  • 학습한 체크포인트 엔드투엔드로 그대로 전이해서 fine-tuning 가능할 듯.
  • 디코더 pre-trained 값을 어떻게 집어넣을 수 있는지 찾아봐야 할 듯

0개의 댓글

관련 채용 정보