학부 시절 구현해보았던 CAM.
다시 한 번 구현해보았다.
이게 뭐라고 그렇게 어려워했을까?
CAM, GRAD CAM을 지나 GAIN을 구현하고자 함.
Learning Deep Features for Discriminative Localization - 논문 링크
CAM-github
기존의 AI는 블랙박스라는 이야기가 많았다. 이는 특정한 판단에 대해 모델을 제안한 사람도, 코드를 짠 사람도 설명할 수 없는 이유에서 였다. 이러한 우려는 인공지능이 성장해감에 따라 커졌을 것이라 생각한다.
초기 우편번호 분류, 차량 번호판 인식 등 간단한 문제를 풀기 위해 적용되었을 때와 달리 자율 주행, 자동 수술 등 섬세하고 인간의 안전과 생명에 영향을 미치는 문제에도 ai가 자리를 차지하기 시작하면서 AI에 대한 우려가 커졌다. AI가 이런 핵심적인 영역에서 내린 결정의 근거와, 그 타당성을 제공하지 못 하고 있다는 이유에서다.
이에 따라 XAI에 대한 관심이 커졌다. 이를 통해 AI의 판단 결과를 논리적으로 설명하는 것이 가능해졌기 때문이다.
컴퓨터 비전의 XAI 모델 중 하나인 CAM에 대해 얘기하고, 구현한 코드를 공유하려 한다.
CAM이란 CNN이 입력으로 들어온 이미지를 분류할 때 "어떤 부분을 보고" 예측을 했는지를 알려주는 역할을 한다. 아래 사진을 보면 이해가 쉬울 것이다. 입력 이미지에 히트맵을 씌워 주어진 단어를 예측하는 데에 있어 중요한 부분에 가까워질수록 가시광선의 파장에서 파란색에서 빨간색(온도가 높은 = 활성화가 많이 된 = 중요도가 높은)으로 변해감을 확인할 수 있다.
기본적인 아이디어는 다음과 같다.
CNN으로 이미지 분류를 할 때 마지막 단의 출력값이 클수록 softmax를 거친 뒤 1에 가까워 지는데, 그렇다면 입력 이미지의 label에 해당하는 채널마지막 conv layer의 출력이 크게 하는 클래스에 크게 반응했단 거겠지?
위의 발상에 따라 예측한 클래스에 해당하는 Linear layer의 weight을 가져온 뒤 heatmap을 그려 "과연 어떤 부분을 보고 이런 예측을 한 걸까?"에 대해 답을 주는 그림이 바로 CAM(Class Activation Map)이다.
이런 걸 하는 친구인데, 코드를 보자!
import torch
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import numpy as np
class vgg_cam(nn.Module):
def __init__(self, features, num_classes, init_weights=True):
super(vgg_cam, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classfier = nn.Linear(512, num_classes)
def forward(self, x):
x = self.features(x)
map = x
x = self.avgpool(x)
x = torch.squeeze(x)
x = self.classfier(x)
return x, map
class VGG(nn.Module):
def __init__(
self,
features: nn.Module,
num_classes: int = 1000,
init_weights: bool = True
) -> None:
super(VGG, self).__init__()
self.features = features
# 중략
cfgs: Dict[str, List[Union[str, int]]] = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
vgg16 = models.vgg16(pretrained=True)
features = vgg16.features
model = vgg_cam(features=features,
num_classes = 10,
init_weights=True)
PATH = '/content/drive/MyDrive/vgg_cam.pth'
model.load_state_dict(torch.load(PATH))
<All keys matched successfully>
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
train_dataset = datasets.CIFAR10(root='./', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
len_train = len(train_loader) # to be used for logging loss
len_test = len(test_loader) # to be used for logging loss
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
epochs = 10
for epoch in range(epochs):
train_loss = 0
test_loss = 0
for data_, target_ in train_loader:
data = data_.to(device)
target = target_.to(device)
output_, _ = model(data)
output = torch.squeeze(output_)
optimizer.zero_grad()
loss = criterion(output, target)
train_loss += loss.item()
loss.backward()
optimizer.step()
print("{}epoch train loss : {}".format(epoch + 1, round(train_loss / len_train, 2)))
for data_, target_ in test_loader:
data = data_.to(device)
target = target_.to(device)
output_, _ = model(data)
loss = criterion(output, target)
test_loss += loss.item()
print("{}epoch test loss : {}".format(epoch + 1, round(test_loss / len_test, 2)))
img = test_dataset[0][0]
label = test_dataset[0][1]
batch_img = img[None, :, :, :].to(device)
x, map = model(batch_img)
class_weight = model.classfier.weight[int(label)].unsqueeze(-1).unsqueeze(-1)
cam_ = torch.squeeze(map) * class_weight
512 10의 크기를 가지는 nn.Linear에서 입력 이미지의 라벨에 해당하는 벡터를 빼온 뒤, 이를 8-1에서 구한 map과 channel wise로 곱해준다.
코드 설명을 해보자면, class_weight은 512의 shape을, map은 512 7 7 의 shape을 가진다.
이를 channelwise로 곱해주려면 class_weight을 512 1 1 의 shape을 가지게 해주어야 했고, 두 번의 unsqueeze를 통해 shape을 맞추어주었다.
cam = torch.sum(cam_, axis=0)
cam = cam.detach().cpu().numpy()
# print(cam.shape) # DEBUG
1. 설명 가능한 인공지능(XAI), 왜 주목하나?
2. Learning Deep Features for Discriminative Localization
3. torchvision.models.VGG