
[UNETR 논문 리뷰] - UNETR: Transformers for 3D Medical Image Segmentation
Step1_3D U-NET이란?
💡 <한줄요약>의료분야에서 3차원 데이터를 segmentation 할떄 가장 기본이 되는 구조가 3D U-Net임!
💡 참고 : [https://kyujinpy.tistory.com/9](https://kyujinpy.tistory.com/9)기존의 U-NET은 2D 구조에 대해 segmenation하는 모델이다.
그러나 최근들어서 의료데이터 등에 적용되면서 3D 구조에 대해 Segmentation 하는 경우가 많아졌다.
2D이미지는 (H,W,3)인 차원이면, 의료데이터와 같은 3D이미지는(H,W,S,1) 차원이다. → 따라서 3d 구조 데이터는 pixel이 아닌 voxel 개념을 갖고 있다.
voxel unit : 1 X 1 X 1
즉 pixel이 아닌 voxel 개념을 가지고 와서 의료데이터 등등에 사용하는 것이 3D U-NET

참고) U-Net의 단계는 다음과 같음.
각 과정마다 나온 output을 모두 고려하기 때문에 그대로 둔다(?).
4. N번 반복한 최종 결과를 Up-convolution 시킨다.
여기서 이용되는 기법은, Upsampling으로 ConvTranspose(Deconvolution)이라고 한다.
5. Upsampling 시켜서 나온 결과는 N-1번째 결과의 dimension과 일치한다.
6. N-1번째 결과와 Upsampling 과정의 channel을 서로 연결 시킨 후, conv+BN+ReLU+Upsampling 연산을 반복한다.
여기서 연결되는 것을 skip-connection 이라고 한다!!
U-net은 목적이 segmentation이기떄문에 input / output dimension이 일치해야함!
반드시 일치하지 않고 resize해도 되긴 하지만 정석은 same dimension임!
따라서 CNN기반의 U-NET은 어쩔 수 없이 CNN을 통과하며 feature map 크기가 줄기에 이것을 다시 high-resolution 으로 만들어줘야 한다.
아래 모델 이미지를 통해 convolution과 crop에 대해 똑같이 하는 것을 한 번 보면 이해가 될듯!

3D UNET 구조 U모양으로 생겨서 U-NET임
<Contracting Path 영역>
아래로 내려가는 건 Down-Sampling이라고 하고, convolution을 통해 서서히 낮은 차원의 채널을 가짐
<Expanding Path 영역>
위로 올라가는건 Up-Sampling이라고 하고, crop 이나 resize로 서서히 높은 차원의 채널을 가짐→ 최종 특징 맵으로 부터 보다 높은 해상도를 가진 Segmentation 이미지를 얻기위해 사용
💡 U-NET을 더 알기 위해서 FCN이라는 CNN 기반 이미지 분류 모델을 아는 것이 중요 [https://wikidocs.net/147359](https://wikidocs.net/147359) 참고!U-NET 장점
적은 양의 학습 데이터로도 Data Augmentation을 활용해 여러 Biomedical Image Segmentation 문제에서 우수한 성능을 보임
컨텍스트 정보를 잘 사용하면서도 정확히 지역화함
End-to-End 구조로 속도가 빠름
속도가 빠른 이유: 검증이 끝난 곳은 건너뛰고 다음 Patch부터 새 검증을 하기 때문.
Step2_Transformer란?
💡 Transformer 모델 : NLP 분야에서 단어들간의 연관성에 대한 RNN 문제점을 해결하고 vision Transformer등장으로 지금까지 CV 분야에서 대표적으로 사용하고 있는 모델 <구조> encoder / decoder Transformer가 CV분야에서 매우 대표적인 모델로 자리잡음.Transformer 등장 전
기존 NLP(자연어 처리) 분야에서는 RNN을 이용했는데 → 고질적인 문제가 대표적으로 2개가 존재.
→ 그래서 기존 RNN 모델은 Attention 기법을 같이 엮어서 활용했다.
그러나 Transformer는 RNN 이용하지 않고 Attention 을 이용해 NLP Task를 수행하고, SOTA 모델급의 성능을 보여주면서 여러 컴퓨터 비전 분야에 등장하는 엄청난 모델이 됨!

Attention이란?
단어를 임베딩(embedding)해주는 개념.
→ 특이한 점은 Transformer에선 단어를 형태소를 input으로 하는 것이 아니라 sentence 단위로 input으로 넣음, +) Q(쿼리), K(키), V(벨류) 라는 값을 통해 Sentence 를 구성하는 단어들 간의 연관성을 파악

Q : 단어행렬, K : 유사도행렬, V : 가중치행렬

NLP 분야 transformer 아키텍쳐
좌 : encoder
우 : decoder
Transformer 모델은 Vision Transformer 등장으로 Computer vision에서도 적용이 가능해졌다. → Computer vision 분야에서도 Transformer로 성능을 더욱 향상시킬 수 있다.
Step3_Vision Transformer (ViT)란?
💡 NLP 분야는 1차원 데이터지만 이미지나 영상같은 2D/3D 이미지를 Transformer하는 방법이 있을까? → 이미지를 쪼개서 일렬로 나열해 2D transformer의 특징인 Setence 형태처럼 만들어 input으로 넣으면 되지 않을까? → Vision Transformer 탄생! Vision Transformer method
Vision Transformer(ViT)는 이미지를 patch단위로 쪼갠다는 개념에서 시작 → sentence단위로
1. 이미지를 16x16 patch 크기로 분할 [임의로 지정] 2. 잘라진 이미지들을 linear projection 수행 : 2D → 1D로 Flatten 시킴 - flatten layer는 추출된 주요 특징을 전결합층에 전달하기 위해 1차원 자료로 바꿔주는 layer이다. 1. Positional encoding을 더해줌 - 여기에서 [class] token의 embedding을 0번쨰로 넣어줌 [처음을 가르킴] 1. Transformer encoder를 통해 embedding을 생성 2. [class]에 대한 embedding만 이용해서 MLP에 넣어주고 classification을 진행 - 자연어 처리 모델인 BERT 모델에서는 [class] token을 사용하는데, 이는 문장의 첫 시작을 알려주는 역할을 한다 → 이를 이미지에서도 똑같이 사용 장점 : image classfication이나 detecion 분야에서는 좋은 성능을 갖춤 단점 : 세밀한 영역(우리가 찾는 PVS과 같은 영역)에서의 segmentaion은 좋지 못한 성능을 가짐 → 단점을 보완해서 나온 것이 “Swin-Transformer” : detection 분야와 CV분야에서 SOTA 모델의 베이스 구조를 담당함 💡 Swin U-NETR 링크 : [https://kyujinpy.tistory.com/14](https://kyujinpy.tistory.com/14)위에 개념들을 다 알고가면 UNETR은 기본적인 model 아키텍쳐만 봐도 됨
3d 데이터셋을 3D Patched로 쪼갠 후에 1차원으로 flatten시킴(ViT 원리)
3D U-NET 모델에 넣어 segmentation 진행
배경
U-NETR 나오기 전에는 FCNN(Full Convolution Neural Network)이 10년 동안 쓰였음.
→FCNN의 컨볼류션 레이어의 지역성은 장거리 범위의 학습을 제한함
→Transformer 개념 중 ViT가 이를 해결해줌 → 3D U-NET과 결합해서 쓰니 성능 SOTA급으로 잘나옴
→ 범용적으로 BM 분야에서 3d 이미지 segment를 위해 UNETR 을 많이쓰게 됨


EnhancingTumor core
3d 데이터셋을 3D Patched로 쪼갠 후에 Linear projection 진행해서 1차원으로 flatten시킴(ViT 원리)
→ encoder에 넣고 3D-UNET 진행스타트
import os
import shutil
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
AsDiscrete,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandFlipd,
RandCropByPosNegLabeld,
RandShiftIntensityd,
ToTensord,
ScaleIntensityRanged,
Spacingd,
Spacing,
RandRotate90d,
)
from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR
from monai.data import (
DataLoader,
CacheDataset,
load_decathlon_datalist,
decollate_batch,
)
import torch
import nibabel as nib
from skimage.transform import resize
print_config()#디렉토리 경로 가지고오기
directory = os.environ.get('ENV','/home/bmegpu02/tmp/freesurfer/subjects/DH_SEG/result')
root_dir = directory #임시로 사용할 디렉토리 tmp 내에 존재
#transforemr
#train transforms
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
]
)
#validation transforms
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
#???
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
#??
ScaleIntensityRanged(
keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
),
CropForegroundd(keys=["image", "label"], source_key="image"),
]
)
#test transforms
test_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]),
Orientationd(keys=["image"], axcodes="RAS"),
#???
Spacingd(
keys=["image"],
pixdim=(1.5, 1.5, 2.0),
#mode=("bilinear", "nearest"),
),
#??
ScaleIntensityRanged(
keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
),
CropForegroundd(keys=["image"], source_key="image"),
]
)#> 훈련데이터셋 10개,val 3개 + 테스트데이터셋 7개
data_dir = "./data/"
split_json = "data_seg.json"
###json 파일을 변형해서 내껄로 만들자.!!
datasets = data_dir + split_json
#json 파일이 dataset에 경로 저장되어있는지 확인
if datasets is not None:
print("Success Config PATH : {}".format(datasets)) #-> 존재함 제대로 경로 설정됨
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
data=datalist,
transform=train_transforms,
cache_num=13,
cache_rate=1.0,
num_workers=8,
)
train_loader = DataLoader(
train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True
)
val_ds = CacheDataset(
data=val_files,
transform=val_transforms,
cache_num=3,
cache_rate=1.0,
num_workers=4
)
val_loader = DataLoader(
val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True
)
print("train & val load complete")
#test 14~20
test_files = load_decathlon_datalist(datasets, True, "testing")
test_ds = CacheDataset(
data = test_files,
transform = test_transforms,
cache_num = 7,
cache_rate = 1.0,
num_workers = 4
)
#test_loader = DataLoader(
# test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True
#)
print("test load complete")
####데이터 로드 확인
slice_map = {
"img0011.nii.gz": 79,
"img0012.nii.gz": 79,
"img0013.nii.gz": 79,
}
case_num = 0
img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
img = val_ds[case_num]["image"]
label = val_ds[case_num]["label"]
img_shape = img.shape
label_shape = label.shape
print(f"image shape: {img_shape}, label shape: {label_shape}")
plt.figure("image", (18, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(img[0, :, :, slice_map[img_name]].detach().cpu(), cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[0, :, :, slice_map[img_name]].detach().cpu())
plt.show()
test_slice_map = {
"img0014.nii.gz" : 79,
"img0015.nii.gz" : 79,
"img0016.nii.gz" : 79,
"img0017.nii.gz" : 79,
"img0018.nii.gz" : 79,
"img0019.nii.gz" : 79,
"img0020.nii.gz" : 79
}
case_num = 0
test_img_name = os.path.split(test_ds[case_num]["image"].meta["filename_or_obj"])[1]
test_img = val_ds[case_num]["image"]
img_shape = img.shape
print(f"image shape: {img_shape}")
plt.figure("image", (18, 6))
plt.subplot(1, 2, 1)
plt.title(img_name + " " + "image")
plt.imshow(img[0, :, :, test_slice_map[img_name]].detach().cpu(), cmap="gray")
plt.subplot(1, 2, 2)
plt.title(img_name + " "+ "label")
plt.imshow(label[0, :, :, test_slice_map[img_name]].detach().cpu())
plt.show()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = UNETR(
in_channels=1,
out_channels=10, #9개의 classification 1개의 백그라운드
img_size=(96, 96, 96),
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed="perceptron",
norm_name="instance",
res_block=True,
dropout_rate=0.0,
).to(device)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels)
val_labels_convert = [
post_label(val_label_tensor) for val_label_tensor in val_labels_list
]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [
post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
]
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
epoch_iterator_val.set_description(
"Validate (%d / %d Steps)" % (global_step, 10.0)
)
mean_dice_val = dice_metric.aggregate().item()
dice_metric.reset()
return mean_dice_val
def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
step = 0
epoch_iterator = tqdm(
train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True
)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch["image"].cuda(), batch["label"].cuda())
logit_map = model(x)
loss = loss_function(logit_map, y)
loss.backward()
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
epoch_iterator.set_description(
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)
)
if (
global_step % eval_num == 0 and global_step != 0
) or global_step == max_iterations:
epoch_iterator_val = tqdm(
val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True
)
dice_val = validation(epoch_iterator_val)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save(
model.state_dict(), os.path.join(root_dir, "best_metric_model.pth")
)
print(
"Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
dice_val_best, dice_val
)
)
else:
print(
"Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
dice_val_best, dice_val
)
)
global_step += 1return global_step, dice_val_best, global_step_best
max_iterations = 100
eval_num = 10
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(
global_step, train_loader, dice_val_best, global_step_best
)#best metric 확인
print(
f"train completed, best_metric: {dice_val_best:.4f} "
f"at iteration: {global_step_best}"
)
#loss 및 val mean dice 확인
plt.figure("train", (12,6))
##### Iteration Average loss plot config #####
plt.subplot(1,2,1)
plt.title("Iteration Average loss")
#x = iteration수=최대 max_iteration , y = loss
x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Iteration")
plt.plot(x,y)
##### Val mean Dice plot config #####
plt.subplot(1,2,2)
plt.title("Val mean Dice")
# x = iteraion 수 = 최대 max_iteration , y = 검증 평균율(=Valdility Mean)
x = [eval_num * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x,y)
plt.show()
## val_mean_dice > 0.9 이상 평균적으로 나옴
case_num = 0
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
img = val_ds[case_num]["image"]
label = val_ds[case_num]["label"]
val_inputs = torch.unsqueeze(img, 1).cuda()
val_labels = torch.unsqueeze(label, 1).cuda()
val_outputs = sliding_window_inference(
val_inputs, (96, 96, 96), 4, model, overlap=0.8
)
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title("image")
plt.imshow(val_inputs.cpu().numpy()[0, 0, :, :, slice_map[img_name]], cmap="gray")
plt.subplot(1, 3, 2)
plt.title("label")
plt.imshow(val_labels.cpu().numpy()[0, 0, :, :, slice_map[img_name]])
plt.subplot(1, 3, 3)
plt.title("output")
plt.imshow(
torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, slice_map[img_name]]
)
#test4/png에 저장
plt.savefig(os.path.join(f'/home/bmegpu02/tmp/freesurfer/subjects/DH_SEG/result/test4/png/11seg.png'))
plt.show()
val_outputs.shape[UNETR++ 논문 리뷰] - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation
[Swin Transformer 논문 리뷰] - Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
[TransUNet 논문 리뷰] - TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation