(2021)UNETR: Transformers for 3D Medical Image Segmentation

Gyuha Park·2021년 10월 14일
0

Paper Review

목록 보기
27/34
post-thumbnail
post-custom-banner

0. Abstract

Medical image segmentation 분야에서 contracting, expanding paths를 이용한 FCNNs(Fully Convolutional Neural Networks)이 좋은 결과를 가져다 줬다. 하지만 여전히 FCNNs의 locality라는 특징은 long-range의 spatial dependency를 제한하는 한계가 있다.

최근에 NLP(Natural Language Processing) 분야에서 transformers의 성공적인 결과에 영향을 받아 본 논문에서는 3D medical image segmentation task를 sequence-to-sequence 문제로 접근하였다.

본 논문에서는 UNETR(UNEt TRansformers)라는 새로운 architecture를 제안하였으며 encoder에서 sequence representations를 학습함과 동시에 UNet과 같이 encoder, decoder 형식을 따른다.

BTCV, MSD dataset에 성능을 검증하였으며 특히 BTCV dataset에서 SOTA를 달성하였다.

1. Introduction

Image segmentation 분야에서는 FCNNs 방식 중 UNet 기반의 architecture들이 좋은 결과를 가져다 줬었다. 그러나 비록 FCNNs 기반의 방식이 좋은 representation을 학습하지만 long-lange의 dependency가 local한 receptive fields에 의해 제한된다는 한계가 있다.

FCNNs의 한계를 극복하기 위해 ViT(Vision Transformer)와 같은 transformer 기반의 model들이 제안 되었다. 이에 영향을 받아 본 논문에서는 새로운 architecture UNETR을 제안하였다. ViT와 같이 embedded input patch로 부터 representation을 뽑기 위해 encoder에 transformer 구조를 사용하였다. 그리고 encoder로 부터 얻은 representation을 CNN 기반의 decoder에 skip connection으로 연결해 최종 segmentation output을 얻는 구조이다.

UNETR은 서로 다른 BTCV, MSD 3D segmentation task에서 검증 과정을 거쳤으며 BTCV dataset에서 SOTA를 달성하였다.

CNN 기반의 architecture는 2D, 3D medical image segmentation에서 SOTA를 달성하였다.

특히 3D segmentation에서는 3가지 view의 slice를 이용하는 2.5D 방식과 반대로 3D volume을 그대로 사용하는 방식이 있다.

하지만 global context와 long-range spatial dependencies를 학습하는데 있어 한계가 있다.

이러한 문제를 해결하기 위해 transformer 기반의 많은 model들이 제안 되었다. 하지만 앞 선 model들과는 달리 UNETR은 volumetric data를 바로 입력으로 넣는다는 차이점이 있고, 두 번째로 transformer를 main encoder로 사용하고 CNN 기반의 decoder에 skip connection으로 연결한다는 차이점이 있다. 마지막으로 input sequence를 생성하기 위해 CNN backbone을 사용하지 않으며 tokenized patches를 바로 사용한다.

3. Methodology

1) Architecture

UNETR은 stacked transformer에서 skip connection을 이용해 decoder로 연결되는 구조를 갖고 있다.

3D input volume xRH×W×D×C\text{x}\in\mathbb{R}^{H\times W\times D\times C}는 non-overlapping patches인 xvRN×(P3C)\text{x}_v\in\mathbb{R}^{N\times(P^3\cdot C)}로 나눠진다. 이 때 N=(H×W×D)/P3N=(H\times W\times D)/P^3이다.

xv\text{x}_v는 다시 linear layer를 거쳐 KK dimensional embedding space로 project된다.

ER(P3C)×K\text{E}\in\mathbb{R}^{(P^3\cdot C)\times K}

그리고 spatial information을 보존하기 위해 1D의 learnable positional embedding EposRN×K\text{E}_{pos}\in\mathbb{R}^{N\times K}를 더한다.

z0=[xv1E;xv2E;;xvNE]+Epos\text{z}_0=[\text{x}_v^1\text{E};\text{x}_v^2\text{E};\ldots;\text{x}_v^N\text{E}]+\text{E}_{pos}

UNETR의 backbone은 semantic segmentation을 위해 만들어졌기 때문에 class token은 따로 사용하지 않는다.

Embedding layer 후, MSA(Multi-head Self Attention)와 MLP로 구성된 LL개의 stacked transformer block을 지나게 된다.

zi=MSA(Norm(zi1))+zi1,   (i=1L)\text{z}'_i=\text{MSA}(\text{Norm}(\text{z}_{i-1}))+{z}_{i-1},\ \ \ (i=1\ldots L)

zi=MLP(Norm(zi))+zi,   (i=1L)\text{z}_i=\text{MLP}(\text{Norm}(\text{z}'_i))+\text{z}'_i,\ \ \ (i=1\ldots L)

UNETR은 U-Net에 영감을 받았다. encoder로 부터 얻은 representation zi   (i{3,6,9,12})\text{z}_i\ \ \ (i\in\{3,6,9,12\})는 Deconv layer를 거쳐 Conv layer인 decoder에 skip connection으로 연결되어 최종 segmentation output을 얻는다.

2) Loss Function

Loss function은 soft dice loss와 cross-entropy loss를 조합하여 사용하였다.

L(G,Y)=12Jj=1Ji=1IGi,jYi,ji=1IGi,j2+i=1IYi,j21Ii=1Ij=1JGi,jlogYi,jL(G,Y)=1-\cfrac{2}{J}\sum\limits_{j=1}^J\cfrac{\sum_{i=1}^IG_{i,j}Y_{i,j}}{\sum_{i=1}^IG_{i,j}^2+\sum_{i=1}^IY_{i,j}^2}-\cfrac{1}{I}\sum\limits_{i=1}^I\sum\limits_{j=1}^JG_{i,j}\log Y_{i,j}

II는 voxel의 개수, JJ는 class 개수, Yi,jY_{i,j}Gi,jG_{i,j}는 class jj와 voxel ii에서 예측된 probability와 ground truth를 나타낸다.

4. Experiments

1) Datasets

  • BTCV (CT)
    BTCV dataset은 30 subjects로 구성되어 있으며 13개의 장기가 annotation되어 있다. 각각의 CT scan은 8080~255255 slice, 512×512512\times512 pixel의 조영 CT이다. thickness는 11~6 mm6\ mm이다. 모든 image는 1.0 mm1.0\ mm의 voxel space로 resample되었다.
  • MSD (MRI/CT)
    MSD dataset에서 brain tumoer segmentation task는 484개의 multi-modal, multi-site MRI data (FLAIR, T1w, T1gd, T2w)이다. 그리고 necrotic/active tumor와 oedema가 annotation되어 있다. voxel space는 1.0×1.0×1.0 mm31.0 \times1.0\times1.0\ mm^3로 사용되었다. Spleen segmentation task는 41개의 CT volume으로 구성되어 있으며 spleen body가 annotation되어 있다.

2) Evaluation Metrics

Dice score와 95% Hausdorff Distance(HD)를 사용하였다.

Dice(G,P)=2i=1IGiPii=1IGi+i=1IPi\text{Dice}(G,P)=\cfrac{2\sum_{i=1}^IG_iP_i}{\sum_{i=1}^IG_i+\sum_{i=1}^IP_i}

HD(G,P)=max{maxgGminpPgp, maxpPmingGpg}\text{HD}(G',P')=\max\{\max\limits_{g'\in G'}\min\limits_{p'\in P'}||g'-p'||,\ \max\limits_{p'\in P'}\min\limits_{g'\in G'}||p'-g'||\}

3) Implementation Details

  • Batch size: 6
  • Optimizer: AdamW
  • Learning rate: 0.0001
  • Iteration: 20000
  • Backbone: ViT-B16
  • L=12, K=768, P=16×16×1616\times16\times16
  • Augmentation: Random rotation(90°, 180°, 270°90\degree,\ 180\degree,\ 270\degree), Random flip(axial, sagittal, coronal views), Random scale, Shift intensity
  • Ensemble: Five-fold cross-validation

4) Quantitative Evaluations

UNETR은 BTCV dataset에서 SOTA를 달성했다.

UNETR은 MSD dataset에서도 좋은 성능을 보여주고 있다.

5) Qualitative Results

UNETR은 abdomen organs task에서 좋은 성능을 보여준다. row 3에서 nnUNet이 liver를 stoach tissues로 착각한 반면에 UNETR은 organs의 경계를 잘 구분하고 있다.

row 2에서는 UNETR만 유일하게 kidney와 adrenal glands를 잘 구분하고 있다.

MSD dataset에도 UNETR만이 tumor의 fine-grained details를 잘 capture하고 있다.

5. Discussion

실험 결과, UNETR은 CNN과 transformer 기반의 model들 보다 뛰어난 성능을 얻었다. 특히 UNETR은 global, local dependencies를 잘 capture함으로 더 좋은 segmentation 성능을 얻었다.

UNETR은 BTCV dataset에서 SOTA를 달성함으로 효과를 입증했으며 gallbladder, adrenal glands와 같이 작은 organs에서도 좋은 성능을 보여주고 있다.

6. Ablation

1) Decoder Choice

Encoder는 모두 UNETR의 encoder를 사용하고 decoder를 NUP(Naive UPsampling), PUP(Progressive UPsampling), MLA(MuLti-scale Aggregation)으로 대체하여 비교하였다.

실험 결과, UNETR의 decoder가 가장 좋은 성능을 얻었다.

2) Patch Resolution

Patch의 resolution을 32에서 16으로 줄인 결과 약간의 성능의 향상이 있었다.

3) Model and Computational Complexity

Parameter size는 UNETR이 가장 크지만 FLOPs에서 CNN 기반의 model들 보다도 성능이 뛰어나면서 적당한 model complexity를 갖고 있다. 게다가 inference time은 다른 transformer 기반의 model들 보다도 빠르다.

7. Conclusion

본 논문은 semantic segmentation task에서 새로운 transformer 기반의 architecture UNETR을 제안하였다.

UNETR은 encoder에 transformer를 사용함으로 model의 long-range dependencies를 학습하는 능력을 올렸으며 효과적으로 global contextual representation을 capture 할 수 있다.

BTCV, MSD dataset에서 좋은 성능을 얻었으며 BTCV dataset에서는 SOTA를 달성하였다.

제안 된 architecture는 transformer 기반의 medical image segmentation model에 새로운 foundation이 될 것이다.

post-custom-banner

0개의 댓글