논문리뷰_UNETR: Transformers for 3D Medical Image Segmentation

Greatwise.ch·2023년 3월 5일

논문리뷰

목록 보기
2/3

  • 참고문헌

[UNETR 논문 리뷰] - UNETR: Transformers for 3D Medical Image Segmentation

필수적으로 알고가기

  • Step1_3D U-NET이란?

    💡 <한줄요약>

    의료분야에서 3차원 데이터를 segmentation 할떄 가장 기본이 되는 구조가 3D U-Net임!

    💡 참고 : [https://kyujinpy.tistory.com/9](https://kyujinpy.tistory.com/9)

    기존의 U-NET은 2D 구조에 대해 segmenation하는 모델이다.

    그러나 최근들어서 의료데이터 등에 적용되면서 3D 구조에 대해 Segmentation 하는 경우가 많아졌다.

    • 2D이미지는 (H,W,3)인 차원이면, 의료데이터와 같은 3D이미지는(H,W,S,1) 차원이다. → 따라서 3d 구조 데이터는 pixel이 아닌 voxel 개념을 갖고 있다.

    • voxel unit : 1 X 1 X 1

      즉 pixel이 아닌 voxel 개념을 가지고 와서 의료데이터 등등에 사용하는 것이 3D U-NET

      참고) U-Net의 단계는 다음과 같음.

    1. Input data를 convolution 연산과 batch Normalization, 그리고 ReLU를 적용한다.
    • 이 과정에서는 feature map의 크기는 줄지 않고, channel의 개수가 늘어난다.
    1. 그리고 Max pooling을 적용해서 feature map의 사이즈를 줄인다.
    2. 위의 과정을 N번 반복한다.
    • 각 과정마다 나온 output을 모두 고려하기 때문에 그대로 둔다(?).

      4. N번 반복한 최종 결과를 Up-convolution 시킨다.

    • 여기서 이용되는 기법은, Upsampling으로 ConvTranspose(Deconvolution)이라고 한다.

      5. Upsampling 시켜서 나온 결과는 N-1번째 결과의 dimension과 일치한다.

      6. N-1번째 결과와 Upsampling 과정의 channel을 서로 연결 시킨 후, conv+BN+ReLU+Upsampling 연산을 반복한다.

    • 여기서 연결되는 것을 skip-connection 이라고 한다!!

    1. 최종적으로, dimension을 원상복귀 시킨 후, channel을 동일하게 만들어서 결과물을 얻는다.
    • U-net은 목적이 segmentation이기떄문에 input / output dimension이 일치해야함!

    • 반드시 일치하지 않고 resize해도 되긴 하지만 정석은 same dimension임!

    • 따라서 CNN기반의 U-NET은 어쩔 수 없이 CNN을 통과하며 feature map 크기가 줄기에 이것을 다시 high-resolution 으로 만들어줘야 한다.

      아래 모델 이미지를 통해 convolution과 crop에 대해 똑같이 하는 것을 한 번 보면 이해가 될듯!

      3D UNET 구조 U모양으로 생겨서 U-NET임

      <Contracting Path 영역>

      아래로 내려가는 건 Down-Sampling이라고 하고, convolution을 통해 서서히 낮은 차원의 채널을 가짐

      <Expanding Path 영역>

      위로 올라가는건 Up-Sampling이라고 하고, crop 이나 resize로 서서히 높은 차원의 채널을 가짐→ 최종 특징 맵으로 부터 보다 높은 해상도를 가진 Segmentation 이미지를 얻기위해 사용

      💡 U-NET을 더 알기 위해서 FCN이라는 CNN 기반 이미지 분류 모델을 아는 것이 중요 [https://wikidocs.net/147359](https://wikidocs.net/147359) 참고!

      U-NET 장점

    • 적은 양의 학습 데이터로도 Data Augmentation을 활용해 여러 Biomedical Image Segmentation 문제에서 우수한 성능을 보임

    • 컨텍스트 정보를 잘 사용하면서도 정확히 지역화함

    • End-to-End 구조로 속도가 빠름

    • 속도가 빠른 이유: 검증이 끝난 곳은 건너뛰고 다음 Patch부터 새 검증을 하기 때문.

  • Step2_Transformer란?

    💡 Transformer 모델 : NLP 분야에서 단어들간의 연관성에 대한 RNN 문제점을 해결하고 vision Transformer등장으로 지금까지 CV 분야에서 대표적으로 사용하고 있는 모델 <구조> encoder / decoder Transformer가 CV분야에서 매우 대표적인 모델로 자리잡음.

    💡 참고 https://kyujinpy.tistory.com/2

  1. Transformer 등장 전

    기존 NLP(자연어 처리) 분야에서는 RNN을 이용했는데 → 고질적인 문제가 대표적으로 2개가 존재.

    1. Long-term dependecy problems(앞에 있는 객체의 영향력이 뒤에 까지 온전히 전달되지 못함)가 존재
    2. sequence(문장)에서 단어 간의 연관성이 RNN에서는 각 단어에 대한 전파력이 앞으로만 전달되니까 모든 단어들 간의 연관성을 파악하기가 쉽지 않다.

    → 그래서 기존 RNN 모델은 Attention 기법을 같이 엮어서 활용했다.

    그러나 Transformer는 RNN 이용하지 않고 Attention 을 이용해 NLP Task를 수행하고, SOTA 모델급의 성능을 보여주면서 여러 컴퓨터 비전 분야에 등장하는 엄청난 모델이 됨!

  1. Attention이란?

    단어를 임베딩(embedding)해주는 개념.

    → 특이한 점은 Transformer에선 단어를 형태소를 input으로 하는 것이 아니라 sentence 단위로 input으로 넣음, +) Q(쿼리), K(키), V(벨류) 라는 값을 통해 Sentence 를 구성하는 단어들 간의 연관성을 파악

    Q : 단어행렬, K : 유사도행렬, V : 가중치행렬

    1. NLP Transformer란?

    NLP 분야 transformer 아키텍쳐
    좌 : encoder
    우 : decoder

    Transformer 모델은 Vision Transformer 등장으로 Computer vision에서도 적용이 가능해졌다. → Computer vision 분야에서도 Transformer로 성능을 더욱 향상시킬 수 있다.

    • Transformer가 CV분야에서 매우 대표적인 모델로 자리잡음.
  • Step3_Vision Transformer (ViT)란?

    💡 NLP 분야는 1차원 데이터지만 이미지나 영상같은 2D/3D 이미지를 Transformer하는 방법이 있을까? → 이미지를 쪼개서 일렬로 나열해 2D transformer의 특징인 Setence 형태처럼 만들어 input으로 넣으면 되지 않을까? → Vision Transformer 탄생! ![](https://velog.velcdn.com/images/dablro12/post/8abdea8d-dac6-458e-9211-573681f990ba/image.png)

    Vision Transformer method

    Vision Transformer(ViT)는 이미지를 patch단위로 쪼갠다는 개념에서 시작 → sentence단위로

    1. 이미지를 16x16 patch 크기로 분할 [임의로 지정] 2. 잘라진 이미지들을 linear projection 수행 : 2D → 1D로 Flatten 시킴 - flatten layer는 추출된 주요 특징을 전결합층에 전달하기 위해 1차원 자료로 바꿔주는 layer이다. 1. Positional encoding을 더해줌 - 여기에서 [class] token의 embedding을 0번쨰로 넣어줌 [처음을 가르킴] 1. Transformer encoder를 통해 embedding을 생성 2. [class]에 대한 embedding만 이용해서 MLP에 넣어주고 classification을 진행 - 자연어 처리 모델인 BERT 모델에서는 [class] token을 사용하는데, 이는 문장의 첫 시작을 알려주는 역할을 한다 → 이를 이미지에서도 똑같이 사용 장점 : image classfication이나 detecion 분야에서는 좋은 성능을 갖춤 단점 : 세밀한 영역(우리가 찾는 PVS과 같은 영역)에서의 segmentaion은 좋지 못한 성능을 가짐 → 단점을 보완해서 나온 것이 “Swin-Transformer” : detection 분야와 CV분야에서 SOTA 모델의 베이스 구조를 담당함 💡 Swin U-NETR 링크 : [https://kyujinpy.tistory.com/14](https://kyujinpy.tistory.com/14)

위에 개념들을 다 알고가면 UNETR은 기본적인 model 아키텍쳐만 봐도 됨

본문 UNETR : UNEt TRansformers

💡 한줄 요약 : 3d input이미지를 3D U-NET + ViT 인 Method로 seg했을떄 ⇒ U-NETR 💡
  1. 3d 데이터셋을 3D Patched로 쪼갠 후에 1차원으로 flatten시킴(ViT 원리)

  2. 3D U-NET 모델에 넣어 segmentation 진행

  • 배경

  • U-NETR 나오기 전에는 FCNN(Full Convolution Neural Network)이 10년 동안 쓰였음.

    →FCNN의 컨볼류션 레이어의 지역성은 장거리 범위의 학습을 제한함

    →Transformer 개념 중 ViT가 이를 해결해줌 → 3D U-NET과 결합해서 쓰니 성능 SOTA급으로 잘나옴

    → 범용적으로 BM 분야에서 3d 이미지 segment를 위해 UNETR 을 많이쓰게 됨

    1. Overview of UNETR

    1. Architecture of UNETR & Method

    ![Untitled](https://s3-us-west-![업로드중..](blob:https://velog.io/4a1d97bc-e71b-4dbf-bbe3-f63bf777d94a)
    EnhancingTumor core

    1. input data preprocessing [3d 모델이므로 voxel size : H W D * 4]

    3d 데이터셋을 3D Patched로 쪼갠 후에 Linear projection 진행해서 1차원으로 flatten시킴(ViT 원리)

    → encoder에 넣고 3D-UNET 진행스타트

    1. Normalization과 Multi-Head-Attention을 적용한 후 residual Network 구조를 통해 합쳐진 데이터를 MLP 에 다시 넣는다.
    2. 이걸 총 Normalization과 Attention을 N=12번 반복
    3. Transformer연산이 나오면, 나온 embedding(가중치값)을 reshape과정을 통해 3D voxel형식으로 만듬
    4. scale이 작으므로 Upsampling( resize이나 Spacingd)하면서 N = 9,6,3,9에서 나온 Feature들과 skip-connection 진행
    5. 마지막 레이어에서 최종 dimension과 channel을 input data와 똑같게 조절한 후 result 출력

    UNETR 적용한 코드리뷰

    • setup
      import os
      import shutil  
      import tempfile 
      import numpy as np
      import matplotlib.pyplot as plt 
      from tqdm import tqdm 
      
      from monai.losses import DiceCELoss
      from monai.inferers import sliding_window_inference
      from monai.transforms import (
          AsDiscrete,
          EnsureChannelFirstd,
          Compose,
          CropForegroundd,
          LoadImaged,
          Orientationd,
          RandFlipd,
          RandCropByPosNegLabeld,
          RandShiftIntensityd,
          ToTensord,
          ScaleIntensityRanged,
          Spacingd,
          Spacing,
          RandRotate90d,
      )
      
      from monai.config import print_config
      from monai.metrics import DiceMetric
      from monai.networks.nets import UNETR
      
      from monai.data import (
          DataLoader,
          CacheDataset,
          load_decathlon_datalist,
          decollate_batch,
      )
      import torch
      import nibabel as nib
      from skimage.transform import resize
      print_config()
    • PATH setup 및 transformer
      #디렉토리 경로 가지고오기 
      directory = os.environ.get('ENV','/home/bmegpu02/tmp/freesurfer/subjects/DH_SEG/result')
      root_dir = directory #임시로 사용할 디렉토리 tmp 내에 존재
      
      #transforemr
      #train transforms
      train_transforms = Compose(
          [
              LoadImaged(keys=["image", "label"]),
              EnsureChannelFirstd(keys=["image", "label"]),
              Orientationd(keys=["image", "label"], axcodes="RAS"),
              Spacingd(
                  keys=["image", "label"],
                  pixdim=(1.5, 1.5, 2.0),
                  mode=("bilinear", "nearest"),
              ),
              ScaleIntensityRanged(
                  keys=["image"],
                  a_min=-175,
                  a_max=250,
                  b_min=0.0,
                  b_max=1.0,
                  clip=True,
              ),
              CropForegroundd(keys=["image", "label"], source_key="image"),
              RandCropByPosNegLabeld(
                  keys=["image", "label"],
                  label_key="label",
                  spatial_size=(96, 96, 96),
                  pos=1,
                  neg=1,
                  num_samples=4,
                  image_key="image",
                  image_threshold=0,
              ),
              RandFlipd(
                  keys=["image", "label"],
                  spatial_axis=[0],
                  prob=0.10,
              ),
              RandFlipd(
                  keys=["image", "label"],
                  spatial_axis=[1],
                  prob=0.10,
              ),
              RandFlipd(
                  keys=["image", "label"],
                  spatial_axis=[2],
                  prob=0.10,
              ),
              RandRotate90d(
                  keys=["image", "label"],
                  prob=0.10,
                  max_k=3,
              ),
              RandShiftIntensityd(
                  keys=["image"],
                  offsets=0.10,
                  prob=0.50,
              ),
          ]
      )
      #validation transforms
      val_transforms = Compose(
          [
              
              LoadImaged(keys=["image", "label"]),
              EnsureChannelFirstd(keys=["image", "label"]),
              Orientationd(keys=["image", "label"], axcodes="RAS"),
              #???
              Spacingd(
                  keys=["image", "label"],
                  pixdim=(1.5, 1.5, 2.0),
                  mode=("bilinear", "nearest"),
              ),
              #??
              ScaleIntensityRanged(
                 keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
              ),
              CropForegroundd(keys=["image", "label"], source_key="image"),
          ]
      )
      
      #test transforms
      test_transforms = Compose(
          [
              
              LoadImaged(keys=["image"]),
              EnsureChannelFirstd(keys=["image"]),
              Orientationd(keys=["image"], axcodes="RAS"),
              #???
              Spacingd(
                  keys=["image"],
                  pixdim=(1.5, 1.5, 2.0),
                  #mode=("bilinear", "nearest"),
              ),
              #??
              ScaleIntensityRanged(
                 keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
              ),
              CropForegroundd(keys=["image"], source_key="image"),
          ]
      )
    • data preprocessing
      #> 훈련데이터셋 10개,val 3개 +  테스트데이터셋 7개
      data_dir = "./data/"
      split_json = "data_seg.json"
      ###json 파일을 변형해서 내껄로 만들자.!!
      datasets = data_dir + split_json
      
      #json 파일이 dataset에 경로 저장되어있는지 확인
      if datasets is not None:
          print("Success Config PATH : {}".format(datasets)) #-> 존재함 제대로 경로 설정됨 
      
      datalist = load_decathlon_datalist(datasets, True, "training")
      val_files = load_decathlon_datalist(datasets, True, "validation")
      train_ds = CacheDataset(
          data=datalist,
          transform=train_transforms,
          cache_num=13,
          cache_rate=1.0,
          num_workers=8,
      )
      train_loader = DataLoader(
          train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True
      )
      val_ds = CacheDataset(
          data=val_files,
          transform=val_transforms,
          cache_num=3,
          cache_rate=1.0,
          num_workers=4
      )
      val_loader = DataLoader(
          val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True
      )
      print("train & val load complete")
      
      #test 14~20
      test_files = load_decathlon_datalist(datasets, True, "testing")
      test_ds = CacheDataset(
          data = test_files,
          transform = test_transforms,
          cache_num = 7,
          cache_rate = 1.0,
          num_workers = 4
      )
      #test_loader = DataLoader(
      #    test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True
      #)
      print("test load complete")
      
      ####데이터 로드 확인 
      slice_map = {
          "img0011.nii.gz": 79,
          "img0012.nii.gz": 79,
          "img0013.nii.gz": 79,
      }
      case_num = 0
      img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
      img = val_ds[case_num]["image"]
      label = val_ds[case_num]["label"]
      img_shape = img.shape
      label_shape = label.shape
      print(f"image shape: {img_shape}, label shape: {label_shape}")
      plt.figure("image", (18, 6))
      plt.subplot(1, 2, 1)
      plt.title("image")
      plt.imshow(img[0, :, :, slice_map[img_name]].detach().cpu(), cmap="gray")
      plt.subplot(1, 2, 2)
      plt.title("label")
      plt.imshow(label[0, :, :, slice_map[img_name]].detach().cpu())
      plt.show()
      
      test_slice_map = {
          "img0014.nii.gz" : 79,
          "img0015.nii.gz" : 79,
          "img0016.nii.gz" : 79,
          "img0017.nii.gz" : 79,
          "img0018.nii.gz" : 79,
          "img0019.nii.gz" : 79,
          "img0020.nii.gz" : 79
      }
      case_num = 0
      test_img_name = os.path.split(test_ds[case_num]["image"].meta["filename_or_obj"])[1]
      test_img = val_ds[case_num]["image"]
      img_shape = img.shape
      print(f"image shape: {img_shape}")
      plt.figure("image", (18, 6))
      plt.subplot(1, 2, 1)
      plt.title(img_name + " " + "image")
      plt.imshow(img[0, :, :, test_slice_map[img_name]].detach().cpu(), cmap="gray")
      plt.subplot(1, 2, 2)
      plt.title(img_name + " "+ "label")
      plt.imshow(label[0, :, :, test_slice_map[img_name]].detach().cpu())
      plt.show()
      
    • model 만들기, optimizer, loss function 만들기
      os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      print(device)
      model = UNETR(
          in_channels=1,
          out_channels=10, #9개의 classification 1개의 백그라운드
          img_size=(96, 96, 96),
          feature_size=16,
          hidden_size=768,
          mlp_dim=3072,
          num_heads=12,
          pos_embed="perceptron",
          norm_name="instance",
          res_block=True,
          dropout_rate=0.0,
      ).to(device)
      
      loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
      torch.backends.cudnn.benchmark = True
      optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    • pytorch 훈련 * UNET기반
      def validation(epoch_iterator_val):
      model.eval()
      with torch.no_grad():
      for batch in epoch_iterator_val:
      val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
      val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
              val_labels_list = decollate_batch(val_labels)
              val_labels_convert = [
                  post_label(val_label_tensor) for val_label_tensor in val_labels_list
              ]
              val_outputs_list = decollate_batch(val_outputs)
              val_output_convert = [
                  post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
              ]
              dice_metric(y_pred=val_output_convert, y=val_labels_convert)
              epoch_iterator_val.set_description(
                  "Validate (%d / %d Steps)" % (global_step, 10.0)
              )
      
          mean_dice_val = dice_metric.aggregate().item()
          dice_metric.reset()
      return mean_dice_val
      
      def train(global_step, train_loader, dice_val_best, global_step_best):
      model.train()
      epoch_loss = 0
      step = 0
      epoch_iterator = tqdm(
      train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True
      )
      for step, batch in enumerate(epoch_iterator):
      step += 1
      x, y = (batch["image"].cuda(), batch["label"].cuda())
      logit_map = model(x)
      loss = loss_function(logit_map, y)
      loss.backward()
      epoch_loss += loss.item()
      optimizer.step()
      optimizer.zero_grad()
      epoch_iterator.set_description(
      "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)
      )
      if (
      global_step % eval_num == 0 and global_step != 0
      ) or global_step == max_iterations:
      epoch_iterator_val = tqdm(
      val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True
      )
      dice_val = validation(epoch_iterator_val)
      epoch_loss /= step
      epoch_loss_values.append(epoch_loss)
      metric_values.append(dice_val)
      if dice_val > dice_val_best:
      dice_val_best = dice_val
      global_step_best = global_step
      torch.save(
      model.state_dict(), os.path.join(root_dir, "best_metric_model.pth")
      )
      print(
      "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
      dice_val_best, dice_val
      )
      )
      else:
      print(
      "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
      dice_val_best, dice_val
      )
      )
      global_step += 1
      return global_step, dice_val_best, global_step_best
      
      max_iterations = 100
      eval_num = 10
      post_label = AsDiscrete(to_onehot=14)
      post_pred = AsDiscrete(argmax=True, to_onehot=14)
      dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
      global_step = 0
      dice_val_best = 0.0
      global_step_best = 0
      epoch_loss_values = []
      metric_values = []
      while global_step < max_iterations:
      global_step, dice_val_best, global_step_best = train(
      global_step, train_loader, dice_val_best, global_step_best
      )
    • best metric 확인 및 model 검증
      #best metric 확인 
      print(
          f"train completed, best_metric: {dice_val_best:.4f} "
          f"at iteration: {global_step_best}"
      )
      
      #loss 및 val mean dice 확인
      plt.figure("train", (12,6))
      
      ##### Iteration Average loss plot config #####
      plt.subplot(1,2,1)
      plt.title("Iteration Average loss")
      #x = iteration수=최대 max_iteration , y = loss 
      x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
      y = epoch_loss_values
      
      plt.xlabel("Iteration")
      plt.plot(x,y)
      
      ##### Val mean Dice plot config #####
      plt.subplot(1,2,2)
      plt.title("Val mean Dice")
      # x = iteraion 수 = 최대 max_iteration , y = 검증 평균율(=Valdility Mean) 
      x = [eval_num * (i + 1) for i in range(len(metric_values))] 
      y = metric_values
      
      plt.xlabel("Iteration")
      plt.plot(x,y)
      plt.show()
      
      ## val_mean_dice > 0.9 이상 평균적으로 나옴
      
    • best model load 및 save output image (to png, nifti files)(미완)
      case_num = 0
      model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
      model.eval()
      with torch.no_grad():
          img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
          img = val_ds[case_num]["image"]
          label = val_ds[case_num]["label"]
          val_inputs = torch.unsqueeze(img, 1).cuda()
          val_labels = torch.unsqueeze(label, 1).cuda()
          val_outputs = sliding_window_inference(
              val_inputs, (96, 96, 96), 4, model, overlap=0.8
          )
          plt.figure("check", (18, 6))
          plt.subplot(1, 3, 1)
          plt.title("image")
          plt.imshow(val_inputs.cpu().numpy()[0, 0, :, :, slice_map[img_name]], cmap="gray")
          plt.subplot(1, 3, 2)
          plt.title("label")
          plt.imshow(val_labels.cpu().numpy()[0, 0, :, :, slice_map[img_name]])
          plt.subplot(1, 3, 3)
          plt.title("output")
          plt.imshow(
              torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, slice_map[img_name]]
          )
          #test4/png에 저장 
          plt.savefig(os.path.join(f'/home/bmegpu02/tmp/freesurfer/subjects/DH_SEG/result/test4/png/11seg.png'))        
          plt.show()
          
      val_outputs.shape

    참고 및 의견

    • U-NETR 업그레이드 버전 논문리뷰

    [UNETR++ 논문 리뷰] - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation

    • ViT가 세밀한 영역을 잡지 못하는 것을 보완한 모델 Swin_U-NETR 리뷰

    [Swin Transformer 논문 리뷰] - Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

    • Transform + 2D U-NET → TransUNET [UNETR 2D 버전]

    [TransUNet 논문 리뷰] - TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

    profile
    제 노션에 더욱 유익한 정보가 많습니다. 많이 놀러와주세요.

    0개의 댓글