Data-efficient and weakly supervised computational pathology on whole-slide images (2) - code review

김종현·2024년 3월 12일
0

Paper review

목록 보기
8/9

어제 리뷰한 CLAM의 코드 리뷰.
코드는 아래 깃허브 참고. (https://github.com/mahmoodlab/CLAM?tab=readme-ov-file#pre-requisites)

코드를 보면서 전체적인 pipeline을 다시 한번 살펴보고, CLAM 모델 구조에 대해서 리뷰.

전체적인 흐름에 대해서는 이해가 되는데, 왜 CLAM은 gated attention mechanism 방법을 적용했을까 아직도 잘 모르겠다.

preprocess pipeline

  1. 초기 설정 및 파라미터 처리

  2. source directory에서 WSI 파일 목록 불러오기 -> process list로 불러와 처리할 리스트 및 적용할 파라미터 설정

    • Segmentation: 'segment' 함수 통해 segmentation 수행. 'WholeSlideImage' class의 'segmentTissue' function 사용
    • Save mask: segmentation 결과를 바탕으로 시각화된 마스크 생성 및 저장
    • Patching: 'patching' 함수를 불러와서 segmentation 결과를 기반으로 이미지 패치를 생성하고 저장. 'WholeSlideImage' class의 'process_contours' function 사용
    • Stitching: 생성된 패치들의 좌표 정보를 사용하여 전체 슬라이드 뷰를 재구성. 'StitchCoords' function 사용

extract features

  1. 초기 설정 및 파라미터 처리
    file_path: directory of bag (.h5 file)
    output_path: directory to save computed features (.h5 file)
    model: pytorch model
    batch_size: batch_size for computing features in batches
    verbose: level of feedback
    pretrained: use weights pretrained on imagenet
    custom_downsample: custom defined downscale factor of image patches
    target_patch_size: custom defined, rescaled image size before embedding

  2. 데이터셋 초기화

    • 입력받은 csv 파일에서 WSI 파일 목록 불러옴.
    • 'Dataset_All_Bags' 클래스를 사용해서 WSI 파일에 대한 데이터셋 생성
      -> WSI에서 추출함 패치 데이터를 배치 단위로 처리하기 위한 사전 준비 작업. 인자로 받은 csv 파일 경로를 사용해서 dataframe을 생성한 이후
      ,데이터 프레임의 특정 인덱스에 해당하는 슬라이드 정보를 반환하고, 이 방법을 통해 dataloader가 각 슬라이드 정보를 순차적으로 또는 랜덤으로 접근 가능.
  3. 모델 초기화 및 설정

    • 'resnet50_baseline' 함수를 사용해 resnet50 모델을 로드하고, 필요한 경우 ImageNet pretrained weights 가져옴.
    • 이후 model = model.to(device) 및 멀티 GPU 가능하면 dataparallel 실행 + model.eval() 설정
  4. Feature 추출 및 저장

    • Openslide 통해서 WSI 파일을 열고, 해당 WSI의 patch 단위로 feature 추출을 하기 위한 준비
    • "compute_w_loader" 함수를 통해서 dataloader를 통해서 배치 단위로 이미지 패치를 모델에 전달하고, feature extract
    • 추출된 feature 및 좌표들은 hdf5 파일로 저장
    • 추출된 feature들은 pytorch 텐서로 변환해서 .pt 파일 형식으로 저장

Datasets

The data used for training and testing are expected to be organized as follows:

DATA_ROOT_DIR/
├──DATASET_1_DATA_DIR/
├── h5_files
├── slide_1.h5
├── slide_2.h5
└── ...
└── pt_files
├── slide_1.pt
├── slide_2.pt
└── ...
├──DATASET_2_DATA_DIR/
├── h5_files
├── slide_a.h5
├── slide_b.h5
└── ...
└── pt_files
├── slide_a.pt
├── slide_b.pt
└── ...
└──DATASET_3_DATA_DIR/
├── h5_files
├── slide_i.h5
├── slide_ii.h5
└── ...
└── pt_files
├── slide_i.pt
├── slide_ii.pt
└── ...
└── ...

Dataset objects used for actual training/validation/testing can be constructed using the Generic_MIL_Dataset Class
1. dataset 준비: csv 파일에서 정보 읽고, 데이터 필터링 및 전처리 수행
- 'Generic_WSI_Classification_Dataset'를 통해서 csv 파일로부터 슬라이드 정보를 읽어 들이고, 필터링, 레이블 지정, 데이터 셔플링 등 사전 처리 작업 수행

  1. train, valid, test set split

    • 데이터셋을 train,valid,test로 나누어준 이후 이를 csv 파일로 저장하는 함수
  2. Generic_MIL_Dataset:

    • Generic_WSI_Classification_Dataset 클래스를 상속받아, MIL을 위한 데이터셋 클래스를 정의
  3. Generic_Split: Generic_MIL_Dataset 클래스를 상속받아, 특정 데이터 분할

Training Splits

WSI classification을 위해서 task에 따른 데이터 분할을 생성하는 과정.
1. tumor vs normal: binary classification
2. tumor subtyping: multiclass classification

지정된 csv 파일에서 슬라이드 데이터를 로드하고, 지정된 label dictonary를 사용해 label mapping 하여 환자 기반 분할

CLAM model

CLAM은 슬라이드 내의 개별 패치들로부터 추출된 특징을 통해 슬라이드 전체의 진단을 예측.
이 모델은 두 가지 주요 구성 요소를 포함: Attention Mechanism & Classifiers

Attention Mechanism:

(1) Class Attn_Net:
- attention network without gating. 2개의 fully connected layer로 구성이 됨.
- 첫번째 layer는 input feature에 nonlinear 변환을 적용하고, 두번째 layer에서 class에 대한 attention score를 계산.
- 모델은 각 인스턴스의 중요도 점수와 함께 raw feature를 다음 단계로 전달.
-> 간단한 fully connected layer를 통해서 attention score 계산. (input 'x'가 모듈을 통과하고, 모듈의 마지막 nn.Linear 층의 출력이 attention score로 사용)

(2) class Attn_Net_Gated:
-'self.attention_a': input feature를 D차원으로 변환 후 tanh function 태우고
-'self.attention_b': input feature를 D차원으로 변환. 여기는 Sigmoid 함수를 사용. Sigmoid를 사용함으로서 attention score 값을 0-1사이로 출력해 gate 역할을 함.
-'self.attention_c': attention a,b의 출력은 elementwise multiplication 통해서 결합. 그래서 sigmoid 게이트를 통과한 weight가 tanh 함수를 통과한 feature와 결합. 최종 attention score를 계산하는데 사용.

-> 근데 왜 gated attention을 사용했을까?
Similarity 기반 attention: 다른 부분들 간의 상호작용 고려해서 특정 context에서 더 중요한 정보를 강조. 모델이 더 맥락적인 이해를 바탕으로 결정을 내림.
특히 입력 시퀀스의 각 요소가 출력에 미치는 영향을 동적으로 가중치를 부여함. 모델이 전체 시퀀스를 효과적으로 이해하고 처리할 수 있게 해줌.

Gated attention: input feature에 가중치를 부여하는 방식을 사용해서 정보를 선택적으로 집중시키는 방식. 입력 데이터의 중요한 부분에 초점을 맞추고, 덜 중요한 정보는 무시해서 모델이 주어진 작업에 더 관련성 높은 특성을 학습하도록 돕는다.
계산은 한 경로에서는 일반적인 변환을 적용하고, 다른 경로에서는 입력에 대한 게이트 역할을 하는 변환을 적용하고, 이후에 두 경로의 결과를 곱해서, 중요한 부분에만 집중할 수 있도록 .

-> CLAM같은 모델은 이제 의료 데이터셋에 적용이 많이 되고, 사실 의료 데이터는 굉장히 복잡하고 고차원적이고 노이즈가 많기 때문에 모델이 좀더 주어진 문제에서 가장 관련성 높은 정보에만 더 집중할 수 있기 때문에 사용하는 것일가?
왜냐하면 의료 이미지는 패턴들이 굉장히 미묘하게 나타나거나, 배경 정보들이 혼합되어 있을수도 있기 때문에?
-> 아니면 의료이미지 데이터는 환자마다, 또한 질병의 유형마다 매우 다양할 수 있으니깐, gated attention이 이러한 복잡성과 다양성을 더 잘 처리하나?
각 입력에 대해서 모델이 attention을 독립적으로 평가하고, 데이터들의 특성들에 따라서 가중치를 조절하니깐?
그렇기 때문에 동일한 backbone으로 다양한 의료 이미지 데이터셋에 적용할 수 있으니깐, 각 데이터 특성들마다 모델이 자동으로 최적화 가능하니깐 이것을 사용하나?

  1. Classifiers: CLAM
    구성:
    (1) Attention network: gated attention or simple attention mechanism
    (2) Classifiers: attention weight를 사용해서 bag label 특성을 집계하고, 이를 바탕으로 가방이 속하는 클래스 예측
    (3) Instance classifiers: 인스턴스 레벨의 분류를 수행해서 추가적인 인스턴스 레벨의 학습 signal 제공.
    - 각 인스턴스(예: 이미지 내의 패치)를 독립적으로 분류하는 역할
    - 이러한 인스턴스 레벨 분류기의 주요 목적은 모델이 bag level에서 뿐만 아니라, 각 instance level에서도 유의미한 학습을 수행할 수 있도록 하는 것
profile
EXPLORE NEW POSSIBILITIES AT THE INTERSECTION OF AI AND MEDICAL

0개의 댓글

관련 채용 정보