| | import itertools |
| | import os |
| | import pickle |
| | from math import sqrt |
| | import re |
| | import yaml |
| |
|
| | import numpy as np |
| | import timm |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision |
| | from einops import rearrange |
| | from transformers import BertModel, AutoTokenizer |
| | import torchvision.transforms as T |
| | import clip |
| | import importlib |
| | from .us import normalize |
| |
|
| | from .pamr import PAMR |
| | from .masker import DINOTextMasker |
| | from .templates import get_template |
| |
|
| | from .model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP |
| | from .hooks import average_text_tokens, get_vit_out, feats |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| |
|
| | class DINOText(nn.Module): |
| | |
| | def get_self_attention(self, module, input, output): |
| | self.feats['self_attn'] = output |
| | |
| | def get_clip_second_last_dense_out(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): |
| | self.feats['clip_second_last_out'] = output |
| | self.feats['clip_second_last_out'].to(dtype=torch.float32) |
| | |
| | def get_all_out_tokens(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): |
| | self.feats['clip_txt_out_tokens'] = output |
| | |
| | def __init__( |
| | self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True, |
| | unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs |
| | ): |
| | nn.Module.__init__(self) |
| |
|
| | self.feats = {} |
| | self.model_name = model_name |
| | |
| | |
| | if 'dinov2' in model_name: |
| | self.model_family = 'facebookresearch/dinov2' if 'dinov2' in model_name else 'facebookresearch/dino:main' |
| | self.model = torch.hub.load(self.model_family, model_name) |
| | elif 'dinov3' in model_name: |
| | def extract_dinov3_name(path, n_parts=2): |
| | filename = os.path.basename(path) |
| | parts = filename.split("_") |
| | return "_".join(parts[:n_parts]) |
| | self.model = torch.hub.load('src/dinov3', extract_dinov3_name(model_name), source='local', weights=model_name) |
| | |
| | |
| | elif 'mae' in model_name or 'sam' in model_name or 'clip' in model_name or 'dino' in model_name: |
| | self.model = timm.create_model( |
| | model_name, |
| | pretrained=True, |
| | num_classes=0, |
| | img_size=resize_dim |
| | ) |
| | |
| | if 'sam' in model_name: |
| | self.model.blocks[-1].register_forward_hook(get_vit_out) |
| | else: |
| | raise Exception("Unknown ViT model") |
| | |
| | mean = (0.485, 0.456, 0.406) if not 'clip' in model_name else (0.4815, 0.4578, 0.4082) |
| | std = (0.229, 0.224, 0.225) if not 'clip' in model_name else (0.2686, 0.2613, 0.2758) |
| | self.image_transforms = T.Compose([ |
| | T.Resize((resize_dim, resize_dim)), |
| | lambda x: T.ToTensor()(x) if not isinstance(x, torch.Tensor) else x / 255.0, |
| | T.Normalize(mean, std), |
| | ]) |
| | |
| | self.model |
| | self.model.requires_grad_(False) |
| | |
| | self.clip_model_name = clip_model_name |
| | if 'bert' in self.clip_model_name: |
| | self.clip_model = BertModel.from_pretrained(self.clip_model_name, output_hidden_states = False) |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name) |
| | else: |
| | self.clip_model, _ = clip.load(clip_model_name, device='meta') |
| | self.clip_model.eval() |
| | self.clip_model.requires_grad_(False) |
| | if unfreeze_last_text_layer: |
| | for param in self.clip_model.transformer.resblocks[-1].parameters(): |
| | param.requires_grad = True |
| | for param in self.clip_model.ln_final.parameters(): |
| | param.requires_grad = True |
| | self.clip_model.text_projection.requires_grad = True |
| | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
| |
|
| | |
| | |
| | if 'vitb_mlp_infonce' in proj_class: |
| | config = { |
| | 'act': 'tanh', |
| | 'hidden_layer': True, |
| | 'dino_embed_dim': 768 |
| | } |
| | elif 'vitl_mlp_infonce' in proj_class: |
| | config = { |
| | 'act': 'tanh', |
| | 'hidden_layer': True, |
| | 'dino_embed_dim': 1024 |
| | } |
| | |
| | self.proj = ProjectionLayer.from_config(config) |
| | |
| | |
| | |
| | |
| | self.proj |
| | |
| | self.masker = DINOTextMasker(similarity_type="cosine") |
| | self.masker = self.masker.eval() |
| | |
| | self.pamr = None |
| | |
| | self.avg_self_attn_token = avg_self_attn_token |
| | self.disentangled_self_attn_token = disentangled_self_attn_token |
| | |
| | if self.avg_self_attn_token or self.disentangled_self_attn_token or is_eval: |
| | self.model.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention) |
| | self.num_global_tokens = 5 if 'reg' in model_name or 'dinov3' in model_name else 1 |
| | if 'sam' in self.model_name: |
| | self.num_global_tokens = 0 |
| | self.num_attn_heads = self.model.num_heads |
| | self.scale = 0.125 |
| | |
| | self.use_avg_text_token = use_avg_text_token |
| | if self.use_avg_text_token: |
| | self.feats = {} |
| | |
| | self.clip_model.ln_final.register_forward_hook(self.get_all_out_tokens) |
| | self.keep_cls = keep_cls |
| | self.keep_end_seq = keep_end_seq |
| | |
| | self.with_bg_clean = with_bg_clean |
| |
|
| | |
| | def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False): |
| | qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0] * scale, qkv[1], qkv[2] |
| | attn = q @ k.transpose(-2, -1) |
| | self_attn_maps = attn[:, : , 0, num_global_tokens:] |
| | self_attn = self_attn_maps.mean(dim=1) |
| | self_attn = self_attn.softmax(dim=-1) |
| | if ret_self_attn_maps: |
| | return self_attn, self_attn_maps |
| | else: |
| | return self_attn |
| | |
| | def encode_text(self, tokenized_texts): |
| | x = self.clip_model.encode_text(tokenized_texts) |
| | return x |
| | |
| | def encode_image(self, images): |
| | batch_size, _, _, _ = images.shape |
| | self_attn_maps = None |
| | x = self.model(images, is_training=(self.avg_self_attn_token or self.disentangled_self_attn_token)) |
| | batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape |
| | num_tokens = num_tokens + self.num_global_tokens |
| | if self.avg_self_attn_token or self.disentangled_self_attn_token: |
| | self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True) |
| | if self.avg_self_attn_token: |
| | x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1) |
| | elif self.disentangled_self_attn_token: |
| | self_attn_maps = self_attn_maps.softmax(dim=-1) |
| | x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2) |
| |
|
| | return x, self_attn_maps |
| |
|
| | def forward(self, image, text, return_logit_scale=False): |
| | with torch.no_grad(): |
| | txt_embed = self.encode_text(text) |
| | |
| | img_embed, self_attn_maps = self.encode_image(image) |
| | |
| | if type(self.proj) == CLIPLastLayer: |
| | img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps, text_argmax=text.argmax(dim=-1)) |
| | else: |
| | img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps) |
| | |
| | if return_logit_scale: |
| | return txt_embed, img_embed, self.logit_scale |
| |
|
| | return txt_embed, img_embed |
| | |
| | def compute_loss(self, image, text, cosine=True, ret_similarity_matrix=True): |
| | ret = {} |
| | if cosine: |
| | img_embed = F.normalize(img_embed, p=2, dim=1) |
| | txt_embed = F.normalize(txt_embed, p=2, dim=1) |
| | sim = img_embed @ txt_embed.transpose(1, 0) |
| | if not ret_similarity_matrix: |
| | sim = sim[torch.eye(len(sim)) > 0.5] |
| | |
| | ret['contrastive_loss'] = self.contrastive_loss.compute_contrastive_loss(sim) |
| | |
| | return ret |
| |
|
| |
|
| | @torch.no_grad() |
| | def build_dataset_class_tokens(self, template_set, classnames): |
| | tokens = [] |
| | templates = get_template(template_set) |
| | for classname in classnames: |
| | if 'bert' not in self.clip_model_name: |
| | tokens.append( |
| | clip.tokenize([template.format(classname) for template in templates]) |
| | ) |
| | else: |
| | tokens.append(self.tokenizer([template.format(classname) for template in templates], return_tensors='pt', padding='max_length')['input_ids']) |
| | |
| | tokens = torch.stack(tokens) |
| |
|
| | return tokens |
| |
|
| | @torch.no_grad() |
| | def build_text_embedding(self, text): |
| | """ |
| | Args: |
| | text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH] text tokens |
| | |
| | Returns: |
| | text_embs |
| | """ |
| | text = text.to(next(self.parameters()).device) |
| | num_classes, num_templates = text.shape[:2] |
| | text_argmax = text.argmax(dim=-1) |
| | text_argmax = rearrange(text_argmax, 'n t -> (n t)', n=num_classes, t=num_templates) |
| | text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates) |
| | |
| | chunk_size = 32 |
| | N = text.size(0) |
| | if type(self.proj) == CLIPLastLayer: |
| | text_embs = torch.cat([ |
| | self.proj.project_clip_txt(self.encode_text(text[i:i + chunk_size]).permute(1, 0, 2), text_argmax=text_argmax[i:i + chunk_size]) |
| | for i in range(0, N, chunk_size) |
| | ]) |
| | else: |
| | if not self.use_avg_text_token: |
| | |
| | if 'bert' not in self.clip_model_name: |
| | text_embs = torch.cat([ |
| | self.clip_model.encode_text(text[i:i + chunk_size]) |
| | for i in range(0, N, chunk_size) |
| | ]) |
| | else: |
| | |
| | text_embs = [] |
| | for i in range(0, N, chunk_size): |
| | outputs = self.clip_model(text[i:i + chunk_size]) |
| | text_embs.append(outputs['pooler_output']) |
| | text_embs = torch.cat(text_embs) |
| | else: |
| | |
| | text_embs = [] |
| | for i in range(0, N, chunk_size): |
| | self.clip_model.encode_text(text[i:i + chunk_size]) |
| | text_embs.append(average_text_tokens(self.feats['clip_txt_out_tokens'] @ self.clip_model.text_projection, text[i:i + chunk_size] > 0, self.keep_cls, self.keep_end_seq)) |
| | text_embs = torch.cat(text_embs) |
| | |
| | text_embs = rearrange(text_embs, '(n t) c -> n t c', n=num_classes, t=num_templates) |
| | |
| | text_embs = text_embs.mean(dim=1).float() |
| | if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP: |
| | text_embs = self.proj.project_clip_txt(text_embs) |
| | text_embs = normalize(text_embs, dim=-1) |
| |
|
| | return text_embs |
| |
|
| | def apply_pamr(self, image, mask): |
| | image = F.interpolate(image, mask.shape[-2:], mode="bilinear", align_corners=True) |
| | if self.pamr is None: |
| | pamr_iter = 10 |
| | pamr_kernel = [1, 2, 4, 8, 12, 24] |
| | self.pamr = PAMR(pamr_iter, pamr_kernel) |
| | self.pamr.eval() |
| | self.pamr.to(next(self.parameters()).device) |
| |
|
| | mask = self.pamr(image, mask) |
| | return mask |
| |
|
| | def compute_padsize(self, H: int, W: int, patch_size: int): |
| | l, r, t, b = 0, 0, 0, 0 |
| | if W % patch_size: |
| | lr = patch_size - (W % patch_size) |
| | l = lr // 2 |
| | r = lr - l |
| |
|
| | if H % patch_size: |
| | tb = patch_size - (H % patch_size) |
| | t = tb // 2 |
| | b = tb - t |
| |
|
| | return l, r, t, b |
| | |
| | @torch.no_grad() |
| | def generate_masks( |
| | self, image, img_metas, text_emb, classnames, text_is_token=False, apply_pamr=False, background_func="weighted_average_sigmoid", lambda_bg=0.2, |
| | |
| | ): |
| | """Generate masks for each text embeddings |
| | |
| | Args: |
| | image [B, 3, H, W] |
| | |
| | Returns: |
| | softmask [B, N, H, W]: softmasks for each text embeddings |
| | """ |
| |
|
| | H, W = image.shape[2:] |
| |
|
| | |
| | pH, pW = image.shape[2:] |
| | num_classes = text_emb.shape[0] |
| | batch_size = image.shape[0] |
| |
|
| | image = image[:, [2, 1, 0], :, :] |
| | ori_image = image.clone() |
| | |
| | img_preprocessed = self.image_transforms(image).to(next(self.parameters()).device) |
| | if 'dinov2' in self.model_name or 'dinov3' in self.model_name: |
| | image_feat = self.model.forward_features(img_preprocessed)['x_norm_patchtokens'] |
| | elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name: |
| | image_feat = self.model.forward_features(img_preprocessed)[:, 1:, :] |
| | elif 'sam' in self.model_name: |
| | self.model.forward_features(img_preprocessed) |
| | image_feat = feats['vit_out'].reshape(feats['vit_out'].shape[0], feats['vit_out'].shape[1]**2, feats['vit_out'].shape[-1]) |
| | |
| | batch_size, num_tokens, embed_dim = image_feat.shape |
| | if type(self.proj) == VisualProjectionLayer: |
| | image_feat = self.proj.project_dino(image_feat.float()) |
| | if type(self.proj) == DoubleMLP: |
| | image_feat = self.proj.project_visual(image_feat.float()) |
| | b, np, c = image_feat.shape |
| | np_h = np_w = int(sqrt(np)) |
| | image_feat = image_feat.reshape(b, np_h, np_w, c).permute(0, 3, 1, 2) |
| | |
| | self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens + self.num_global_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True) |
| | mask, simmap = self.masker.forward_seg(image_feat, text_emb, hard=False) |
| | |
| | if self.with_bg_clean: |
| | mask = self.similarity_assignment_weighted(mask, image_feat, self_attn_maps, text_emb, lambda_bg) |
| |
|
| | |
| | mask = F.interpolate(mask, (pH, pW), mode='bilinear', align_corners=True) |
| |
|
| | if apply_pamr: |
| | for c in range(0, mask.shape[1], 30): |
| | mask[:, c:c + 30] = self.apply_pamr(ori_image, mask[:, c:c + 30]) |
| |
|
| | assert mask.shape[2] == H and mask.shape[3] == W, f"shape mismatch: ({H}, {W}) / {mask.shape}" |
| |
|
| | return mask, simmap |
| | |
| | def similarity_assignment_weighted(self, mask, image_feat, self_attn_maps, text_emb, lambda_bg=0.2): |
| | bs, c, h, w = image_feat.shape |
| | bs, num_classes, h, w = mask.shape |
| | bs, num_heads, hw = self_attn_maps.shape |
| | image_feat = image_feat.reshape(bs, c, hw) |
| | num_classes, c = text_emb.shape |
| | avg_head_embed = (self_attn_maps.unsqueeze(2) * image_feat.unsqueeze(1)).mean(dim=-1) |
| | avg_head_embed = avg_head_embed / avg_head_embed.norm(dim=-1, keepdim=True) |
| | avg_head_embed = avg_head_embed.permute(0, 2, 1) |
| | head_text_sim = text_emb.unsqueeze(0) @ avg_head_embed |
| | head_text_sim = (head_text_sim).softmax(dim=-1) |
| | head_text_sim_sum = head_text_sim.sum(dim=-1) |
| | |
| | self_attn_maps_repeat = self_attn_maps.unsqueeze(1).repeat(1, num_classes, 1, 1) |
| | head_text_sim_repeat = head_text_sim.unsqueeze(-1).repeat(1, 1, 1, hw) |
| | avg_self_attn_per_class = (self_attn_maps_repeat * head_text_sim_repeat).sum(dim=2) / head_text_sim_sum.unsqueeze(-1).repeat(1, 1, hw) |
| | avg_self_attn_per_class = avg_self_attn_per_class.softmax(dim=-1) |
| | |
| | min_self_attn = avg_self_attn_per_class.min().item() |
| | max_self_attn = avg_self_attn_per_class.max().item() |
| | max_self_attn = max(max_self_attn, max_self_attn - min_self_attn) |
| | avg_self_attn_per_class = avg_self_attn_per_class - min_self_attn |
| | avg_self_attn_per_class = avg_self_attn_per_class / max_self_attn |
| | avg_self_attn_per_class = avg_self_attn_per_class * (mask.max() - mask.min()) + mask.min() |
| | mask = mask.reshape(num_classes, hw) |
| | mask_output = (mask + lambda_bg * avg_self_attn_per_class).reshape(bs, num_classes, h, w) / (1 + lambda_bg) |
| | return mask_output |
| |
|
| |
|
| |
|
| |
|