Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from typing import Dict, List, Tuple, Optional, Any | |
| class ObjectGroupProcessor: | |
| """ | |
| 物件組處理器 - 專門處理物件分組、排序和子句生成的邏輯 | |
| 負責物件按類別分組、重複物件檢測移除、物件組優先級排序以及描述子句的生成 | |
| """ | |
| def __init__(self, confidence_threshold_for_description: float = 0.25, | |
| spatial_handler: Optional[Any] = None, | |
| text_optimizer: Optional[Any] = None): | |
| """ | |
| 初始化物件組處理器 | |
| Args: | |
| confidence_threshold_for_description: 用於描述的置信度閾值 | |
| spatial_handler: 空間位置處理器實例 | |
| text_optimizer: 文本優化器實例 | |
| """ | |
| self.logger = logging.getLogger(self.__class__.__name__) | |
| self.confidence_threshold_for_description = confidence_threshold_for_description | |
| self.spatial_handler = spatial_handler | |
| self.text_optimizer = text_optimizer | |
| def group_objects_by_class(self, confident_objects: List[Dict], | |
| object_statistics: Optional[Dict]) -> Dict[str, List[Dict]]: | |
| """ | |
| 按類別分組物件 | |
| Args: | |
| confident_objects: 置信度過濾後的物件 | |
| object_statistics: 物件統計信息 | |
| Returns: | |
| Dict[str, List[Dict]]: 按類別分組的物件 | |
| """ | |
| objects_by_class = {} | |
| if object_statistics: | |
| # 使用預計算的統計信息,採用動態的信心度 | |
| for class_name, stats in object_statistics.items(): | |
| count = stats.get("count", 0) | |
| avg_confidence = stats.get("avg_confidence", 0) | |
| # 動態調整置信度閾值 | |
| dynamic_threshold = self.confidence_threshold_for_description | |
| if class_name in ["potted plant", "vase", "clock", "book"]: | |
| dynamic_threshold = max(0.15, self.confidence_threshold_for_description * 0.6) | |
| elif count >= 3: | |
| dynamic_threshold = max(0.2, self.confidence_threshold_for_description * 0.8) | |
| if count > 0 and avg_confidence >= dynamic_threshold: | |
| matching_objects = [obj for obj in confident_objects if obj.get("class_name") == class_name] | |
| if not matching_objects: | |
| matching_objects = [obj for obj in confident_objects | |
| if obj.get("class_name") == class_name and obj.get("confidence", 0) >= dynamic_threshold] | |
| if matching_objects: | |
| actual_count = min(stats["count"], len(matching_objects)) | |
| objects_by_class[class_name] = matching_objects[:actual_count] | |
| # Debug logging for specific classes | |
| if class_name in ["car", "traffic light", "person", "handbag"]: | |
| print(f"DEBUG: Before spatial deduplication:") | |
| print(f"DEBUG: {class_name}: {len(objects_by_class[class_name])} objects before dedup") | |
| else: | |
| # 備用邏輯,同樣使用動態閾值 | |
| for obj in confident_objects: | |
| name = obj.get("class_name", "unknown object") | |
| if name == "unknown object" or not name: | |
| continue | |
| if name not in objects_by_class: | |
| objects_by_class[name] = [] | |
| objects_by_class[name].append(obj) | |
| return objects_by_class | |
| def remove_duplicate_objects(self, objects_by_class: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]: | |
| """ | |
| 移除重複物件 | |
| Args: | |
| objects_by_class: 按類別分組的物件 | |
| Returns: | |
| Dict[str, List[Dict]]: 去重後的物件 | |
| """ | |
| deduplicated_objects_by_class = {} | |
| processed_positions = [] | |
| for class_name, group_of_objects in objects_by_class.items(): | |
| unique_objects = [] | |
| for obj in group_of_objects: | |
| obj_position = obj.get("normalized_center", [0.5, 0.5]) | |
| is_duplicate = False | |
| for processed_pos in processed_positions: | |
| position_distance = abs(obj_position[0] - processed_pos[0]) + abs(obj_position[1] - processed_pos[1]) | |
| if position_distance < 0.15: | |
| is_duplicate = True | |
| break | |
| if not is_duplicate: | |
| unique_objects.append(obj) | |
| processed_positions.append(obj_position) | |
| if unique_objects: | |
| deduplicated_objects_by_class[class_name] = unique_objects | |
| # Debug logging after deduplication | |
| for class_name in ["car", "traffic light", "person", "handbag"]: | |
| if class_name in deduplicated_objects_by_class: | |
| print(f"DEBUG: After spatial deduplication:") | |
| print(f"DEBUG: {class_name}: {len(deduplicated_objects_by_class[class_name])} objects after dedup") | |
| return deduplicated_objects_by_class | |
| def sort_object_groups(self, objects_by_class: Dict[str, List[Dict]]) -> List[Tuple[str, List[Dict]]]: | |
| """ | |
| 排序物件組 | |
| Args: | |
| objects_by_class: 按類別分組的物件 | |
| Returns: | |
| List[Tuple[str, List[Dict]]]: 排序後的物件組 | |
| """ | |
| def sort_key_object_groups(item_tuple: Tuple[str, List[Dict]]): | |
| class_name_key, obj_group_list = item_tuple | |
| priority = 3 | |
| count = len(obj_group_list) | |
| # 確保類別名稱已標準化 | |
| normalized_class_name = self._normalize_object_class_name(class_name_key) | |
| # 動態優先級 | |
| if normalized_class_name == "person": | |
| priority = 0 | |
| elif normalized_class_name in ["dining table", "chair", "sofa", "bed"]: | |
| priority = 1 | |
| elif normalized_class_name in ["car", "bus", "truck", "traffic light"]: | |
| priority = 2 | |
| elif count >= 3: | |
| priority = max(1, priority - 1) | |
| elif normalized_class_name in ["potted plant", "vase", "clock", "book"] and count >= 2: | |
| priority = 2 | |
| avg_area = sum(o.get("normalized_area", 0.0) for o in obj_group_list) / len(obj_group_list) if obj_group_list else 0 | |
| quantity_bonus = min(count / 5.0, 1.0) | |
| return (priority, -len(obj_group_list), -avg_area, -quantity_bonus) | |
| return sorted(objects_by_class.items(), key=sort_key_object_groups) | |
| def generate_object_clauses(self, sorted_object_groups: List[Tuple[str, List[Dict]]], | |
| object_statistics: Optional[Dict], | |
| scene_type: str, | |
| image_width: Optional[int], | |
| image_height: Optional[int], | |
| region_analyzer: Optional[Any] = None) -> List[str]: | |
| """ | |
| 生成物件描述子句 | |
| Args: | |
| sorted_object_groups: 排序後的物件組 | |
| object_statistics: 物件統計信息 | |
| scene_type: 場景類型 | |
| image_width: 圖像寬度 | |
| image_height: 圖像高度 | |
| region_analyzer: 區域分析器實例 | |
| Returns: | |
| List[str]: 物件描述子句列表 | |
| """ | |
| object_clauses = [] | |
| for class_name, group_of_objects in sorted_object_groups: | |
| count = len(group_of_objects) | |
| # Debug logging for final count | |
| if class_name in ["car", "traffic light", "person", "handbag"]: | |
| print(f"DEBUG: Final count for {class_name}: {count}") | |
| if count == 0: | |
| continue | |
| # 標準化class name | |
| normalized_class_name = self._normalize_object_class_name(class_name) | |
| # 使用統計信息確保準確的數量描述 | |
| if object_statistics and class_name in object_statistics: | |
| actual_count = object_statistics[class_name]["count"] | |
| formatted_name_with_exact_count = self._format_object_count_description( | |
| normalized_class_name, | |
| actual_count, | |
| scene_type=scene_type | |
| ) | |
| else: | |
| formatted_name_with_exact_count = self._format_object_count_description( | |
| normalized_class_name, | |
| count, | |
| scene_type=scene_type | |
| ) | |
| if formatted_name_with_exact_count == "no specific objects clearly identified" or not formatted_name_with_exact_count: | |
| continue | |
| # 確定群組的集體位置 | |
| location_description_suffix = self._generate_location_description( | |
| group_of_objects, count, image_width, image_height, region_analyzer | |
| ) | |
| # 首字母大寫 | |
| formatted_name_capitalized = formatted_name_with_exact_count[0].upper() + formatted_name_with_exact_count[1:] | |
| object_clauses.append(f"{formatted_name_capitalized} {location_description_suffix}") | |
| return object_clauses | |
| def format_object_clauses(self, object_clauses: List[str]) -> str: | |
| """ | |
| 格式化物件描述子句 | |
| Args: | |
| object_clauses: 物件描述子句列表 | |
| Returns: | |
| str: 格式化後的描述 | |
| """ | |
| if not object_clauses: | |
| return "No common objects were confidently identified for detailed description." | |
| # 處理第一個子句 | |
| first_clause = object_clauses.pop(0) | |
| result = first_clause + "." | |
| # 處理剩餘子句 | |
| if object_clauses: | |
| result += " The scene features:" | |
| joined_object_clauses = ". ".join(object_clauses) | |
| if joined_object_clauses and not joined_object_clauses.endswith("."): | |
| joined_object_clauses += "." | |
| result += " " + joined_object_clauses | |
| return result | |
| def _generate_location_description(self, group_of_objects: List[Dict], count: int, | |
| image_width: Optional[int], image_height: Optional[int], | |
| region_analyzer: Optional[Any] = None) -> str: | |
| """ | |
| 生成位置描述 | |
| Args: | |
| group_of_objects: 物件組 | |
| count: 物件數量 | |
| image_width: 圖像寬度 | |
| image_height: 圖像高度 | |
| region_analyzer: 區域分析器實例 | |
| Returns: | |
| str: 位置描述 | |
| """ | |
| if count == 1: | |
| if self.spatial_handler: | |
| spatial_desc = self.spatial_handler.generate_spatial_description( | |
| group_of_objects[0], image_width, image_height, region_analyzer | |
| ) | |
| else: | |
| spatial_desc = self._get_spatial_description_phrase(group_of_objects[0].get("region", "")) | |
| if spatial_desc: | |
| return f"is {spatial_desc}" | |
| else: | |
| distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region")))) | |
| valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()] | |
| if not valid_regions: | |
| return "is positioned in the scene" | |
| elif len(valid_regions) == 1: | |
| spatial_desc = self._get_spatial_description_phrase(valid_regions[0]) | |
| return f"is primarily {spatial_desc}" if spatial_desc else "is positioned in the scene" | |
| elif len(valid_regions) == 2: | |
| clean_region1 = valid_regions[0].replace('_', ' ') | |
| clean_region2 = valid_regions[1].replace('_', ' ') | |
| return f"is mainly across the {clean_region1} and {clean_region2} areas" | |
| else: | |
| return "is distributed in various parts of the scene" | |
| else: | |
| distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region")))) | |
| valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()] | |
| if not valid_regions: | |
| return "are visible in the scene" | |
| elif len(valid_regions) == 1: | |
| clean_region = valid_regions[0].replace('_', ' ') | |
| return f"are primarily in the {clean_region} area" | |
| elif len(valid_regions) == 2: | |
| clean_region1 = valid_regions[0].replace('_', ' ') | |
| clean_region2 = valid_regions[1].replace('_', ' ') | |
| return f"are mainly across the {clean_region1} and {clean_region2} areas" | |
| else: | |
| return "are distributed in various parts of the scene" | |
| def _get_spatial_description_phrase(self, region: str) -> str: | |
| """ | |
| 獲取空間描述短語的備用方法 | |
| Args: | |
| region: 區域字符串 | |
| Returns: | |
| str: 空間描述短語 | |
| """ | |
| if not region or region == "unknown": | |
| return "" | |
| clean_region = region.replace('_', ' ').strip().lower() | |
| region_map = { | |
| "top left": "in the upper left area", | |
| "top center": "in the upper area", | |
| "top right": "in the upper right area", | |
| "middle left": "on the left side", | |
| "middle center": "in the center", | |
| "center": "in the center", | |
| "middle right": "on the right side", | |
| "bottom left": "in the lower left area", | |
| "bottom center": "in the lower area", | |
| "bottom right": "in the lower right area" | |
| } | |
| return region_map.get(clean_region, "") | |
| def _normalize_object_class_name(self, class_name: str) -> str: | |
| """ | |
| 標準化物件類別名稱 | |
| Args: | |
| class_name: 原始類別名稱 | |
| Returns: | |
| str: 標準化後的類別名稱 | |
| """ | |
| if self.text_optimizer: | |
| return self.text_optimizer.normalize_object_class_name(class_name) | |
| else: | |
| # 備用標準化邏輯 | |
| if not class_name or not isinstance(class_name, str): | |
| return "object" | |
| # 簡單的標準化處理 | |
| normalized = class_name.replace('_', ' ').strip().lower() | |
| return normalized | |
| def _format_object_count_description(self, class_name: str, count: int, | |
| scene_type: Optional[str] = None, | |
| detected_objects: Optional[List[Dict]] = None, | |
| avg_confidence: float = 0.0) -> str: | |
| """ | |
| 格式化物件數量描述 | |
| Args: | |
| class_name: 標準化後的類別名稱 | |
| count: 物件數量 | |
| scene_type: 場景類型 | |
| detected_objects: 該類型的所有檢測物件 | |
| avg_confidence: 平均檢測置信度 | |
| Returns: | |
| str: 完整的格式化數量描述 | |
| """ | |
| if self.text_optimizer: | |
| return self.text_optimizer.format_object_count_description( | |
| class_name, count, scene_type, detected_objects, avg_confidence | |
| ) | |
| else: | |
| # 備用格式化邏輯 | |
| if count <= 0: | |
| return "" | |
| elif count == 1: | |
| article = "an" if class_name[0].lower() in 'aeiou' else "a" | |
| return f"{article} {class_name}" | |
| else: | |
| # 簡單的複數處理 | |
| plural_form = class_name + "s" if not class_name.endswith("s") else class_name | |
| number_words = { | |
| 2: "two", 3: "three", 4: "four", 5: "five", 6: "six", | |
| 7: "seven", 8: "eight", 9: "nine", 10: "ten", | |
| 11: "eleven", 12: "twelve" | |
| } | |
| if count in number_words: | |
| return f"{number_words[count]} {plural_form}" | |
| elif count <= 20: | |
| return f"several {plural_form}" | |
| else: | |
| return f"numerous {plural_form}" | |