Malaria / cam_utils.py
mgbam's picture
Create cam_utils.py
e0076d2 verified
import torch, torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np, io, base64
def _normalize_cam(cam):
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-6)
return cam
def grad_cam(model, img: Image.Image, img_size=224, target_layer=None, device="cpu"):
model.eval()
tfms = transforms.Compose([
transforms.Resize(int(img_size*1.15)),
transforms.CenterCrop(img_size),
transforms.ToTensor()
])
x = tfms(img).unsqueeze(0).to(device)
x.requires_grad_(True)
if target_layer is None: # EfficientNet-B0 last block
target_layer = model.features[-1][0]
activations, grads = [], []
def fwd_hook(_, __, out): activations.append(out)
def bwd_hook(_, gin, gout): grads.append(gout[0])
h1 = target_layer.register_forward_hook(fwd_hook)
h2 = target_layer.register_full_backward_hook(bwd_hook)
logits = model(x)
pred = int(logits.argmax(dim=1).item())
score = logits[0, pred]
model.zero_grad(set_to_none=True)
score.backward()
A = activations[-1] # (B,C,h,w) typical
if A.dim() == 4: A = A[0] # (C,h,w)
elif A.dim() == 3: pass # already (C,h,w)
else: A = A.mean(dim=0)
G = grads[-1]
if G.dim() == 4: G = G[0] # (C,h,w)
if G.shape[0] == A.shape[0]:
weights = G.mean(dim=(1,2)) # (C,)
cam = (weights[:, None, None] * A).sum(0) # (h,w)
else:
cam = A.mean(dim=0) # safe fallback
cam = F.relu(cam)[None, None, ...] # (1,1,h,w)
cam = F.interpolate(cam, size=(img_size, img_size), mode='bilinear', align_corners=False)[0,0]
cam = _normalize_cam(cam).detach().cpu().numpy() # (H,W)
img_np = (x[0].detach().cpu().permute(1,2,0).numpy())
img_np = (img_np - img_np.min())/(img_np.max()-img_np.min()+1e-6)
import matplotlib.cm as cm
heat = cm.jet(cam)[..., :3]
overlay = 0.6*img_np + 0.4*heat
overlay = np.clip(overlay, 0, 1)
probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
h1.remove(); h2.remove()
return {"pred": pred, "probs": probs, "overlay": overlay, "input_image": img_np, "cam": cam}