|
|
import torch, torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np, io, base64 |
|
|
|
|
|
def _normalize_cam(cam): |
|
|
cam = cam - cam.min() |
|
|
cam = cam / (cam.max() + 1e-6) |
|
|
return cam |
|
|
|
|
|
def grad_cam(model, img: Image.Image, img_size=224, target_layer=None, device="cpu"): |
|
|
model.eval() |
|
|
tfms = transforms.Compose([ |
|
|
transforms.Resize(int(img_size*1.15)), |
|
|
transforms.CenterCrop(img_size), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
x = tfms(img).unsqueeze(0).to(device) |
|
|
x.requires_grad_(True) |
|
|
|
|
|
if target_layer is None: |
|
|
target_layer = model.features[-1][0] |
|
|
|
|
|
activations, grads = [], [] |
|
|
def fwd_hook(_, __, out): activations.append(out) |
|
|
def bwd_hook(_, gin, gout): grads.append(gout[0]) |
|
|
|
|
|
h1 = target_layer.register_forward_hook(fwd_hook) |
|
|
h2 = target_layer.register_full_backward_hook(bwd_hook) |
|
|
|
|
|
logits = model(x) |
|
|
pred = int(logits.argmax(dim=1).item()) |
|
|
score = logits[0, pred] |
|
|
model.zero_grad(set_to_none=True) |
|
|
score.backward() |
|
|
|
|
|
A = activations[-1] |
|
|
if A.dim() == 4: A = A[0] |
|
|
elif A.dim() == 3: pass |
|
|
else: A = A.mean(dim=0) |
|
|
|
|
|
G = grads[-1] |
|
|
if G.dim() == 4: G = G[0] |
|
|
|
|
|
if G.shape[0] == A.shape[0]: |
|
|
weights = G.mean(dim=(1,2)) |
|
|
cam = (weights[:, None, None] * A).sum(0) |
|
|
else: |
|
|
cam = A.mean(dim=0) |
|
|
|
|
|
cam = F.relu(cam)[None, None, ...] |
|
|
cam = F.interpolate(cam, size=(img_size, img_size), mode='bilinear', align_corners=False)[0,0] |
|
|
cam = _normalize_cam(cam).detach().cpu().numpy() |
|
|
|
|
|
img_np = (x[0].detach().cpu().permute(1,2,0).numpy()) |
|
|
img_np = (img_np - img_np.min())/(img_np.max()-img_np.min()+1e-6) |
|
|
|
|
|
import matplotlib.cm as cm |
|
|
heat = cm.jet(cam)[..., :3] |
|
|
overlay = 0.6*img_np + 0.4*heat |
|
|
overlay = np.clip(overlay, 0, 1) |
|
|
|
|
|
probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy() |
|
|
|
|
|
h1.remove(); h2.remove() |
|
|
return {"pred": pred, "probs": probs, "overlay": overlay, "input_image": img_np, "cam": cam} |
|
|
|