| import os |
|
|
| if not os.path.isdir("weights"): |
| os.mkdir("weights") |
|
|
| os.system("python -m pip install --upgrade pip") |
| os.system( |
| "wget https://raw.githubusercontent.com/asharma381/cs291I/main/backend/original_images/000749.png" |
| ) |
| os.system( |
| "wget -q -O weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
| ) |
| os.system( |
| "wget -q -O weights/ram_plus_swin_large_14m.pth https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth" |
| ) |
| os.system( |
| "wget -q -O weights/groundingdino_swint_ogc.pth https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" |
| ) |
| os.system("pip install git+https://github.com/xinyu1205/recognize-anything.git") |
| os.system("pip install git+https://github.com/IDEA-Research/GroundingDINO.git") |
| os.system("pip install git+https://github.com/facebookresearch/segment-anything.git") |
| os.system("pip install openai==0.27.4") |
| os.system("pip install tenacity") |
|
|
|
|
| from typing import List, Tuple |
|
|
| import cv2 |
| import gradio as gr |
| import groundingdino.config.GroundingDINO_SwinT_OGC |
| import numpy as np |
| import openai |
| import torch |
| from groundingdino.util.inference import Model |
| from PIL import Image, ImageDraw |
| from ram import get_transform |
| from ram import inference_ram as inference |
| from ram.models import ram_plus |
| from scipy.spatial.distance import cdist |
| from segment_anything import SamPredictor, sam_model_registry |
| from supervision import Detections |
| from tenacity import retry, wait_fixed |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| ram_model = None |
| ram_threshold_multiplier = 1 |
| gdino_model = None |
| sam_model = None |
| sam_predictor = None |
|
|
| print("CUDA Available:", torch.cuda.is_available()) |
|
|
|
|
| def get_tags_ram( |
| image: Image.Image, threshold_multiplier=0.8, weights_folder="weights" |
| ) -> List[str]: |
| global ram_model, ram_threshold_multiplier |
| if ram_model is None: |
| print("Loading RAM++ Model...") |
| ram_model = ram_plus( |
| pretrained=f"{weights_folder}/ram_plus_swin_large_14m.pth", |
| vit="swin_l", |
| image_size=384, |
| ) |
| ram_model.eval() |
| ram_model = ram_model.to(device) |
|
|
| ram_model.class_threshold *= threshold_multiplier / ram_threshold_multiplier |
| ram_threshold_multiplier = threshold_multiplier |
| transform = get_transform() |
|
|
| image = transform(image).unsqueeze(0).to(device) |
| res = inference(image, ram_model) |
| return [s.strip() for s in res[0].split("|")] |
|
|
|
|
| def get_gdino_result( |
| image: Image.Image, |
| classes: List[str], |
| box_threshold: float = 0.25, |
| weights_folder="weights", |
| ) -> Tuple[Detections, List[str]]: |
| global gdino_model |
|
|
| if gdino_model is None: |
| print("Loading GroundingDINO Model...") |
| config_path = groundingdino.config.GroundingDINO_SwinT_OGC.__file__ |
| gdino_model = Model( |
| model_config_path=config_path, |
| model_checkpoint_path=f"{weights_folder}/groundingdino_swint_ogc.pth", |
| device=device, |
| ) |
|
|
| detections, phrases = gdino_model.predict_with_caption( |
| image=np.array(image), |
| caption=", ".join(classes), |
| box_threshold=box_threshold, |
| text_threshold=0.25, |
| ) |
|
|
| return detections, phrases |
|
|
|
|
| def get_sam_model(weights_folder="weights"): |
| global sam_model |
| if sam_model is None: |
| sam_checkpoint = f"{weights_folder}/sam_vit_h_4b8939.pth" |
| sam_model = sam_model_registry["vit_h"](checkpoint=sam_checkpoint) |
| sam_model.to(device=device) |
| return sam_model |
|
|
|
|
| def filter_tags_gdino(image: Image.Image, tags: List[str]) -> List[str]: |
| detections, phrases = get_gdino_result(image, tags) |
| filtered_tags = [] |
| for tag in tags: |
| for ( |
| phrase, |
| area, |
| ) in zip(phrases, detections.area): |
| if area < 0.9 * image.size[0] * image.size[1] and tag in phrase: |
| filtered_tags.append(tag) |
| break |
| return filtered_tags |
|
|
|
|
| def read_file_to_string(file_path: str) -> str: |
| content = "" |
|
|
| try: |
| with open(file_path, "r", encoding="utf8") as file: |
| content = file.read() |
| except FileNotFoundError: |
| print(f"The file {file_path} was not found.") |
| except Exception as e: |
| print(f"An error occurred while reading {file_path}: {e}") |
|
|
| return content |
|
|
|
|
| @retry(wait=wait_fixed(2)) |
| def completion_with_backoff(**kwargs): |
| return openai.ChatCompletion.create(**kwargs) |
|
|
|
|
| def gpt4( |
| usr_prompt: str, sys_prompt: str = "", api_key: str = "", model: str = "gpt-4" |
| ) -> str: |
| openai.api_key = api_key |
|
|
| message = [ |
| {"role": "system", "content": sys_prompt}, |
| {"role": "user", "content": usr_prompt}, |
| ] |
|
|
| response = completion_with_backoff( |
| model=model, |
| messages=message, |
| temperature=0.2, |
| max_tokens=1000, |
| frequency_penalty=0.0, |
| ) |
|
|
| return response["choices"][0]["message"]["content"] |
|
|
|
|
| def select_best_tag( |
| filtered_tags: List[str], object_to_place: str, api_key: str = "" |
| ) -> str: |
| user_template = read_file_to_string("user_template.txt").format(object=object_to_place) |
| user_prompt = user_template + "\n".join(filtered_tags) |
| system_prompt = read_file_to_string("system_template.txt") |
| return gpt4(user_prompt, system_prompt, api_key=api_key) |
|
|
|
|
| def get_location_gsam( |
| image: Image.Image, prompt: str, weights_folder="weights" |
| ) -> Tuple[int, int]: |
| global sam_predictor |
|
|
| BOX_TRESHOLD = 0.25 |
| RESIZE_RATIO = 3 |
|
|
| detections, phrases = get_gdino_result( |
| image=image, |
| classes=[prompt], |
| box_threshold=BOX_TRESHOLD, |
| ) |
|
|
| while len(detections.xyxy) == 0: |
| BOX_TRESHOLD -= 0.02 |
| detections, phrases = get_gdino_result( |
| image=image, |
| classes=[prompt], |
| box_threshold=BOX_TRESHOLD, |
| ) |
|
|
| sam_model = get_sam_model(weights_folder) |
|
|
| if sam_predictor is None: |
| print("Loading SAM Model...") |
| sam_predictor = SamPredictor(sam_model) |
|
|
| sam_predictor.set_image(np.array(image)) |
| result_masks = [] |
| for box in detections.xyxy: |
| masks, scores, logits = sam_predictor.predict(box=box, multimask_output=True) |
| index = np.argmax(scores) |
| result_masks.append(masks[index]) |
| detections.mask = np.array(result_masks) |
|
|
| combined_mask = detections.mask[0] |
| for mask in detections.mask[1:]: |
| combined_mask += mask |
| combined_mask[combined_mask > 1] = 1 |
| mask = cv2.resize( |
| combined_mask.astype("uint8"), |
| ( |
| combined_mask.shape[1] // RESIZE_RATIO, |
| combined_mask.shape[0] // RESIZE_RATIO, |
| ), |
| ) |
|
|
| mask_2_pad = np.pad(mask, pad_width=2, mode="constant", constant_values=0) |
| mask_1_pad = np.pad(mask, pad_width=1, mode="constant", constant_values=0) |
|
|
| windows = np.lib.stride_tricks.sliding_window_view(mask_2_pad, (3, 3)) |
| windows_all_zero = (windows == 0).all(axis=(2, 3)) |
|
|
| result = np.where(windows_all_zero, 2, mask_1_pad) |
| mask_0_coordinates = np.argwhere(result == 0) |
| mask_1_coordinates = np.argwhere(result == 1) |
| distances = cdist(mask_1_coordinates, mask_0_coordinates, "euclidean") |
| max_min_distance_index = np.argmax(np.min(distances, axis=1)) |
| y, x = mask_1_coordinates[max_min_distance_index] |
|
|
| return int(x) * RESIZE_RATIO, int(y) * RESIZE_RATIO |
|
|
|
|
| def run_octo_pipeline(input_image, object, api_key): |
| print("Inside run_octo_pipeline with input_image=", input_image, "object=", object) |
|
|
| print("Loading Image...") |
| image = input_image.convert("RGB") |
|
|
| print("Stage 1...") |
| tags = get_tags_ram(image, threshold_multiplier=0.8) |
| print("RAM++ Tags", tags) |
| filtered_tags = filter_tags_gdino(image, tags) |
| print("Filtered Tags", filtered_tags) |
|
|
| print("Stage 2...") |
| selected_tag = select_best_tag(filtered_tags, object, api_key=api_key) |
| print("GPT-4 Selected Tag", selected_tag) |
|
|
| print("Stage 3...") |
| x, y = get_location_gsam(image, selected_tag) |
| print("G-SAM Location", "(" + str(x) + "," + str(y) + ")") |
|
|
| draw = ImageDraw.Draw(image) |
| radius = 10 |
| bbox = (x - radius, y - radius, x + radius, y + radius) |
| draw.ellipse(bbox, fill="red") |
| return [image] |
|
|
|
|
| block = gr.Blocks() |
|
|
| with block: |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(type="pil", value="000749.png") |
| object = gr.Textbox(label="Object", placeholder="Enter an object") |
| api_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter OpenAI API Key") |
|
|
| with gr.Column(): |
| gallery = gr.Gallery( |
| label="Output", |
| show_label=False, |
| elem_id="gallery", |
| preview=True, |
| object_fit="scale-down", |
| ) |
|
|
| iface = gr.Interface( |
| fn=run_octo_pipeline, inputs=[input_image, object, api_key], outputs=gallery |
| ) |
| iface.launch() |
|
|