PyTorch Lightning에서 CAM 시각화 구현

BSH·2023년 5월 23일
0

저가 못 찾은 건지 모르겠지만 pytorch lightning으로 cam을 만드는 코드를 찾을 수 없었니다. 물론 pytorch로 쉽게 구현할 수 있지만 pl로 만든 모델이라 pytorch로 포팅하는 것 보다 predict step이나 test step에 넣어서 보는게 더 편하다고 생각했습니다.

CAM(class activation maps)

CAM 논문
CAM은 워낙 잘 설명 되어있는 블로그나 영상이 많아서 패스하겠습니다

PL code

데이터는 데이콘의 도배하자 분류 유형 대회 데이터 셋을 사용했습니다.

코드를 고치기 쉽도록 좀 분리해서 작성했습니다. 코드를 다 적기에는 너무 길어 일부분만 가져왔습니다. 저는 cam을 만드는 로직을 predict step에 넣었습니다. 그리고 코드에서 cam 이미지는 wandb logger와 연결해 저장해주는 방식입니다.

train.py

...
def plot_cam(CONFIG: Config):
  data_module = DataModule(CONFIG)
  data_module.setup()
  
  model = LightningModel.load_from_checkpoint(".\\ckpt\\resnext_101.ckpt", CONFIG=CONFIG)
  
  trainer = pl.Trainer(
    precision=16,
    accelerator="gpu",
  )
  
  trainer.predict(
    model,
    data_module
  )
...

utils.py

...
def predict_step(self, batch, batch_idx: int):
    # CAM 이미지 저장
    imgs, labels = batch
    labels = labels.view(labels.size(0), -1) # B -> B 1
    _, _map = self(imgs)
    weight_fc = self.model.fc.weight.data.T
    
    W = torch.stack([weight_fc[:, labels[i]] for i in range(len(labels))])
    W = W.unsqueeze(dim=-1)
    cam = torch.mul(_map, W)
    cam = torch.sum(cam, dim=1) # B h' w'
    cam = cam.detach().cpu().numpy()
    
    for i in range(len(batch)):
      fig, ax = plt.subplots(figsize=(15, 15))
      origin = imgs[i].detach().cpu().numpy()
      origin = origin * 0.5 + 0.5 # denorm
      origin = np.transpose(origin, (1, 2, 0))
      label = labels[i].detach().cpu().data
      final_cam = cv2.resize(cam[i], dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
      ax.imshow(origin)
      ax.imshow(final_cam, alpha=0.4, cmap="jet")
      fig.savefig(f"./camp/{batch_idx}_{i}_{label}.png")
      plt.close() # 생성된 plot figure가 메모리를 점유하고 있기 때문에 저장후 삭제 필요 ...

models.py

class ResNext_101(nn.Module):
  def __init__(self):
    super().__init__()
    self.feature = timm.create_model("resnext101_64x4d", pretrained=True)
    self.gap = self.feature.global_pool
    self.fc = nn.Linear(in_features=2048, out_features=19, bias=True)
    
    self.feature.global_pool = nn.Identity()
    self.feature.fc = nn.Identity()
  
  def forward(self, x):
    _map = self.feature(x)
    x = self.gap(_map)
    x = self.fc(x)
    return x, _map

아래는 결과 사진입니다.

전체 코드는 깃허브 링크에서 볼 수 있습니다.

profile
컴공생

0개의 댓글