pyedward commited on
Commit
a2bdad7
·
verified ·
1 Parent(s): 6e95c86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -67
app.py CHANGED
@@ -1,67 +1,67 @@
1
- from transformers import DetrImageProcessor, DetrForObjectDetection
2
- import torch
3
- from PIL import Image, ImageDraw
4
- import requests
5
- import random
6
- from IPython.display import display
7
- import gradio as gr
8
-
9
- # you can specify the revision tag if you don't want the timm dependency
10
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
-
13
- def draw_detections(image, outputs, processor, model, threshold=0.9):
14
- """
15
- Draw bounding boxes and labels on an image using detection results.
16
-
17
- Args:
18
- image (PIL.Image): Input image.
19
- outputs (dict): Model output.
20
- processor: The processor used for post-processing.
21
- model: The object detection model.
22
- threshold (float): Confidence threshold.
23
-
24
- Returns:
25
- PIL.Image: The image with bounding boxes drawn.
26
- """
27
- target_sizes = torch.tensor([image.size[::-1]])
28
- results = processor.post_process_object_detection(
29
- outputs, target_sizes=target_sizes, threshold=threshold
30
- )[0]
31
-
32
- draw_image = image.copy()
33
- draw = ImageDraw.Draw(draw_image, "RGBA")
34
-
35
- # define fixed colors per label for consistency
36
- COLORS = {}
37
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
38
- box = [round(i, 2) for i in box.tolist()]
39
- label_name = model.config.id2label[label.item()]
40
-
41
- # assign consistent random color for each label type
42
- if label_name not in COLORS:
43
- COLORS[label_name] = tuple(random.choices(range(256), k=3))
44
- color = COLORS[label_name]
45
-
46
- # draw translucent box
47
- draw.rectangle(box, fill=color + (80,), outline=color, width=3)
48
- draw.text((box[0] + 3, box[1] + 3),
49
- f"{label_name} {round(score.item(), 2)}",
50
- fill=(255, 255, 255, 255))
51
-
52
- return draw_image
53
-
54
-
55
- def detect_and_draw(img):
56
- inputs = processor(images=img, return_tensors="pt")
57
- outputs = model(**inputs)
58
- return draw_detections(img, outputs, processor, model)
59
-
60
- demo = gr.Interface(
61
- fn=detect_and_draw,
62
- inputs=gr.Image(type="pil"),
63
- outputs="image",
64
- title="Object Detection Viewer"
65
- )
66
-
67
- demo.launch()
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ import requests
5
+ import random
6
+ from IPython.display import display
7
+ import gradio as gr
8
+
9
+ # you can specify the revision tag if you don't want the timm dependency
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
+
13
+ def draw_detections(image, outputs, processor, model, threshold=0.9):
14
+ """
15
+ Draw bounding boxes and labels on an image using detection results.
16
+
17
+ Args:
18
+ image (PIL.Image): Input image.
19
+ outputs (dict): Model output.
20
+ processor: The processor used for post-processing.
21
+ model: The object detection model.
22
+ threshold (float): Confidence threshold.
23
+
24
+ Returns:
25
+ PIL.Image: The image with bounding boxes drawn.
26
+ """
27
+ target_sizes = torch.tensor([image.size[::-1]])
28
+ results = processor.post_process_object_detection(
29
+ outputs, target_sizes=target_sizes, threshold=threshold
30
+ )[0]
31
+
32
+ draw_image = image.copy()
33
+ draw = ImageDraw.Draw(draw_image, "RGBA")
34
+
35
+ # define fixed colors per label for consistency
36
+ COLORS = {}
37
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
38
+ box = [round(i, 2) for i in box.tolist()]
39
+ label_name = model.config.id2label[label.item()]
40
+
41
+ # assign consistent random color for each label type
42
+ if label_name not in COLORS:
43
+ COLORS[label_name] = tuple(random.choices(range(256), k=3))
44
+ color = COLORS[label_name]
45
+
46
+ # draw translucent box
47
+ draw.rectangle(box, fill=color + (80,), outline=color, width=3)
48
+ draw.text((box[0] + 3, box[1] + 3),
49
+ f"{label_name} {round(score.item(), 2)}",
50
+ fill=(255, 255, 255, 255))
51
+
52
+ return draw_image
53
+
54
+
55
+ def detect_and_draw(img):
56
+ inputs = processor(images=img, return_tensors="pt")
57
+ outputs = model(**inputs)
58
+ return draw_detections(img, outputs, processor, model)
59
+
60
+ demo = gr.Interface(
61
+ fn=detect_and_draw,
62
+ inputs=gr.Image(type="pil"),
63
+ outputs="image",
64
+ title="Object Detection Viewer"
65
+ )
66
+
67
+ demo.launch(show_error=True)