어제 리뷰한 CLAM의 코드 리뷰.
코드는 아래 깃허브 참고. (https://github.com/mahmoodlab/CLAM?tab=readme-ov-file#pre-requisites)
코드를 보면서 전체적인 pipeline을 다시 한번 살펴보고, CLAM 모델 구조에 대해서 리뷰.
전체적인 흐름에 대해서는 이해가 되는데, 왜 CLAM은 gated attention mechanism 방법을 적용했을까 아직도 잘 모르겠다.
초기 설정 및 파라미터 처리
source directory에서 WSI 파일 목록 불러오기 -> process list로 불러와 처리할 리스트 및 적용할 파라미터 설정
초기 설정 및 파라미터 처리
file_path: directory of bag (.h5 file)
output_path: directory to save computed features (.h5 file)
model: pytorch model
batch_size: batch_size for computing features in batches
verbose: level of feedback
pretrained: use weights pretrained on imagenet
custom_downsample: custom defined downscale factor of image patches
target_patch_size: custom defined, rescaled image size before embedding
데이터셋 초기화
모델 초기화 및 설정
Feature 추출 및 저장
The data used for training and testing are expected to be organized as follows:
DATA_ROOT_DIR/
├──DATASET_1_DATA_DIR/
├── h5_files
├── slide_1.h5
├── slide_2.h5
└── ...
└── pt_files
├── slide_1.pt
├── slide_2.pt
└── ...
├──DATASET_2_DATA_DIR/
├── h5_files
├── slide_a.h5
├── slide_b.h5
└── ...
└── pt_files
├── slide_a.pt
├── slide_b.pt
└── ...
└──DATASET_3_DATA_DIR/
├── h5_files
├── slide_i.h5
├── slide_ii.h5
└── ...
└── pt_files
├── slide_i.pt
├── slide_ii.pt
└── ...
└── ...
Dataset objects used for actual training/validation/testing can be constructed using the Generic_MIL_Dataset Class
1. dataset 준비: csv 파일에서 정보 읽고, 데이터 필터링 및 전처리 수행
- 'Generic_WSI_Classification_Dataset'를 통해서 csv 파일로부터 슬라이드 정보를 읽어 들이고, 필터링, 레이블 지정, 데이터 셔플링 등 사전 처리 작업 수행
train, valid, test set split
Generic_MIL_Dataset:
Generic_Split: Generic_MIL_Dataset 클래스를 상속받아, 특정 데이터 분할
WSI classification을 위해서 task에 따른 데이터 분할을 생성하는 과정.
1. tumor vs normal: binary classification
2. tumor subtyping: multiclass classification
지정된 csv 파일에서 슬라이드 데이터를 로드하고, 지정된 label dictonary를 사용해 label mapping 하여 환자 기반 분할
CLAM은 슬라이드 내의 개별 패치들로부터 추출된 특징을 통해 슬라이드 전체의 진단을 예측.
이 모델은 두 가지 주요 구성 요소를 포함: Attention Mechanism & Classifiers
Attention Mechanism:
(1) Class Attn_Net:
- attention network without gating. 2개의 fully connected layer로 구성이 됨.
- 첫번째 layer는 input feature에 nonlinear 변환을 적용하고, 두번째 layer에서 class에 대한 attention score를 계산.
- 모델은 각 인스턴스의 중요도 점수와 함께 raw feature를 다음 단계로 전달.
-> 간단한 fully connected layer를 통해서 attention score 계산. (input 'x'가 모듈을 통과하고, 모듈의 마지막 nn.Linear 층의 출력이 attention score로 사용)
(2) class Attn_Net_Gated:
-'self.attention_a': input feature를 D차원으로 변환 후 tanh function 태우고
-'self.attention_b': input feature를 D차원으로 변환. 여기는 Sigmoid 함수를 사용. Sigmoid를 사용함으로서 attention score 값을 0-1사이로 출력해 gate 역할을 함.
-'self.attention_c': attention a,b의 출력은 elementwise multiplication 통해서 결합. 그래서 sigmoid 게이트를 통과한 weight가 tanh 함수를 통과한 feature와 결합. 최종 attention score를 계산하는데 사용.
-> 근데 왜 gated attention을 사용했을까?
Similarity 기반 attention: 다른 부분들 간의 상호작용 고려해서 특정 context에서 더 중요한 정보를 강조. 모델이 더 맥락적인 이해를 바탕으로 결정을 내림.
특히 입력 시퀀스의 각 요소가 출력에 미치는 영향을 동적으로 가중치를 부여함. 모델이 전체 시퀀스를 효과적으로 이해하고 처리할 수 있게 해줌.
Gated attention: input feature에 가중치를 부여하는 방식을 사용해서 정보를 선택적으로 집중시키는 방식. 입력 데이터의 중요한 부분에 초점을 맞추고, 덜 중요한 정보는 무시해서 모델이 주어진 작업에 더 관련성 높은 특성을 학습하도록 돕는다.
계산은 한 경로에서는 일반적인 변환을 적용하고, 다른 경로에서는 입력에 대한 게이트 역할을 하는 변환을 적용하고, 이후에 두 경로의 결과를 곱해서, 중요한 부분에만 집중할 수 있도록 .
-> CLAM같은 모델은 이제 의료 데이터셋에 적용이 많이 되고, 사실 의료 데이터는 굉장히 복잡하고 고차원적이고 노이즈가 많기 때문에 모델이 좀더 주어진 문제에서 가장 관련성 높은 정보에만 더 집중할 수 있기 때문에 사용하는 것일가?
왜냐하면 의료 이미지는 패턴들이 굉장히 미묘하게 나타나거나, 배경 정보들이 혼합되어 있을수도 있기 때문에?
-> 아니면 의료이미지 데이터는 환자마다, 또한 질병의 유형마다 매우 다양할 수 있으니깐, gated attention이 이러한 복잡성과 다양성을 더 잘 처리하나?
각 입력에 대해서 모델이 attention을 독립적으로 평가하고, 데이터들의 특성들에 따라서 가중치를 조절하니깐?
그렇기 때문에 동일한 backbone으로 다양한 의료 이미지 데이터셋에 적용할 수 있으니깐, 각 데이터 특성들마다 모델이 자동으로 최적화 가능하니깐 이것을 사용하나?