| | import os.path as osp |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from torch.cuda.amp import GradScaler, autocast |
| | from collections import OrderedDict |
| |
|
| | from dassl.engine import TRAINER_REGISTRY, TrainerX |
| | from dassl.metrics import compute_accuracy |
| | from dassl.utils import load_pretrained_weights, load_checkpoint |
| | from dassl.optim import build_optimizer, build_lr_scheduler |
| |
|
| | from clip import clip |
| | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer |
| |
|
| | _tokenizer = _Tokenizer() |
| |
|
| |
|
| | def load_clip_to_cpu(cfg): |
| | backbone_name = cfg.MODEL.BACKBONE.NAME |
| | url = clip._MODELS[backbone_name] |
| | model_path = clip._download(url) |
| |
|
| | try: |
| | |
| | model = torch.jit.load(model_path, map_location="cpu").eval() |
| | state_dict = None |
| |
|
| | except RuntimeError: |
| | state_dict = torch.load(model_path, map_location="cpu") |
| | design_details = {"trainer": 'CoOp', |
| | "vision_depth": 0, |
| | "language_depth": 0, "vision_ctx": 0, |
| | "language_ctx": 0} |
| | model = clip.build_model(state_dict or model.state_dict(), design_details) |
| |
|
| | return model |
| |
|
| | |
| | CUSTOM_TEMPLATES = { |
| | "OxfordPets": "a photo of a {}, a type of pet.", |
| | "OxfordFlowers": "a photo of a {}, a type of flower.", |
| | "FGVCAircraft": "a photo of a {}, a type of aircraft.", |
| | "DescribableTextures": "a photo of a {}, a type of texture.", |
| | "EuroSAT": "a centered satellite photo of {}.", |
| | |
| | "StanfordCars": "a photo of a {}.", |
| | "Food101": "a photo of {}, a type of food.", |
| | "SUN397": "a photo of a {}.", |
| | "Caltech101": "a photo of a {}.", |
| | "UCF101": "a photo of a person doing {}.", |
| | "ImageNet": "a photo of a {}.", |
| | "ImageNetSketch": "a photo of a {}.", |
| | "ImageNetV2": "a photo of a {}.", |
| | "ImageNetA": "a photo of a {}.", |
| | "ImageNetR": "a photo of a {}.", |
| | } |
| |
|
| |
|
| |
|
| |
|
| | class TextEncoder(nn.Module): |
| | def __init__(self, clip_model): |
| | super().__init__() |
| | self.transformer = clip_model.transformer |
| | self.positional_embedding = clip_model.positional_embedding |
| | self.ln_final = clip_model.ln_final |
| | self.text_projection = clip_model.text_projection |
| | self.dtype = clip_model.dtype |
| |
|
| | def forward(self, prompts, tokenized_prompts): |
| | x = prompts + self.positional_embedding.type(self.dtype) |
| | x = x.permute(1, 0, 2) |
| | x = self.transformer(x) |
| | x = x.permute(1, 0, 2) |
| | x = self.ln_final(x).type(self.dtype) |
| |
|
| | |
| | |
| | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection |
| |
|
| | return x |
| |
|
| |
|
| | class PromptLearner(nn.Module): |
| | def __init__(self, cfg, classnames, clip_model): |
| | super().__init__() |
| | n_cls = len(classnames) |
| | n_ctx = cfg.TRAINER.COOP.N_CTX |
| | ctx_init = cfg.TRAINER.COOP.CTX_INIT |
| | dtype = clip_model.dtype |
| | ctx_dim = clip_model.ln_final.weight.shape[0] |
| | clip_imsize = clip_model.visual.input_resolution |
| | cfg_imsize = cfg.INPUT.SIZE[0] |
| | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" |
| |
|
| | if ctx_init: |
| | |
| | temp = 'a photo of a' |
| | ctx_init = temp.replace("_", " ") |
| | n_ctx = len(ctx_init.split(" ")) |
| | prompt = clip.tokenize(ctx_init) |
| | with torch.no_grad(): |
| | embedding = clip_model.token_embedding(prompt).type(dtype) |
| | |
| | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] |
| | prompt_prefix = ctx_init |
| |
|
| | else: |
| | |
| | if cfg.TRAINER.COOP.CSC: |
| | print("Initializing class-specific contexts") |
| | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) |
| | else: |
| | print("Initializing a generic context") |
| | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) |
| | nn.init.normal_(ctx_vectors, std=0.02) |
| | prompt_prefix = " ".join(["X"] * n_ctx) |
| |
|
| |
|
| | print(f'Initial context: "{prompt_prefix}"') |
| | print(f"Number of context words (tokens): {n_ctx}") |
| |
|
| | self.ctx = nn.Parameter(ctx_vectors) |
| |
|
| | bias_vectors = torch.empty(1, 512, dtype=dtype) |
| | nn.init.normal_(bias_vectors, std=0.02) |
| | self.bias_vectors = nn.Parameter(bias_vectors) |
| |
|
| | classnames = [name.replace("_", " ") for name in classnames] |
| | name_lens = [len(_tokenizer.encode(name)) for name in classnames] |
| | prompts = [prompt_prefix + " " + name + "." for name in classnames] |
| |
|
| | |
| | clip_model_ = load_clip_to_cpu(cfg) |
| | clip_model_.cuda() |
| | |
| | |
| | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] |
| | prompts_ = [temp.format(c.replace("_", " ")) for c in classnames] |
| | print(f"Prompts: {prompts_}") |
| | prompts_ = torch.cat([clip.tokenize(p) for p in prompts_]) |
| | prompts_ = prompts_.cuda() |
| |
|
| | with torch.no_grad(): |
| | text_features = clip_model_.encode_text(prompts_) |
| | text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
| |
|
| | self.text_features = text_features |
| |
|
| | self.meta_net = nn.Sequential(OrderedDict([ |
| | ("linear1", nn.Linear(512, 512)), |
| | ("relu", nn.ReLU(inplace=True)) |
| | |
| | ])) |
| |
|
| |
|
| | if cfg.TRAINER.COCOOP.PREC == "fp16": |
| | self.meta_net.half() |
| |
|
| |
|
| | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) |
| | with torch.no_grad(): |
| | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) |
| | |
| | |
| | |
| | self.register_buffer("token_prefix", embedding[:, :1, :]) |
| | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) |
| |
|
| |
|
| | self.n_cls = n_cls |
| | self.n_ctx = n_ctx |
| | self.tokenized_prompts = tokenized_prompts |
| | self.name_lens = name_lens |
| | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION |
| |
|
| | def forward(self): |
| | ctx = self.ctx |
| |
|
| | if ctx.dim() == 2: |
| | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) |
| | |
| | prefix = self.token_prefix |
| | suffix = self.token_suffix |
| |
|
| | prompts = torch.cat( |
| | [ |
| | prefix, |
| | ctx, |
| | suffix, |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | return prompts |
| |
|
| |
|
| | class Adapter(nn.Module): |
| | def __init__(self, c_in, reduction=4): |
| | super(Adapter, self).__init__() |
| | self.fc = nn.Sequential( |
| | nn.Linear(c_in, c_in // reduction, bias=False), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(c_in // reduction, c_in, bias=False), |
| | nn.ReLU(inplace=True) |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.fc(x) |
| | return x |
| |
|
| | class CustomCLIP(nn.Module): |
| | def __init__(self, cfg, classnames, clip_model): |
| | super().__init__() |
| | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) |
| | self.tokenized_prompts = self.prompt_learner.tokenized_prompts |
| | self.ori_embedding = self.prompt_learner.text_features |
| | self.image_encoder = clip_model.visual |
| | self.text_encoder = TextEncoder(clip_model) |
| | self.logit_scale = clip_model.logit_scale |
| | self.dtype = clip_model.dtype |
| | self.meta_net = self.prompt_learner.meta_net |
| | self.adapter = Adapter(512, 4).to(clip_model.dtype) |
| |
|
| | def forward(self, image): |
| | prompts = self.prompt_learner() |
| | image_features = self.image_encoder(image.type(self.dtype)) |
| |
|
| | tokenized_prompts = self.tokenized_prompts |
| | text_features = self.text_encoder(prompts, tokenized_prompts) |
| | text_features_old = self.ori_embedding |
| |
|
| |
|
| | image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
| | text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
| | logit_scale = self.logit_scale.exp() |
| |
|
| | logits = logit_scale * image_features @ text_features.t() |
| |
|
| | cos = torch.nn.CosineSimilarity(dim=1,eps=1e-07) |
| | text_features_old = text_features_old / text_features_old.norm(dim=-1, keepdim=True) |
| | score = cos(text_features,text_features_old) |
| | score = 1.0-torch.mean(score) |
| |
|
| | return logits, score |
| |
|
| |
|
| | @TRAINER_REGISTRY.register() |
| | class KgCoOp(TrainerX): |
| |
|
| | def check_cfg(self, cfg): |
| | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] |
| |
|
| | def build_model(self): |
| | cfg = self.cfg |
| | classnames = self.dm.dataset.classnames |
| |
|
| | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") |
| | clip_model = load_clip_to_cpu(cfg) |
| | |
| | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": |
| | |
| | clip_model.float() |
| |
|
| | print("Building custom CLIP") |
| | self.model = CustomCLIP(cfg, classnames, clip_model) |
| | self.w = cfg.TRAINER.COOP.W |
| |
|
| | print("Turning off gradients in both the image and the text encoder") |
| | for name, param in self.model.named_parameters(): |
| | |
| | if "ctx" not in name: |
| | param.requires_grad_(False) |
| | else: |
| | print(name) |
| |
|
| | if cfg.MODEL.INIT_WEIGHTS: |
| | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) |
| |
|
| | self.model.to(self.device) |
| | |
| | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) |
| | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) |
| | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) |
| | |
| | |
| | |
| | |
| |
|
| | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None |
| |
|
| | |
| | |
| | device_count = torch.cuda.device_count() |
| | if device_count > 1: |
| | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") |
| | self.model = nn.DataParallel(self.model) |
| |
|
| | def forward_backward(self, batch): |
| | image, label = self.parse_batch_train(batch) |
| | prec = self.cfg.TRAINER.COOP.PREC |
| | if prec == "amp": |
| | with autocast(): |
| | output = self.model(image) |
| | loss = F.cross_entropy(output, label) |
| | self.optim.zero_grad() |
| | self.scaler.scale(loss).backward() |
| | self.scaler.step(self.optim) |
| | self.scaler.update() |
| | else: |
| | output,score = self.model(image) |
| | loss = F.cross_entropy(output, label)+self.w*score |
| | self.model_backward_and_update(loss) |
| |
|
| | loss_summary = { |
| | "loss": loss.item(), |
| | "acc": compute_accuracy(output, label)[0].item(), |
| | } |
| |
|
| | if (self.batch_idx + 1) == self.num_batches: |
| | |
| | self.sched.step() |
| | |
| | return loss_summary |
| |
|
| | def parse_batch_train(self, batch): |
| | input = batch["img"] |
| | label = batch["label"] |
| | input = input.to(self.device) |
| | label = label.to(self.device) |
| | return input, label |
| |
|
| | def model_inference(self, input): |
| | return self.model(input)[0] |
| |
|
| |
|
| | def load_model(self, directory, epoch=None): |
| | if not directory: |
| | print("Note that load_model() is skipped as no pretrained model is given") |
| | return |
| |
|
| | names = self.get_model_names() |
| | print(names) |
| |
|
| | |
| | model_file = "model-best.pth.tar" |
| |
|
| | if epoch is not None: |
| | model_file = "model.pth.tar-" + str(epoch) |
| |
|
| | for name in names: |
| | model_path = osp.join(directory, name, model_file) |
| |
|
| | if not osp.exists(model_path): |
| | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) |
| |
|
| | checkpoint = load_checkpoint(model_path) |
| | state_dict = checkpoint["state_dict"] |
| | epoch = checkpoint["epoch"] |
| |
|
| | |
| | if "token_prefix" in state_dict: |
| | del state_dict["token_prefix"] |
| |
|
| | if "token_suffix" in state_dict: |
| | del state_dict["token_suffix"] |
| |
|
| | if "token_midfix" in state_dict: |
| | del state_dict["token_midfix"] |
| |
|
| | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) |
| | |
| | self._models[name].load_state_dict(state_dict, strict=False) |