Spaces:
Runtime error
Runtime error
| # Copyright by HQ-SAM team | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import torch.optim as optim | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import random | |
| from typing import Dict, List, Tuple | |
| from segment_anything_training import sam_model_registry | |
| from segment_anything_training.modeling import TwoWayTransformer, MaskDecoder | |
| from utils.dataloader import get_im_gt_name_dict, create_dataloaders, RandomHFlip, Resize, LargeScaleJitter | |
| from utils.loss_mask import loss_masks | |
| import utils.misc as misc | |
| class LayerNorm2d(nn.Module): | |
| def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(num_channels)) | |
| self.bias = nn.Parameter(torch.zeros(num_channels)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| output_dim: int, | |
| num_layers: int, | |
| sigmoid_output: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList( | |
| nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) | |
| ) | |
| self.sigmoid_output = sigmoid_output | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| if self.sigmoid_output: | |
| x = F.sigmoid(x) | |
| return x | |
| class MaskDecoderHQ(MaskDecoder): | |
| def __init__(self, model_type): | |
| super().__init__(transformer_dim=256, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=256, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| num_multimask_outputs=3, | |
| activation=nn.GELU, | |
| iou_head_depth= 3, | |
| iou_head_hidden_dim= 256,) | |
| assert model_type in ["vit_b","vit_l","vit_h"] | |
| checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth", | |
| "vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth", | |
| 'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"} | |
| checkpoint_path = checkpoint_dict[model_type] | |
| self.load_state_dict(torch.load(checkpoint_path)) | |
| print("HQ Decoder init from SAM MaskDecoder") | |
| for n,p in self.named_parameters(): | |
| p.requires_grad = False | |
| transformer_dim=256 | |
| vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280} | |
| vit_dim = vit_dim_dict[model_type] | |
| self.hf_token = nn.Embedding(1, transformer_dim) | |
| self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) | |
| self.num_mask_tokens = self.num_mask_tokens + 1 | |
| self.compress_vit_feat = nn.Sequential( | |
| nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), | |
| LayerNorm2d(transformer_dim), | |
| nn.GELU(), | |
| nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) | |
| self.embedding_encoder = nn.Sequential( | |
| nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), | |
| LayerNorm2d(transformer_dim // 4), | |
| nn.GELU(), | |
| nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), | |
| ) | |
| self.embedding_maskfeature = nn.Sequential( | |
| nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), | |
| LayerNorm2d(transformer_dim // 4), | |
| nn.GELU(), | |
| nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)) | |
| def forward( | |
| self, | |
| image_embeddings: torch.Tensor, | |
| image_pe: torch.Tensor, | |
| sparse_prompt_embeddings: torch.Tensor, | |
| dense_prompt_embeddings: torch.Tensor, | |
| multimask_output: bool, | |
| hq_token_only: bool, | |
| interm_embeddings: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Predict masks given image and prompt embeddings. | |
| Arguments: | |
| image_embeddings (torch.Tensor): the embeddings from the ViT image encoder | |
| image_pe (torch.Tensor): positional encoding with the shape of image_embeddings | |
| sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes | |
| dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs | |
| multimask_output (bool): Whether to return multiple masks or a single | |
| mask. | |
| Returns: | |
| torch.Tensor: batched predicted hq masks | |
| """ | |
| vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT | |
| hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) | |
| batch_len = len(image_embeddings) | |
| masks = [] | |
| iou_preds = [] | |
| for i_batch in range(batch_len): | |
| mask, iou_pred = self.predict_masks( | |
| image_embeddings=image_embeddings[i_batch].unsqueeze(0), | |
| image_pe=image_pe[i_batch], | |
| sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch], | |
| dense_prompt_embeddings=dense_prompt_embeddings[i_batch], | |
| hq_feature = hq_features[i_batch].unsqueeze(0) | |
| ) | |
| masks.append(mask) | |
| iou_preds.append(iou_pred) | |
| masks = torch.cat(masks,0) | |
| iou_preds = torch.cat(iou_preds,0) | |
| # Select the correct mask or masks for output | |
| if multimask_output: | |
| # mask with highest score | |
| mask_slice = slice(1,self.num_mask_tokens-1) | |
| iou_preds = iou_preds[:, mask_slice] | |
| iou_preds, max_iou_idx = torch.max(iou_preds,dim=1) | |
| iou_preds = iou_preds.unsqueeze(1) | |
| masks_multi = masks[:, mask_slice, :, :] | |
| masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) | |
| else: | |
| # singale mask output, default | |
| mask_slice = slice(0, 1) | |
| masks_sam = masks[:,mask_slice] | |
| masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :] | |
| if hq_token_only: | |
| return masks_hq | |
| else: | |
| return masks_sam, masks_hq | |
| def predict_masks( | |
| self, | |
| image_embeddings: torch.Tensor, | |
| image_pe: torch.Tensor, | |
| sparse_prompt_embeddings: torch.Tensor, | |
| dense_prompt_embeddings: torch.Tensor, | |
| hq_feature: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Predicts masks. See 'forward' for more details.""" | |
| output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) | |
| output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) | |
| tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) | |
| # Expand per-image data in batch direction to be per-mask | |
| src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) | |
| src = src + dense_prompt_embeddings | |
| pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) | |
| b, c, h, w = src.shape | |
| # Run the transformer | |
| hs, src = self.transformer(src, pos_src, tokens) | |
| iou_token_out = hs[:, 0, :] | |
| mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] | |
| # Upscale mask embeddings and predict masks using the mask tokens | |
| src = src.transpose(1, 2).view(b, c, h, w) | |
| upscaled_embedding_sam = self.output_upscaling(src) | |
| upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + hq_feature | |
| hyper_in_list: List[torch.Tensor] = [] | |
| for i in range(self.num_mask_tokens): | |
| if i < 4: | |
| hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) | |
| else: | |
| hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) | |
| hyper_in = torch.stack(hyper_in_list, dim=1) | |
| b, c, h, w = upscaled_embedding_sam.shape | |
| masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) | |
| masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w) | |
| masks = torch.cat([masks_sam,masks_ours],dim=1) | |
| iou_pred = self.iou_prediction_head(iou_token_out) | |
| return masks, iou_pred | |
| def show_anns(masks, input_point, input_box, input_label, filename, image, ious, boundary_ious): | |
| if len(masks) == 0: | |
| return | |
| for i, (mask, iou, biou) in enumerate(zip(masks, ious, boundary_ious)): | |
| plt.figure(figsize=(10,10)) | |
| plt.imshow(image) | |
| show_mask(mask, plt.gca()) | |
| if input_box is not None: | |
| show_box(input_box, plt.gca()) | |
| if (input_point is not None) and (input_label is not None): | |
| show_points(input_point, input_label, plt.gca()) | |
| plt.axis('off') | |
| plt.savefig(filename+'_'+str(i)+'.png',bbox_inches='tight',pad_inches=-0.1) | |
| plt.close() | |
| def show_mask(mask, ax, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30/255, 144/255, 255/255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def show_points(coords, labels, ax, marker_size=375): | |
| pos_points = coords[labels==1] | |
| neg_points = coords[labels==0] | |
| ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| def show_box(box, ax): | |
| x0, y0 = box[0], box[1] | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser('HQ-SAM', add_help=False) | |
| parser.add_argument("--output", type=str, required=True, | |
| help="Path to the directory where masks and checkpoints will be output") | |
| parser.add_argument("--model-type", type=str, default="vit_l", | |
| help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']") | |
| parser.add_argument("--checkpoint", type=str, required=True, | |
| help="The path to the SAM checkpoint to use for mask generation.") | |
| parser.add_argument("--device", type=str, default="cuda", | |
| help="The device to run generation on.") | |
| parser.add_argument('--seed', default=42, type=int) | |
| parser.add_argument('--learning_rate', default=1e-3, type=float) | |
| parser.add_argument('--start_epoch', default=0, type=int) | |
| parser.add_argument('--lr_drop_epoch', default=10, type=int) | |
| parser.add_argument('--max_epoch_num', default=12, type=int) | |
| parser.add_argument('--input_size', default=[1024,1024], type=list) | |
| parser.add_argument('--batch_size_train', default=4, type=int) | |
| parser.add_argument('--batch_size_valid', default=1, type=int) | |
| parser.add_argument('--model_save_fre', default=1, type=int) | |
| parser.add_argument('--world_size', default=1, type=int, | |
| help='number of distributed processes') | |
| parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') | |
| parser.add_argument('--rank', default=0, type=int, | |
| help='number of distributed processes') | |
| parser.add_argument('--local_rank', type=int, help='local rank for dist') | |
| parser.add_argument('--find_unused_params', action='store_true') | |
| parser.add_argument('--eval', action='store_true') | |
| parser.add_argument('--visualize', action='store_true') | |
| parser.add_argument("--restore-model", type=str, | |
| help="The path to the hq_decoder training checkpoint for evaluation") | |
| return parser.parse_args() | |
| def main(net, train_datasets, valid_datasets, args): | |
| misc.init_distributed_mode(args) | |
| print('world size: {}'.format(args.world_size)) | |
| print('rank: {}'.format(args.rank)) | |
| print('local_rank: {}'.format(args.local_rank)) | |
| print("args: " + str(args) + '\n') | |
| seed = args.seed + misc.get_rank() | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| ### --- Step 1: Train or Valid dataset --- | |
| if not args.eval: | |
| print("--- create training dataloader ---") | |
| train_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") | |
| train_dataloaders, train_datasets = create_dataloaders(train_im_gt_list, | |
| my_transforms = [ | |
| RandomHFlip(), | |
| LargeScaleJitter() | |
| ], | |
| batch_size = args.batch_size_train, | |
| training = True) | |
| print(len(train_dataloaders), " train dataloaders created") | |
| print("--- create valid dataloader ---") | |
| valid_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") | |
| valid_dataloaders, valid_datasets = create_dataloaders(valid_im_gt_list, | |
| my_transforms = [ | |
| Resize(args.input_size) | |
| ], | |
| batch_size=args.batch_size_valid, | |
| training=False) | |
| print(len(valid_dataloaders), " valid dataloaders created") | |
| ### --- Step 2: DistributedDataParallel--- | |
| if torch.cuda.is_available(): | |
| net.cuda() | |
| net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params) | |
| net_without_ddp = net.module | |
| ### --- Step 3: Train or Evaluate --- | |
| if not args.eval: | |
| print("--- define optimizer ---") | |
| optimizer = optim.Adam(net_without_ddp.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) | |
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop_epoch) | |
| lr_scheduler.last_epoch = args.start_epoch | |
| train(args, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler) | |
| else: | |
| sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) | |
| _ = sam.to(device=args.device) | |
| sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params) | |
| if args.restore_model: | |
| print("restore model from:", args.restore_model) | |
| if torch.cuda.is_available(): | |
| net_without_ddp.load_state_dict(torch.load(args.restore_model)) | |
| else: | |
| net_without_ddp.load_state_dict(torch.load(args.restore_model,map_location="cpu")) | |
| evaluate(args, net, sam, valid_dataloaders, args.visualize) | |
| def train(args, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler): | |
| if misc.is_main_process(): | |
| os.makedirs(args.output, exist_ok=True) | |
| epoch_start = args.start_epoch | |
| epoch_num = args.max_epoch_num | |
| train_num = len(train_dataloaders) | |
| net.train() | |
| _ = net.to(device=args.device) | |
| sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) | |
| _ = sam.to(device=args.device) | |
| sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params) | |
| for epoch in range(epoch_start,epoch_num): | |
| print("epoch: ",epoch, " learning rate: ", optimizer.param_groups[0]["lr"]) | |
| metric_logger = misc.MetricLogger(delimiter=" ") | |
| train_dataloaders.batch_sampler.sampler.set_epoch(epoch) | |
| for data in metric_logger.log_every(train_dataloaders,1000): | |
| inputs, labels = data['image'], data['label'] | |
| if torch.cuda.is_available(): | |
| inputs = inputs.cuda() | |
| labels = labels.cuda() | |
| imgs = inputs.permute(0, 2, 3, 1).cpu().numpy() | |
| # input prompt | |
| input_keys = ['box','point','noise_mask'] | |
| labels_box = misc.masks_to_boxes(labels[:,0,:,:]) | |
| try: | |
| labels_points = misc.masks_sample_points(labels[:,0,:,:]) | |
| except: | |
| # less than 10 points | |
| input_keys = ['box','noise_mask'] | |
| labels_256 = F.interpolate(labels, size=(256, 256), mode='bilinear') | |
| labels_noisemask = misc.masks_noise(labels_256) | |
| batched_input = [] | |
| for b_i in range(len(imgs)): | |
| dict_input = dict() | |
| input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous() | |
| dict_input['image'] = input_image | |
| input_type = random.choice(input_keys) | |
| if input_type == 'box': | |
| dict_input['boxes'] = labels_box[b_i:b_i+1] | |
| elif input_type == 'point': | |
| point_coords = labels_points[b_i:b_i+1] | |
| dict_input['point_coords'] = point_coords | |
| dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] | |
| elif input_type == 'noise_mask': | |
| dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1] | |
| else: | |
| raise NotImplementedError | |
| dict_input['original_size'] = imgs[b_i].shape[:2] | |
| batched_input.append(dict_input) | |
| with torch.no_grad(): | |
| batched_output, interm_embeddings = sam(batched_input, multimask_output=False) | |
| batch_len = len(batched_output) | |
| encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) | |
| image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] | |
| sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] | |
| dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] | |
| masks_hq = net( | |
| image_embeddings=encoder_embedding, | |
| image_pe=image_pe, | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=False, | |
| hq_token_only=True, | |
| interm_embeddings=interm_embeddings, | |
| ) | |
| loss_mask, loss_dice = loss_masks(masks_hq, labels/255.0, len(masks_hq)) | |
| loss = loss_mask + loss_dice | |
| loss_dict = {"loss_mask": loss_mask, "loss_dice":loss_dice} | |
| # reduce losses over all GPUs for logging purposes | |
| loss_dict_reduced = misc.reduce_dict(loss_dict) | |
| losses_reduced_scaled = sum(loss_dict_reduced.values()) | |
| loss_value = losses_reduced_scaled.item() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| metric_logger.update(training_loss=loss_value, **loss_dict_reduced) | |
| print("Finished epoch: ", epoch) | |
| metric_logger.synchronize_between_processes() | |
| print("Averaged stats:", metric_logger) | |
| train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} | |
| lr_scheduler.step() | |
| test_stats = evaluate(args, net, sam, valid_dataloaders) | |
| train_stats.update(test_stats) | |
| net.train() | |
| if epoch % args.model_save_fre == 0: | |
| model_name = "/epoch_"+str(epoch)+".pth" | |
| print('come here save at', args.output + model_name) | |
| misc.save_on_master(net.module.state_dict(), args.output + model_name) | |
| # Finish training | |
| print("Training Reaches The Maximum Epoch Number") | |
| # merge sam and hq_decoder | |
| if misc.is_main_process(): | |
| sam_ckpt = torch.load(args.checkpoint) | |
| hq_decoder = torch.load(args.output + model_name) | |
| for key in hq_decoder.keys(): | |
| sam_key = 'mask_decoder.'+key | |
| if sam_key not in sam_ckpt.keys(): | |
| sam_ckpt[sam_key] = hq_decoder[key] | |
| model_name = "/sam_hq_epoch_"+str(epoch)+".pth" | |
| torch.save(sam_ckpt, args.output + model_name) | |
| def compute_iou(preds, target): | |
| assert target.shape[1] == 1, 'only support one mask per image now' | |
| if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]): | |
| postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False) | |
| else: | |
| postprocess_preds = preds | |
| iou = 0 | |
| for i in range(0,len(preds)): | |
| iou = iou + misc.mask_iou(postprocess_preds[i],target[i]) | |
| return iou / len(preds) | |
| def compute_boundary_iou(preds, target): | |
| assert target.shape[1] == 1, 'only support one mask per image now' | |
| if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]): | |
| postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False) | |
| else: | |
| postprocess_preds = preds | |
| iou = 0 | |
| for i in range(0,len(preds)): | |
| iou = iou + misc.boundary_iou(target[i],postprocess_preds[i]) | |
| return iou / len(preds) | |
| def evaluate(args, net, sam, valid_dataloaders, visualize=False): | |
| net.eval() | |
| print("Validating...") | |
| test_stats = {} | |
| for k in range(len(valid_dataloaders)): | |
| metric_logger = misc.MetricLogger(delimiter=" ") | |
| valid_dataloader = valid_dataloaders[k] | |
| print('valid_dataloader len:', len(valid_dataloader)) | |
| for data_val in metric_logger.log_every(valid_dataloader,1000): | |
| imidx_val, inputs_val, labels_val, shapes_val, labels_ori = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'], data_val['ori_label'] | |
| if torch.cuda.is_available(): | |
| inputs_val = inputs_val.cuda() | |
| labels_val = labels_val.cuda() | |
| labels_ori = labels_ori.cuda() | |
| imgs = inputs_val.permute(0, 2, 3, 1).cpu().numpy() | |
| labels_box = misc.masks_to_boxes(labels_val[:,0,:,:]) | |
| input_keys = ['box'] | |
| batched_input = [] | |
| for b_i in range(len(imgs)): | |
| dict_input = dict() | |
| input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous() | |
| dict_input['image'] = input_image | |
| input_type = random.choice(input_keys) | |
| if input_type == 'box': | |
| dict_input['boxes'] = labels_box[b_i:b_i+1] | |
| elif input_type == 'point': | |
| point_coords = labels_points[b_i:b_i+1] | |
| dict_input['point_coords'] = point_coords | |
| dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] | |
| elif input_type == 'noise_mask': | |
| dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1] | |
| else: | |
| raise NotImplementedError | |
| dict_input['original_size'] = imgs[b_i].shape[:2] | |
| batched_input.append(dict_input) | |
| with torch.no_grad(): | |
| batched_output, interm_embeddings = sam(batched_input, multimask_output=False) | |
| batch_len = len(batched_output) | |
| encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) | |
| image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] | |
| sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] | |
| dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] | |
| masks_sam, masks_hq = net( | |
| image_embeddings=encoder_embedding, | |
| image_pe=image_pe, | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=False, | |
| hq_token_only=False, | |
| interm_embeddings=interm_embeddings, | |
| ) | |
| iou = compute_iou(masks_hq,labels_ori) | |
| boundary_iou = compute_boundary_iou(masks_hq,labels_ori) | |
| if visualize: | |
| print("visualize") | |
| os.makedirs(args.output, exist_ok=True) | |
| masks_hq_vis = (F.interpolate(masks_hq.detach(), (1024, 1024), mode="bilinear", align_corners=False) > 0).cpu() | |
| for ii in range(len(imgs)): | |
| base = data_val['imidx'][ii].item() | |
| print('base:', base) | |
| save_base = os.path.join(args.output, str(k)+'_'+ str(base)) | |
| imgs_ii = imgs[ii].astype(dtype=np.uint8) | |
| show_iou = torch.tensor([iou.item()]) | |
| show_boundary_iou = torch.tensor([boundary_iou.item()]) | |
| show_anns(masks_hq_vis[ii], None, labels_box[ii].cpu(), None, save_base , imgs_ii, show_iou, show_boundary_iou) | |
| loss_dict = {"val_iou_"+str(k): iou, "val_boundary_iou_"+str(k): boundary_iou} | |
| loss_dict_reduced = misc.reduce_dict(loss_dict) | |
| metric_logger.update(**loss_dict_reduced) | |
| print('============================') | |
| # gather the stats from all processes | |
| metric_logger.synchronize_between_processes() | |
| print("Averaged stats:", metric_logger) | |
| resstat = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} | |
| test_stats.update(resstat) | |
| return test_stats | |
| if __name__ == "__main__": | |
| ### --------------- Configuring the Train and Valid datasets --------------- | |
| dataset_dis = {"name": "DIS5K-TR", | |
| "im_dir": "./data/DIS5K/DIS-TR/im", | |
| "gt_dir": "./data/DIS5K/DIS-TR/gt", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_thin = {"name": "ThinObject5k-TR", | |
| "im_dir": "./data/thin_object_detection/ThinObject5K/images_train", | |
| "gt_dir": "./data/thin_object_detection/ThinObject5K/masks_train", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_fss = {"name": "FSS", | |
| "im_dir": "./data/cascade_psp/fss_all", | |
| "gt_dir": "./data/cascade_psp/fss_all", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_duts = {"name": "DUTS-TR", | |
| "im_dir": "./data/cascade_psp/DUTS-TR", | |
| "gt_dir": "./data/cascade_psp/DUTS-TR", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_duts_te = {"name": "DUTS-TE", | |
| "im_dir": "./data/cascade_psp/DUTS-TE", | |
| "gt_dir": "./data/cascade_psp/DUTS-TE", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_ecssd = {"name": "ECSSD", | |
| "im_dir": "./data/cascade_psp/ecssd", | |
| "gt_dir": "./data/cascade_psp/ecssd", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_msra = {"name": "MSRA10K", | |
| "im_dir": "./data/cascade_psp/MSRA_10K", | |
| "gt_dir": "./data/cascade_psp/MSRA_10K", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| # valid set | |
| dataset_coift_val = {"name": "COIFT", | |
| "im_dir": "./data/thin_object_detection/COIFT/images", | |
| "gt_dir": "./data/thin_object_detection/COIFT/masks", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_hrsod_val = {"name": "HRSOD", | |
| "im_dir": "./data/thin_object_detection/HRSOD/images", | |
| "gt_dir": "./data/thin_object_detection/HRSOD/masks_max255", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_thin_val = {"name": "ThinObject5k-TE", | |
| "im_dir": "./data/thin_object_detection/ThinObject5K/images_test", | |
| "gt_dir": "./data/thin_object_detection/ThinObject5K/masks_test", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| dataset_dis_val = {"name": "DIS5K-VD", | |
| "im_dir": "./data/DIS5K/DIS-VD/im", | |
| "gt_dir": "./data/DIS5K/DIS-VD/gt", | |
| "im_ext": ".jpg", | |
| "gt_ext": ".png"} | |
| train_datasets = [dataset_dis, dataset_thin, dataset_fss, dataset_duts, dataset_duts_te, dataset_ecssd, dataset_msra] | |
| valid_datasets = [dataset_dis_val, dataset_coift_val, dataset_hrsod_val, dataset_thin_val] | |
| args = get_args_parser() | |
| net = MaskDecoderHQ(args.model_type) | |
| main(net, train_datasets, valid_datasets, args) | |