import gradio as gr import torch import io from PIL import Image from transformers import ( AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM, ) import numpy as np model_root = "qihoo360/fg-clip2-base" model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True) device = model.device tokenizer = AutoTokenizer.from_pretrained(model_root) image_processor = AutoImageProcessor.from_pretrained(model_root) import math import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import ast def resize_short_edge(image, target_size=2048): if isinstance(image, str): image = Image.open(image) width, height = image.size short_edge = min(width, height) if short_edge >= target_size: return image scale = target_size / short_edge new_width = int(width * scale) new_height = int(height * scale) resized_image = image.resize((new_width, new_height)) return resized_image def Get_Densefeature(image, candidate_labels): """ Takes an image and a comma-separated string of candidate labels, and returns the classification scores. """ candidate_labels = ast.literal_eval(candidate_labels) assert len(candidate_labels) != 0 print(candidate_labels) image = image.convert("RGB") image = resize_short_edge(image,target_size=1024) image_input = image_processor(images=image, max_num_patches=4096, return_tensors="pt").to(device) # captions = ["电脑","黑猫","窗户","window","white cat","book"] captions = candidate_labels with torch.no_grad(): dense_image_feature = model.get_image_dense_feature(**image_input) spatial_values = image_input["spatial_shapes"][0] real_h = spatial_values[0].item() real_w = spatial_values[1].item() real_pixel_tokens_num = real_w*real_h dense_image_feature = dense_image_feature[0][:real_pixel_tokens_num] captions = [caption.lower() for caption in captions] caption_input = tokenizer(captions, padding="max_length", max_length=64, truncation=True, return_tensors="pt").to(device) text_feature = model.get_text_features(**caption_input, walk_type="box") text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True) dense_image_feature = dense_image_feature / dense_image_feature.norm(p=2, dim=-1, keepdim=True) similarity = dense_image_feature @ text_feature.T similarity = similarity.cpu() num_classes = len(captions) cols = 3 rows = (num_classes + cols - 1) // cols aspect_ratio = real_w / real_h fig_width_inch = 3 * cols fig_height_inch = fig_width_inch / aspect_ratio * rows / cols fig, axes = plt.subplots(rows, cols, figsize=(fig_width_inch, fig_height_inch)) fig.subplots_adjust(wspace=0.01, hspace=0.01) if num_classes == 1: axes = [axes] else: axes = axes.flatten() for cls_index in range(num_classes): similarity_map = similarity[:, cls_index].cpu().numpy() show_image = similarity_map.reshape((real_h, real_w)) ax = axes[cls_index] ax.imshow(show_image, cmap='viridis', aspect='equal') # 保持原始比例 ax.set_xticks([]) ax.set_yticks([]) ax.axis('off') for idx in range(num_classes, len(axes)): axes[idx].axis('off') buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close(fig) pil_img = Image.open(buf) # buf.close() return pil_img with gr.Blocks() as demo: gr.Markdown("# FG-CLIP 2 Densefeature") gr.Markdown( "This app uses the FG-CLIP 2 model (qihoo360/fg-clip2-base) for Densefeature show on CPU :" ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil") text_input = gr.Textbox(label="Input a list of labels, example:['a','b','c']") dfs_button = gr.Button("Run Densefeature", visible=True) with gr.Column(): dfs_output = gr.Image(label="Similarity Visualization", type="pil") examples = [ ["./cat_dfclor.jpg", str(["电脑","黑猫","窗户","window","white cat","book"])], ] gr.Examples( examples=examples, inputs=[image_input, text_input], ) dfs_button.click(fn=Get_Densefeature, inputs=[image_input, text_input], outputs=dfs_output) demo.launch() # demo.launch(server_name="0.0.0.0", server_port=7862, share=True)