Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from constants import COLORS | |
| from utils import fig2img | |
| def visualize_prediction( | |
| pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None | |
| ): | |
| keep = output_dict["scores"] > threshold | |
| boxes = output_dict["boxes"][keep].tolist() | |
| scores = output_dict["scores"][keep].tolist() | |
| labels = output_dict["labels"][keep].tolist() | |
| if id2label is not None: | |
| labels = [id2label[x] for x in labels] | |
| fig, ax = plt.subplots(figsize=(12, 12)) | |
| ax.imshow(pil_img) | |
| if display_mask and mask is not None: | |
| mask_arr = np.asarray(mask) | |
| new_mask = np.zeros_like(mask_arr) | |
| new_mask[mask_arr > 0] = 255 | |
| new_mask = Image.fromarray(new_mask) | |
| ax.imshow(new_mask, alpha=0.5, cmap='viridis') | |
| colors = COLORS * 100 | |
| counter = 0 | |
| for score, (xmin, ymin, xmax, ymax), label, color in zip( | |
| scores, boxes, labels, colors | |
| ): | |
| counter += 1 | |
| ax.add_patch( | |
| plt.Rectangle( | |
| (xmin, ymin), | |
| xmax - xmin, | |
| ymax - ymin, | |
| fill=False, | |
| color=color, | |
| linewidth=2, | |
| ) | |
| ) | |
| ax.text( | |
| xmin, | |
| ymin, | |
| f"[{counter}] {score:0.2f}", | |
| fontsize=8, | |
| bbox=dict(facecolor="yellow", alpha=0.5), | |
| ) | |
| ax.axis("off") | |
| return fig2img(fig) | |
| def visualize_attention_map(pil_img, attention_map): | |
| attention_map = attention_map[-1].detach().cpu() | |
| n_heads = attention_map.shape[1] | |
| avg_attention_weight = torch.mean(attention_map, dim=1).squeeze() | |
| resized_attention_weight = F.interpolate( | |
| avg_attention_weight.unsqueeze(0).unsqueeze(0), | |
| size=pil_img.size[::-1], | |
| mode="bicubic", | |
| ).squeeze().numpy() | |
| fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4)) | |
| for i, ax in enumerate(axes.flat): | |
| ax.imshow(pil_img) | |
| ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis") | |
| ax.set_title(f"Head {i+1}") | |
| ax.axis("off") | |
| plt.tight_layout() | |
| return fig2img(fig) | |