Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import gradio as gr | |
| import spaces | |
| import supervision as sv | |
| import torch | |
| from PIL import Image | |
| from utils.florence import load_florence_model, run_florence_inference, \ | |
| FLORENCE_OPEN_VOCABULARY_DETECTION_TASK | |
| from utils.sam import load_sam_image_model, run_sam_inference | |
| DEVICE = torch.device("cuda") | |
| # DEVICE = torch.device("cpu") | |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) | |
| SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
| def process_image(image_input, text_input) -> Optional[Image.Image]: | |
| if not image_input: | |
| gr.Info("Please upload an image.") | |
| return None | |
| if not text_input: | |
| gr.Info("Please enter a text prompt.") | |
| return None | |
| _, result = run_florence_inference( | |
| model=FLORENCE_MODEL, | |
| processor=FLORENCE_PROCESSOR, | |
| device=DEVICE, | |
| image=image_input, | |
| task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
| text=text_input | |
| ) | |
| detections = sv.Detections.from_lmm( | |
| lmm=sv.LMM.FLORENCE_2, | |
| result=result, | |
| resolution_wh=image_input.size | |
| ) | |
| detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) | |
| if len(detections) == 0: | |
| gr.Info("No objects detected.") | |
| return None | |
| return Image.fromarray(detections.mask[0].astype("uint8") * 255) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input_component = gr.Image( | |
| type='pil', label='Upload image') | |
| text_input_component = gr.Textbox( | |
| label='Text prompt', | |
| placeholder='Enter text prompts') | |
| submit_button_component = gr.Button( | |
| value='Submit', variant='primary') | |
| with gr.Column(): | |
| image_output_component = gr.Image(label='Output mask') | |
| submit_button_component.click( | |
| fn=process_image, | |
| inputs=[ | |
| image_input_component, | |
| text_input_component | |
| ], | |
| outputs=[ | |
| image_output_component, | |
| ] | |
| ) | |
| text_input_component.submit( | |
| fn=process_image, | |
| inputs=[ | |
| image_input_component, | |
| text_input_component | |
| ], | |
| outputs=[ | |
| image_output_component, | |
| ] | |
| ) | |
| demo.launch(debug=False, show_error=True) | |