Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from typing import Dict, List, Any, Optional, Tuple | |
| import spaces | |
| from detection_model import DetectionModel | |
| from color_mapper import ColorMapper | |
| from evaluation_metrics import EvaluationMetrics | |
| from style import Style | |
| from image_processor import ImageProcessor | |
| # Initialize image processor | |
| image_processor = ImageProcessor() | |
| def get_all_classes(): | |
| """ | |
| Get all available COCO classes from the currently active model or fallback to standard COCO classes | |
| Returns: | |
| List of tuples (class_id, class_name) | |
| """ | |
| # Try to get class names from any loaded model | |
| for model_name, model_instance in image_processor.model_instances.items(): | |
| if model_instance and model_instance.is_model_loaded: | |
| try: | |
| class_names = model_instance.class_names | |
| return [(idx, name) for idx, name in class_names.items()] | |
| except Exception: | |
| pass | |
| # Fallback to standard COCO classes | |
| return [ | |
| (0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'), | |
| (5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'), | |
| (10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'), | |
| (14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'), | |
| (20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'), | |
| (25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'), | |
| (30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'), | |
| (35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'), | |
| (39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'), | |
| (44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'), | |
| (49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'), | |
| (54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'), | |
| (59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'), | |
| (64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'), | |
| (69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'), | |
| (74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'), | |
| (79, 'toothbrush') | |
| ] | |
| def process_and_plot(image, model_name, confidence_threshold, filter_classes=None): | |
| """ | |
| Process image and create plots for statistics with enhanced visualization | |
| Args: | |
| image: Input image | |
| model_name: Name of the model to use | |
| confidence_threshold: Confidence threshold for detection | |
| filter_classes: Optional list of classes to filter results | |
| Returns: | |
| Tuple of (result_image, result_text, formatted_stats, plot_figure) | |
| """ | |
| class_ids = None | |
| if filter_classes: | |
| class_ids = [] | |
| for class_str in filter_classes: | |
| try: | |
| # Extract ID from format "id: name" | |
| class_id = int(class_str.split(":")[0].strip()) | |
| class_ids.append(class_id) | |
| except: | |
| continue | |
| # Execute detection | |
| result_image, result_text, stats = image_processor.process_image( | |
| image, | |
| model_name, | |
| confidence_threshold, | |
| class_ids | |
| ) | |
| # Format the statistics for better display | |
| formatted_stats = image_processor.format_json_for_display(stats) | |
| if not stats or "class_statistics" not in stats or not stats["class_statistics"]: | |
| # Create the table | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, "No detection data available", | |
| ha='center', va='center', fontsize=14, fontfamily='Arial') | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| ax.axis('off') | |
| plot_figure = fig | |
| else: | |
| # Prepare visualization data | |
| available_classes = dict(get_all_classes()) | |
| viz_data = image_processor.prepare_visualization_data(stats, available_classes) | |
| # Create plot | |
| plot_figure = EvaluationMetrics.create_enhanced_stats_plot(viz_data) | |
| return result_image, result_text, formatted_stats, plot_figure | |
| def create_interface(): | |
| """創建 Gradio 界面,包含美化的視覺效果""" | |
| css = Style.get_css() | |
| # 獲取可用模型信息 | |
| available_models = DetectionModel.get_available_models() | |
| model_choices = [model["model_file"] for model in available_models] | |
| model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models] | |
| # 可用類別過濾選項 | |
| available_classes = get_all_classes() | |
| class_choices = [f"{id}: {name}" for id, name in available_classes] | |
| # 創建 Gradio Blocks 界面 | |
| with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="teal", secondary_hue="blue")) as demo: | |
| # 頁面頂部標題 | |
| with gr.Group(elem_classes="app-header"): | |
| gr.HTML(""" | |
| <div style="text-align: center; width: 100%;"> | |
| <h1 class="app-title">VisionScout</h1> | |
| <h2 class="app-subtitle">Detect and identify objects in your images</h2> | |
| <div class="app-divider"></div> | |
| </div> | |
| """) | |
| current_model = gr.State("yolov8m.pt") # use medium size model as defualt | |
| # 主要內容區 | |
| with gr.Row(equal_height=True): | |
| # 左側 - 輸入控制區(可上傳圖片) | |
| with gr.Column(scale=4, elem_classes="input-panel"): | |
| with gr.Group(): | |
| gr.HTML('<div class="section-heading">Upload Image</div>') | |
| image_input = gr.Image(type="pil", label="Upload an image", elem_classes="upload-box") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=model_choices, | |
| value="yolov8m.pt", | |
| label="Select Model", | |
| info="Choose different models based on your needs for speed vs. accuracy" | |
| ) | |
| # display model info | |
| model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt")) | |
| confidence = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.25, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| info="Higher values show fewer but more confident detections" | |
| ) | |
| with gr.Accordion("Filter Classes", open=False): | |
| # 常見物件類別快速選擇按鈕 | |
| gr.HTML('<div class="section-heading" style="font-size: 1rem;">Common Categories</div>') | |
| with gr.Row(): | |
| people_btn = gr.Button("People", size="sm") | |
| vehicles_btn = gr.Button("Vehicles", size="sm") | |
| animals_btn = gr.Button("Animals", size="sm") | |
| objects_btn = gr.Button("Common Objects", size="sm") | |
| # 類別選擇下拉框 | |
| class_filter = gr.Dropdown( | |
| choices=class_choices, | |
| multiselect=True, | |
| label="Select Classes to Display", | |
| info="Leave empty to show all detected objects" | |
| ) | |
| # detect buttom | |
| detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn") | |
| # 使用說明區 | |
| with gr.Group(elem_classes="how-to-use"): | |
| gr.HTML('<div class="section-heading">How to Use</div>') | |
| gr.Markdown(""" | |
| 1. Upload an image or use the camera | |
| 2. (Optional) Adjust settings like confidence threshold or model size (n, m, x) | |
| 3. Optionally filter to specific object classes | |
| 4. Click "Detect Objects" button | |
| The model will identify objects in your image and display them with bounding boxes. | |
| **Note:** Detection quality depends on image clarity and model settings. | |
| """) | |
| # 右側 - 結果顯示區 | |
| with gr.Column(scale=6, elem_classes="output-panel"): | |
| with gr.Tabs(elem_classes="tabs"): | |
| with gr.Tab("Detection Result"): | |
| result_image = gr.Image(type="pil", label="Detection Result") | |
| # details summary | |
| with gr.Group(elem_classes="result-details-box"): | |
| gr.HTML('<div class="section-heading">Detection Details</div>') | |
| # 文本框設置,讓顯示會更寬 | |
| result_text = gr.Textbox( | |
| label=None, | |
| lines=12, | |
| max_lines=15, | |
| elem_classes="wide-result-text", | |
| elem_id="detection-details", | |
| container=False, | |
| scale=2, | |
| min_width=600 | |
| ) | |
| with gr.Tab("Statistics"): | |
| with gr.Row(): | |
| with gr.Column(scale=3, elem_classes="plot-column"): | |
| gr.HTML('<div class="section-heading">Object Distribution</div>') | |
| plot_output = gr.Plot( | |
| label=None, | |
| elem_classes="large-plot-container" | |
| ) | |
| # 右側放 JSON 數據比較清晰 | |
| with gr.Column(scale=2, elem_classes="stats-column"): | |
| gr.HTML('<div class="section-heading">Detection Statistics</div>') | |
| stats_json = gr.JSON( | |
| label=None, # remove label | |
| elem_classes="enhanced-json-display" | |
| ) | |
| detect_btn.click( | |
| fn=process_and_plot, | |
| inputs=[image_input, current_model, confidence, class_filter], | |
| outputs=[result_image, result_text, stats_json, plot_output] | |
| ) | |
| # model option | |
| model_dropdown.change( | |
| fn=lambda model: (model, DetectionModel.get_model_description(model)), | |
| inputs=[model_dropdown], | |
| outputs=[current_model, model_info] | |
| ) | |
| # each classes link | |
| people_classes = [0] # 人 | |
| vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # 各種車輛 | |
| animals_classes = list(range(14, 24)) # COCO 中的動物 | |
| common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # 常見家居物品 | |
| # Linked the quik buttom | |
| people_btn.click( | |
| lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes], | |
| outputs=class_filter | |
| ) | |
| vehicles_btn.click( | |
| lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes], | |
| outputs=class_filter | |
| ) | |
| animals_btn.click( | |
| lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes], | |
| outputs=class_filter | |
| ) | |
| objects_btn.click( | |
| lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects], | |
| outputs=class_filter | |
| ) | |
| example_images = [ | |
| "room_01.jpg", | |
| "street_01.jpg", | |
| "street_02.jpg", | |
| "street_03.jpg" | |
| ] | |
| # add example images | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=image_input, | |
| outputs=None, | |
| fn=None, | |
| cache_examples=False, | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div class="footer"> | |
| <p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p> | |
| <p>Model can detect 80 different classes of objects</p> | |
| </div> | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| import time | |
| demo = create_interface() | |
| demo.launch() | |