저가 못 찾은 건지 모르겠지만 pytorch lightning으로 cam을 만드는 코드를 찾을 수 없었니다. 물론 pytorch로 쉽게 구현할 수 있지만 pl로 만든 모델이라 pytorch로 포팅하는 것 보다 predict step이나 test step에 넣어서 보는게 더 편하다고 생각했습니다.
CAM 논문
CAM은 워낙 잘 설명 되어있는 블로그나 영상이 많아서 패스하겠습니다
데이터는 데이콘의 도배하자 분류 유형 대회 데이터 셋을 사용했습니다.
코드를 고치기 쉽도록 좀 분리해서 작성했습니다. 코드를 다 적기에는 너무 길어 일부분만 가져왔습니다. 저는 cam을 만드는 로직을 predict step에 넣었습니다. 그리고 코드에서 cam 이미지는 wandb logger와 연결해 저장해주는 방식입니다.
...
def plot_cam(CONFIG: Config):
data_module = DataModule(CONFIG)
data_module.setup()
model = LightningModel.load_from_checkpoint(".\\ckpt\\resnext_101.ckpt", CONFIG=CONFIG)
trainer = pl.Trainer(
precision=16,
accelerator="gpu",
)
trainer.predict(
model,
data_module
)
...
...
def predict_step(self, batch, batch_idx: int):
# CAM 이미지 저장
imgs, labels = batch
labels = labels.view(labels.size(0), -1) # B -> B 1
_, _map = self(imgs)
weight_fc = self.model.fc.weight.data.T
W = torch.stack([weight_fc[:, labels[i]] for i in range(len(labels))])
W = W.unsqueeze(dim=-1)
cam = torch.mul(_map, W)
cam = torch.sum(cam, dim=1) # B h' w'
cam = cam.detach().cpu().numpy()
for i in range(len(batch)):
fig, ax = plt.subplots(figsize=(15, 15))
origin = imgs[i].detach().cpu().numpy()
origin = origin * 0.5 + 0.5 # denorm
origin = np.transpose(origin, (1, 2, 0))
label = labels[i].detach().cpu().data
final_cam = cv2.resize(cam[i], dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
ax.imshow(origin)
ax.imshow(final_cam, alpha=0.4, cmap="jet")
fig.savefig(f"./camp/{batch_idx}_{i}_{label}.png")
plt.close() # 생성된 plot figure가 메모리를 점유하고 있기 때문에 저장후 삭제 필요 ...
class ResNext_101(nn.Module):
def __init__(self):
super().__init__()
self.feature = timm.create_model("resnext101_64x4d", pretrained=True)
self.gap = self.feature.global_pool
self.fc = nn.Linear(in_features=2048, out_features=19, bias=True)
self.feature.global_pool = nn.Identity()
self.feature.fc = nn.Identity()
def forward(self, x):
_map = self.feature(x)
x = self.gap(_map)
x = self.fc(x)
return x, _map
아래는 결과 사진입니다.
전체 코드는 깃허브 링크에서 볼 수 있습니다.