pyedward's picture
Update app.py
a2bdad7 verified
raw
history blame
2.29 kB
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageDraw
import requests
import random
from IPython.display import display
import gradio as gr
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
def draw_detections(image, outputs, processor, model, threshold=0.9):
"""
Draw bounding boxes and labels on an image using detection results.
Args:
image (PIL.Image): Input image.
outputs (dict): Model output.
processor: The processor used for post-processing.
model: The object detection model.
threshold (float): Confidence threshold.
Returns:
PIL.Image: The image with bounding boxes drawn.
"""
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=threshold
)[0]
draw_image = image.copy()
draw = ImageDraw.Draw(draw_image, "RGBA")
# define fixed colors per label for consistency
COLORS = {}
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
label_name = model.config.id2label[label.item()]
# assign consistent random color for each label type
if label_name not in COLORS:
COLORS[label_name] = tuple(random.choices(range(256), k=3))
color = COLORS[label_name]
# draw translucent box
draw.rectangle(box, fill=color + (80,), outline=color, width=3)
draw.text((box[0] + 3, box[1] + 3),
f"{label_name} {round(score.item(), 2)}",
fill=(255, 255, 255, 255))
return draw_image
def detect_and_draw(img):
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
return draw_detections(img, outputs, processor, model)
demo = gr.Interface(
fn=detect_and_draw,
inputs=gr.Image(type="pil"),
outputs="image",
title="Object Detection Viewer"
)
demo.launch(show_error=True)