(2021)Convolution-Free Medical Image Segmentation using Transformers

Gyuha Park·2021년 9월 24일
0

Paper Review

목록 보기
24/34
post-thumbnail
post-custom-banner

0. Abstract

본 논문에서는 convolution 연산을 사용하지 않고 이웃 image patches간의 self-attention 기반의 model을 제안하였고 경쟁력 있는 결과를 얻었다. 주어진 3D block에서 model은 n3n^3의 3D patches로 나누고 center patch의 segmentation map을 예측한다. 또한 unlabeled images로 pre-training하는 방법을 제안하였고 labeled training data가 적을 때 CNN 기반의 모델보다 좋은 결과를 얻었다.

1. Introduction

Medical image segmentation의 고전적인 방법은 region growing, atlas 기반의 deformable models, bayesian 접근법, graph cuts, clustering 등이 있었다. 현재는 deep learning 방법이 대표적인 medical image segmentation의 방법이 되었다.

Deep learning은 model의 architecture, loss function, training strategies를 고려할 필요가 있다. 놀랍게도 알려진 많은 medical image segmentation의 model은 architecture가 CNN 기반이다. 본 논문에서는 convolution 연산과 arranged되는 방식이 다르지만 convolution operation에 기반하고 있다. 이전에 recurrent, attention mechanism을 적용하려는 시도도 있었지만 CNN 기반의 model에 미치지 못했다.

NLP 분야에서 예전에는 RNN이 지배적 이였지만 transformer로 대체되었다. 그리고 최근에는 computer vision 분야에도 transformer를 적용하려는 시도들이 많이 있었다. ViT는 embeds image patches를 이용한 transformer기반으로 SOTA를 달성한 classification model이다.

본 논문의 목적은 3D medical image segmentation을 위해 attention 기반의 deep learning model의 잠재력을 탐구하는 것이다. 제안 된 model은 3D image patches의 linear embedding 간의 self-attention에 기반하고 있다.

본 논문의 contribution은 다음과 같다.

  • Convolution-free인 deep learning model을 제안하였다.
  • 서로 다른 세 가지의 medical image segmentation dataset에서 SOTA인 CNN 기반의 모델보다 더 좋은 성능을 보여줬다. Image classification과 달리 image segmentation은 20 ~ 200장의 labeled images만으로 효과적으로 학습할 수 있었다.
  • 많은 양의 unlabeled training dataset을 사용 가능할 때 pre-training 방법이 segmentation model의 성능을 높일 수 있었다. Labeled image가 적을 때 pre-training에서 CNN 기반의 model보다 좋은 결과를 얻었다.

2. Materials and Methods

1) Proposed Network

Model의 input은 3D block BRW×W×W×cB\in\mathop{\mathbb{R}}^{W\times W\times W \times c}이다. WW는 block의 범위를 나타낸다. cc는 image의 channel 수를 나타낸다.

Block BB는 겹치지 않는 n3n^3개의 3D patches {piw×w×w×c}i=1N\{p_i\in^{w\times w\times w \times c}\}_{i=1}^N로 나뉜다. w=W/nw=W/n이며 N=n3N=n^3이다. nn은 홀수만 선택한다. 실험에서는 3 또는 5를 선택하였다.

Model은 모든 NN개의 patch의 information을 사용해 center patch의 segmentation map을 예측한다.

각각의 NN개의 patches는 Rw3c\mathop{\mathbb{R}}^{w^3c} size의 vector로 flatten되며 ERD×w3cE\in\mathop{\mathbb{R}}^{D\times w^3c} size의 학습 가능한 mapping에 의해 RD\mathop{\mathbb{R}}^D로 mapping된다.

연속된 patches X0=[Ep1;;EpN]+EposX^0=[E_{p_1};\ldots;E_{p_N}]+E_{pos}는 model의 input이 된다. Matrix EposRD×NE_{pos}\in\mathop{\mathbb{R}}^{D\times N}은 학습 가능한 positional encoding이다. 학습 가능한 상태로 둔 이유는 본 model에 적절한 positional encoding을 알 수 없었기 때문이다.

Encoder는 MSA(Multi-head Self-Attention), 두 개의 FFN(Fully connected Feed-forward Network)로 구성된 KK개의 stage를 가진다.

위에서 설명 한 것처럼 position-encoded patches인 X0X^0 부터 시작하여 kthk^{th} stage의 XkX^{k}에서 Xk+1X^{k+1}로 operation하는 과정은 아래와 같다.

a) XkX^k는 MSA에서 입력 시 분리된 nhn_h개의 heads를 통과한다. ithi^{th} head라고 가정하자.

Input sequences에 대해 query, key, value를 계산한다.

Qk,i=EQk,iXk, Kk,i=EKk,iXk, Vk,i=EVk,iXkQ^{k,i}=E_Q^{k,i}X^k,\ K^{k,i}=E_K^{k,i}X^k,\ V^{k,i}=E_V^{k,i}X^k

where EQ, EK, EVRDh×D\text{where}\ E_Q,\ E_K,\ E_V\in\mathop{\mathbb{R}}^{D_h\times D}

Self-attention matrix를 계산하고 normalize한다.

Ak,i=Softmax(QTK)/DhA^{k,i}=\text{Softmax}(Q^TK)/\sqrt{D_h}

SAk,i=Ak,iVk,i\text{SA}^{k,i}=A^{k,i}V^{k,i}

b) nhn_h개의 heads는 stacked 되고 RD\mathop{\mathbb{R}}^D로 reprojected 된다.

MSAk=Ereprojk[SAk,0;;SAk,nh]T\text{MSA}^k=E_{\text{reproj}}^k[\text{SA}^{k,0};\ldots;\text{SA}^{k,n_h}]^T

where EreprojRD×Dh×nh\text{where}\ E_{\text{reproj}}\in\mathop{\mathbb{R}}^{D\times D_h\times n_h}

c) MSA의 output은 다음과 같이 계산된다.

XMSAk=MSAk+XkX_{\text{MSA}}^k=\text{MSA}^k+X^k

d) XMSAkX_{\text{MSA}}^kkthk^{th} encoder stage의 output을 얻기 위해 두 개의 FFN을 통과한다.

Xk+1=XMSAk+E2k(ReLU((E1kXMSAk+b1k))+b2kX^{k+1}=X_{\text{MSA}}^k+E_2^k(\text{ReLU}((E_1^kX_{\text{MSA}}^k+b_1^k))+b_2^k

e) Center patch의 segmentation map Y^\hat{Y}는 다음 과 같이 계산된다.

마지막 encoder stage의 output XKX^KRNnclass\mathop{\mathbb{R}}^{Nn_{class}}로 project하는 FFN을 통과하고 Rn×n×n×nclass\mathop{\mathbb{R}}^{n\times n\times n\times n_{class}}로 reshape 한다. (Binary segmentation인 경우 nclass=2n_{class}=2)

Y^=Softmax(EoutXK+bout))\hat{Y}=\text{Softmax}(E_{out}X^K+b_{out}))

2) Implementation and training

Model은 Tensorflow 1.16으로 구현되었으며 NVIDIA GeForce GTX 1080 GPU, 120 GB memroy, 16 CPU core를 사용하였다.

SOTA인 3D UNet++와 medical image segmentation 성능을 비교하였다.

Model의 parameter는 learning rate 1e-4의 Adam optimizer를 사용하여 ground-truth와 Y^\hat{Y} 사이의 DSC (Dice similarity coefficient)가 최대가 되도록 학습하였다. Learning rate는 validation loss가 감소하지 않을 때 마다 절반을 줄였다.

Labeled training images가 부족할 때 성능을 높이기 위해 denoising 또는 inpainting task로 unlabeled dataset를 학습하는 pre-train 방식을 제안하였다.

두 task에서 output layer에 softmax는 제거하였으며 대신 l2l_2 norm으로 대체하였다. Target은 noise-free의 image이다.

3) Data

위 표는 사용한 dataset을 보여준다. Training data의 1/5는 validation으로 사용하였다.

3. Results and Discussion

본 논문의 model과 UNet++의 DSC, HD95(the 95 percentile of the Hausdorff Distance), ASSD(Average Symmetric Surface Distance)를 비교하였다. 이 실험에서 parameter setting은 다음과 같다.

K=7, W=24, n=3, D=1024, Dh=256, nh=4K=7,\ W=24,\ n=3,\ D=1024,\ D_h=256,\ n_h=4

위 표에서 확인할 수 있듯이 제안 된 convolution-free 방법이 UNet++에 비해 훨씬 좋은 결과를 보여준다.

위 그림은 제안 된 방법과 UNet++의 예측 결과를 visualization한 결과이다.

또한 적은 양의 labeled training images에서 좋은 결과를 얻기 위해 pre-train 방법을 적용하였다.

Cortical plate, pancreas dataset의 5, 10, 15장의 labeled training images를 사용하였으며 cortical plate는 dHCP dataset의 500장의 unlabeled images를 pre-training에 사용하였고 pancreas는 231장의 unlabeled training images를 사용하였다.

위 그림은 pre-train 결과를 보여준다. Convolution-free model이 적은 labeled training images에 대해 UNet++보다 좋은 결과를 보여준다.

위 그림의 attention maps는 early stage에서는 넓은 attention scope를 가지고 deep한 stage일수록 특정 영역에 집중하는 것을 볼 수 있다. 그리고 서로 다른 heads에서는 다양한 패턴의 attention을 출력하고 있다. 이를 통해 MSA의 multi-head 방식이 attention 패턴을 더 잘 학습하도록 하여 segmentation의 accuracy가 높게 나옴을 추측할 수 있다.

위 표는 model design에서 몇 가지 선택 사항들이 모델의 성능에 미치는 영향을 pancreas dataset에서 보여준다. Baseline의 parameter는 다음과 같다.

K=7, W=24, n=3, D=1024, Dh=256, nh=4K=7,\ W=24,\ n=3,\ D=1024,\ D_h=256,\ n_h=4

또한 Patch의 size와 number를 늘리거나 model의 depth를 높이면 약간의 성능 향상을 볼 수 있었다. 게다가 fixed positional encoding 또는 no positional encoding을 적용한 경우 약간의 성능 하락이 있었다. 마지막으로 single-head attention을 적용한 경우 약간의 성능 하락이 있었다.

4. Conclusion

본 논문에서는 새로운 3D medical image segmentation model을 제안하였다. 최근 모든 model들은 convolution을 사용하여 block을 쌓았다. 하지만 제안 된 model은 3D patches간의 self-attention 기반의 model이다. 실험 결과 제안 된 model은 세 가지의 medical image segmentation dataset에서 SOTA를 달성하였다. 그리고 unlabeled images의 denoising, inpainting task에서의 pre-train 결과 5 ~ 15 labeled training images에서 CNN 기반의 model보다 좋은 성능을 보여줬다.

post-custom-banner

0개의 댓글