Theo Viel commited on
Commit
da3baa6
·
1 Parent(s): ac224e3

Advanced post-processing

Browse files
Demo.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1452423fcb5fbc1cb08085f1169727e61238f763d0994f3d8a98b98621a0fc89
3
- size 302483
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98c4ca2b9c91864ea8f45b0d91900ce9460582ce2a9419a02efe8d5188f60b88
3
+ size 1482484
post_processing/page_elt_pp.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def expand_boxes(boxes, r_x=(1, 1), r_y=(1, 1), size_agnostic=True):
5
+ """
6
+ Expands bounding boxes by a specified ratio.
7
+ Expected box format is normalized [x_min, y_min, x_max, y_max].
8
+
9
+ Args:
10
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
11
+ r_x (tuple, optional): Left, right expansion ratios. Defaults to (1, 1) (no expansion).
12
+ r_y (tuple, optional): Up, down expansion ratios. Defaults to (1, 1) (no expansion).
13
+ size_agnostic (bool, optional): Expand independently of the bbox shape. Defaults to True.
14
+
15
+ Returns:
16
+ numpy.ndarray: Adjusted bounding boxes clipped to the [0, 1] range.
17
+ """
18
+ old_boxes = boxes.copy()
19
+
20
+ if not size_agnostic:
21
+ h = boxes[:, 3] - boxes[:, 1]
22
+ w = boxes[:, 2] - boxes[:, 0]
23
+ else:
24
+ h, w = 1, 1
25
+
26
+ boxes[:, 0] -= w * (r_x[0] - 1) # left
27
+ boxes[:, 2] += w * (r_x[1] - 1) # right
28
+ boxes[:, 1] -= h * (r_y[0] - 1) # up
29
+ boxes[:, 3] += h * (r_y[1] - 1) # down
30
+
31
+ boxes = np.clip(boxes, 0, 1)
32
+
33
+ # Enforce non-overlapping boxes
34
+ for i in range(len(boxes)):
35
+ for j in range(i + 1, len(boxes)):
36
+ iou = bb_iou_array(boxes[i][None], boxes[j])[0]
37
+ old_iou = bb_iou_array(old_boxes[i][None], old_boxes[j])[0]
38
+ # print(iou, old_iou)
39
+ if iou > 0.05 and old_iou < 0.1:
40
+ if boxes[i, 1] < boxes[j, 1]: # i above j
41
+ boxes[j, 1] = min(old_boxes[j, 1], boxes[i, 3])
42
+ if old_iou > 0:
43
+ boxes[i, 3] = max(old_boxes[i, 3], boxes[j, 1])
44
+ else:
45
+ boxes[i, 1] = min(old_boxes[i, 1], boxes[j, 3])
46
+ if old_iou > 0:
47
+ boxes[j, 3] = max(old_boxes[j, 3], boxes[i, 1])
48
+
49
+ return boxes
50
+
51
+
52
+ def merge_boxes(b1, b2):
53
+ """
54
+ Merges two bounding boxes into a single box that encompasses both.
55
+
56
+ Args:
57
+ b1 (numpy.ndarray): First bounding box [x_min, y_min, x_max, y_max].
58
+ b2 (numpy.ndarray): Second bounding box [x_min, y_min, x_max, y_max].
59
+
60
+ Returns:
61
+ numpy.ndarray: A single bounding box that covers both input boxes.
62
+ """
63
+ b = b1.copy()
64
+ b[0] = min(b1[0], b2[0])
65
+ b[1] = min(b1[1], b2[1])
66
+ b[2] = max(b1[2], b2[2])
67
+ b[3] = max(b1[3], b2[3])
68
+ return b
69
+
70
+
71
+ def bb_iou_array(boxes, new_box):
72
+ """
73
+ Calculates the Intersection over Union (IoU) between a box and an array of boxes.
74
+
75
+ Args:
76
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
77
+ new_box (numpy.ndarray): A single bounding box [x_min, y_min, x_max, y_max].
78
+
79
+ Returns:
80
+ numpy.ndarray: Array of IoU values between the new_box and each box in the array.
81
+ """
82
+ # bb interesection over union
83
+ xA = np.maximum(boxes[:, 0], new_box[0])
84
+ yA = np.maximum(boxes[:, 1], new_box[1])
85
+ xB = np.minimum(boxes[:, 2], new_box[2])
86
+ yB = np.minimum(boxes[:, 3], new_box[3])
87
+
88
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
89
+
90
+ # compute the area of both the prediction and ground-truth rectangles
91
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
92
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
93
+
94
+ iou = interArea / (boxAArea + boxBArea - interArea)
95
+
96
+ return iou
97
+
98
+
99
+ def match_with_title(
100
+ bbox, title_bboxes, match_dist=0.1, delta=1.5, already_matched=[]
101
+ ):
102
+ """
103
+ Matches a bounding box with a title bounding box based on IoU or proximity.
104
+
105
+ Args:
106
+ bbox (numpy.ndarray): Bounding box to match with title [x_min, y_min, x_max, y_max].
107
+ title_bboxes (numpy.ndarray): Array of title bounding boxes with shape (N, 4).
108
+ match_dist (float, optional): Maximum distance for matching. Defaults to 0.1.
109
+ delta (float, optional): Multiplier for matching several titles. Defaults to 1.5.
110
+ already_matched (list, optional): List of already matched title indices. Defaults to [].
111
+
112
+ Returns:
113
+ tuple or None: If matched, returns a tuple of (merged_bbox, updated_title_bboxes).
114
+ If no match is found, returns None, None.
115
+ """
116
+ if not len(title_bboxes):
117
+ return None, None
118
+
119
+ dist_above = np.abs(title_bboxes[:, 3] - bbox[1])
120
+ dist_below = np.abs(bbox[3] - title_bboxes[:, 1])
121
+
122
+ dist_left = np.abs(title_bboxes[:, 0] - bbox[0])
123
+ dist_center = np.abs(title_bboxes[:, 0] + title_bboxes[:, 2] - bbox[0] - bbox[2]) / 2
124
+
125
+ dists = np.min([dist_above, dist_below], 0)
126
+ dists += np.min([dist_left, dist_center], 0) / 2
127
+
128
+ ious = bb_iou_array(title_bboxes, bbox)
129
+ dists = np.where(ious > 0, min(match_dist, np.min(dists)), dists)
130
+
131
+ if len(already_matched):
132
+ dists[already_matched] = match_dist * 10 # Remove already matched titles
133
+
134
+ # print(dists)
135
+ matches = None # noqa
136
+ if np.min(dists) <= match_dist:
137
+ matches = np.where(
138
+ dists <= min(match_dist, np.min(dists) * delta)
139
+ )[0]
140
+
141
+ if matches is not None:
142
+ new_bbox = bbox
143
+ for match in matches:
144
+ new_bbox = merge_boxes(new_bbox, title_bboxes[match])
145
+ return new_bbox, list(matches)
146
+ else:
147
+ return None, None
148
+
149
+
150
+ def match_boxes_with_title(
151
+ boxes, confs, labels, classes, to_match_labels=["chart"], remove_matched_titles=False
152
+ ):
153
+ """
154
+ Matches charts with title.
155
+
156
+ Args:
157
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
158
+ confs (numpy.ndarray): Array of confidence scores with shape (N,).
159
+ labels (numpy.ndarray): Array of labels with shape (N,).
160
+ classes (list): List of class names.
161
+ to_match_labels (list): List of class names to match with titles.
162
+ remove_matched_titles (bool): Whether to remove matched titles from the boxes.
163
+
164
+ Returns:
165
+ boxes (numpy.ndarray): Array of bounding boxes with shape (M, 4).
166
+ confs (numpy.ndarray): Array of confidence scores with shape (M,).
167
+ labels (numpy.ndarray): Array of labels with shape (M,).
168
+ found_title (list): List of indices of matched titles.
169
+ no_found_title (list): List of indices of unmatched titles.
170
+ """
171
+ # Put titles at the end
172
+ title_ids = np.where(labels == classes.index("title"))[0]
173
+ order = np.concatenate([np.delete(np.arange(len(boxes)), title_ids), title_ids])
174
+ boxes = boxes[order]
175
+ confs = confs[order]
176
+ labels = labels[order]
177
+
178
+ # Ids
179
+ title_ids = np.where(labels == classes.index("title"))[0]
180
+ to_match = np.where(np.isin(labels, [classes.index(c) for c in to_match_labels]))[0]
181
+
182
+ # Matching
183
+ found_title, already_matched = [], []
184
+ for i in range(len(boxes)):
185
+ if i not in to_match:
186
+ continue
187
+ merged_box, matched_title_ids = match_with_title(
188
+ boxes[i],
189
+ boxes[title_ids],
190
+ already_matched=already_matched,
191
+ )
192
+ if matched_title_ids is not None:
193
+ # print(f'Merged {classes[int(labels[i])]} at idx #{i} with title {matched_title_ids[-1]}') # noqa
194
+ boxes[i] = merged_box
195
+ already_matched += matched_title_ids
196
+ found_title.append(i)
197
+
198
+ if remove_matched_titles and len(already_matched):
199
+ boxes = np.delete(boxes, title_ids[already_matched], axis=0)
200
+ confs = np.delete(confs, title_ids[already_matched], axis=0)
201
+ labels = np.delete(labels, title_ids[already_matched], axis=0)
202
+
203
+ return boxes, confs, labels, found_title
post_processing/text_pp.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_overlaps(boxes, other_boxes, normalize="box_only"):
5
+ """
6
+ Checks if a box overlaps with any other box.
7
+ Boxes are expeceted in format (x0, y0, x1, y1)
8
+
9
+ Args:
10
+ boxes (np array [4] or [n x 4]): Boxes.
11
+ other_boxes (np array [m x 4]): Other boxes.
12
+
13
+ Returns:
14
+ np array [n x m]: Overlaps.
15
+ """
16
+ if boxes.ndim == 1:
17
+ boxes = boxes[None, :]
18
+
19
+ x0, y0, x1, y1 = (
20
+ boxes[:, 0][:, None], boxes[:, 1][:, None], boxes[:, 2][:, None], boxes[:, 3][:, None]
21
+ )
22
+ areas = ((y1 - y0) * (x1 - x0))
23
+
24
+ x0_other, y0_other, x1_other, y1_other = (
25
+ other_boxes[:, 0][None, :],
26
+ other_boxes[:, 1][None, :],
27
+ other_boxes[:, 2][None, :],
28
+ other_boxes[:, 3][None, :]
29
+ )
30
+ areas_other = ((y1_other - y0_other) * (x1_other - x0_other))
31
+
32
+ # Intersection
33
+ inter_y0 = np.maximum(y0, y0_other)
34
+ inter_y1 = np.minimum(y1, y1_other)
35
+ inter_x0 = np.maximum(x0, x0_other)
36
+ inter_x1 = np.minimum(x1, x1_other)
37
+ inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0)
38
+
39
+ # Overlap
40
+ if normalize == "box_only": # Only consider box included in other box
41
+ overlaps = inter_area / areas
42
+ elif normalize == "all": # Consider box included in other box and other box included in box
43
+ overlaps = inter_area / np.minimum(areas, areas_other[:, None])
44
+ else:
45
+ raise ValueError(f"Invalid normalization: {normalize}")
46
+ return overlaps
47
+
48
+
49
+ def get_distances(title_boxes, other_boxes):
50
+ """
51
+ Computes the distances between title and table/chart boxes.
52
+ Distance is computed as the sum of the vertical and horizontal distances.
53
+ Horizontal distance uses min(boxes center dist, boxes left dist).
54
+ Vertical distance uses min(top_title to bottom_other dists, bottom_title to top_other dists).
55
+
56
+ Args:
57
+ title_boxes (np array [n_titles x 4]): Title boxes.
58
+ other_boxes (np array [n_other x 4]): Other boxes.
59
+
60
+ Returns:
61
+ np array [n_titles x n_other]: Distances between titles and other boxes.
62
+ """
63
+ x0_title, xc_title, y0_title, y1_title = (
64
+ title_boxes[:, 0],
65
+ (title_boxes[:, 0] + title_boxes[:, 2]) / 2,
66
+ title_boxes[:, 1],
67
+ title_boxes[:, 3]
68
+ )
69
+ x0_other, xc_other, y0_other, y1_other = (
70
+ other_boxes[:, 0],
71
+ (other_boxes[:, 0] + other_boxes[:, 2]) / 2,
72
+ other_boxes[:, 1],
73
+ other_boxes[:, 3]
74
+ )
75
+
76
+ x_dists = np.min([
77
+ np.abs(xc_title[:, None] - xc_other[None, :]), # Title center to other center
78
+ np.abs(x0_title[:, None] - x0_other[None, :]), # Title left to other left
79
+ ], axis=0)
80
+
81
+ y_dists = np.min([
82
+ np.abs(y1_title[:, None] - y0_other[None, :]), # Title above other
83
+ np.abs(y0_title[:, None] - y1_other[None, :]), # Title below other
84
+ ], axis=0)
85
+
86
+ dists = y_dists + x_dists / 2
87
+ return dists
88
+
89
+
90
+ def find_titles(title_boxes, table_boxes, chart_boxes, max_dist=0.1):
91
+ """
92
+ Associates titles to tables and charts.
93
+
94
+ Args:
95
+ title_boxes (np array [n_titles x 4]): Title boxes.
96
+ table_boxes (np array [n_tables x 4]): Table boxes.
97
+ chart_boxes (np array [n_charts x 4]): Chart boxes.
98
+ max_dist (float, optional): Maximum distance between title and table/chart. Defaults to 0.1.
99
+
100
+ Returns:
101
+ dict: Dictionary of assigned titles.
102
+ - Keys are the indices of the titles,
103
+ - Values are tuples of:
104
+ - str: Whether the title is assigned to a "chart" or "table"
105
+ - int: index of the assigned table/chart
106
+ """
107
+ if not len(title_boxes) or not (len(table_boxes) or len(chart_boxes)):
108
+ return {}
109
+
110
+ # print(title_boxes.shape, table_boxes.shape, chart_boxes.shape)
111
+
112
+ # Get distances
113
+ chart_distances = np.ones((len(title_boxes), 0))
114
+ if len(chart_boxes):
115
+ chart_distances = get_distances(title_boxes, chart_boxes)
116
+ chart_overlaps = get_overlaps(title_boxes, chart_boxes, normalize="box_only")
117
+ # print(chart_overlaps, "chart_overlaps", chart_overlaps.shape)
118
+ # print(chart_distances, "chart_distances", chart_distances.shape)
119
+ chart_distances = np.where(chart_overlaps > 0.25, 0, chart_distances)
120
+
121
+ # print(chart_distances)
122
+
123
+ table_distances = np.ones((len(title_boxes), 0))
124
+ if len(table_boxes):
125
+ table_distances = get_distances(title_boxes, table_boxes)
126
+ if len(chart_boxes): # Penalize table titles that are inside charts
127
+ table_distances = np.where(
128
+ chart_overlaps.max(1, keepdims=True) > 0.25, table_distances * 10, table_distances
129
+ )
130
+
131
+ # print(table_distances, "table_distances")
132
+
133
+ # Assign to tables
134
+ assigned_titles = {}
135
+ for i, table in enumerate(table_boxes):
136
+ best_match = np.argmin(table_distances[:, i])
137
+ if table_distances[best_match, i] < max_dist:
138
+ assigned_titles[best_match] = ("table", i)
139
+ table_distances[best_match] = np.inf
140
+ chart_distances[best_match] = np.inf
141
+
142
+ # Assign to charts
143
+ for i, chart in enumerate(chart_boxes):
144
+ best_match = np.argmin(chart_distances[:, i])
145
+ if chart_distances[best_match, i] < max_dist:
146
+ assigned_titles[best_match] = ("chart", i)
147
+ chart_distances[best_match] = np.inf
148
+
149
+ return assigned_titles
150
+
151
+
152
+ def postprocess_included(
153
+ boxes, labels, confs, class_="title", classes=["table", "chart", "title", "infographic"]
154
+ ):
155
+ """
156
+ Post process title predictions.
157
+ - Remove titles that are included in other boxes
158
+
159
+ Args:
160
+ boxes (numpy.ndarray [N, 4]): Array of bounding boxes.
161
+ labels (numpy.ndarray [N]): Array of labels.
162
+ confs (numpy.ndarray [N]): Array of confidences.
163
+ class_ (str, optional): Class to postprocess. Defaults to "title".
164
+ classes (list, optional): Classes. Defaults to ["table", "chart", "title", "infographic"].
165
+
166
+ Returns:
167
+ boxes (numpy.ndarray): Array of bounding boxes.
168
+ labels (numpy.ndarray): Array of labels.
169
+ confs (numpy.ndarray): Array of confidences.
170
+ """
171
+ boxes_to_pp = boxes[labels == classes.index(class_)]
172
+ confs_to_pp = confs[labels == classes.index(class_)]
173
+
174
+ order = np.argsort(confs_to_pp) # least to most confident for NMS
175
+ boxes_to_pp, confs_to_pp = boxes_to_pp[order], confs_to_pp[order]
176
+
177
+ if len(boxes_to_pp) == 0:
178
+ return boxes, labels, confs
179
+
180
+ # other_boxes = boxes[labels != classes.index("title")]
181
+
182
+ inclusion_classes = ["table", "infographic", "chart"]
183
+ if class_ in ["header_footer", "title"]:
184
+ inclusion_classes.append("paragraph")
185
+
186
+ other_boxes = boxes[np.isin(
187
+ labels,
188
+ [classes.index(c) for c in inclusion_classes])
189
+ ]
190
+
191
+ # Remove boxes included in other_boxes
192
+ kept_boxes, kept_confs = [], []
193
+ for i, b in enumerate(boxes_to_pp):
194
+ # # Inclusion NMS
195
+ # if i < len(titles) - 1:
196
+ # overlaps_titles = get_overlaps(t, titles[i + 1:], normalize="all")
197
+ # if overlaps_titles.max() > 0.9:
198
+ # continue
199
+
200
+ # print(t)
201
+ # print(other_boxes)
202
+ if len(other_boxes) > 0:
203
+ overlaps = get_overlaps(b, other_boxes, normalize="box_only")
204
+ if overlaps.max() > 0.9:
205
+ continue
206
+
207
+ kept_boxes.append(b)
208
+ kept_confs.append(confs_to_pp[i])
209
+
210
+ # Aggregate
211
+ kept_boxes = np.stack(kept_boxes) if len(kept_boxes) else np.empty((0, 4))
212
+ kept_confs = np.stack(kept_confs) if len(kept_confs) else np.empty(0)
213
+
214
+ boxes_pp = np.concatenate(
215
+ [boxes[labels != classes.index(class_)], kept_boxes]
216
+ )
217
+ confs_pp = np.concatenate(
218
+ [confs[labels != classes.index(class_)], kept_confs]
219
+ )
220
+ labels_pp = np.concatenate([
221
+ labels[labels != classes.index(class_)],
222
+ np.ones(len(kept_boxes)) * classes.index(class_)
223
+ ])
224
+
225
+ return boxes_pp, labels_pp, confs_pp
post_processing/wbf.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from:
2
+ # https://github.com/ZFTurbo/Weighted-Boxes-Fusion/blob/master/ensemble_boxes/ensemble_boxes_wbf.py
3
+
4
+ import warnings
5
+ import numpy as np
6
+
7
+
8
+ def prefilter_boxes(boxes, scores, labels, weights, thr, class_agnostic=False):
9
+ """
10
+ Reformats and filters boxes.
11
+ Output is a dict of boxes to merge separately.
12
+
13
+ Args:
14
+ boxes (list[np array[n x 4]]): List of boxes. One list per model.
15
+ scores (list[np array[n]]): List of confidences.
16
+ labels (list[np array[n]]): List of labels.
17
+ weights (list): Model weights.
18
+ thr (float): Confidence threshold
19
+ class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
20
+
21
+ Returns:
22
+ dict[np array [? x 8]]: Filtered boxes.
23
+ """
24
+ # Create dict with boxes stored by its label
25
+ new_boxes = dict()
26
+
27
+ for t in range(len(boxes)):
28
+ assert len(boxes[t]) == len(scores[t]), "len(boxes) != len(scores)"
29
+ assert len(boxes[t]) == len(labels[t]), "len(boxes) != len(labels)"
30
+
31
+ for j in range(len(boxes[t])):
32
+ score = scores[t][j]
33
+ if score < thr:
34
+ continue
35
+ label = int(labels[t][j])
36
+ box_part = boxes[t][j]
37
+ x1 = float(box_part[0])
38
+ y1 = float(box_part[1])
39
+ x2 = float(box_part[2])
40
+ y2 = float(box_part[3])
41
+
42
+ # Box data checks
43
+ if x2 < x1:
44
+ warnings.warn("X2 < X1 value in box. Swap them.")
45
+ x1, x2 = x2, x1
46
+ if y2 < y1:
47
+ warnings.warn("Y2 < Y1 value in box. Swap them.")
48
+ y1, y2 = y2, y1
49
+
50
+ array = np.array([x1, x2, y1, y2])
51
+ if array.min() < 0 or array.max() > 1:
52
+ warnings.warn("Coordinates outside [0, 1]")
53
+ array = np.clip(array, 0, 1)
54
+ x1, x2, y1, y2 = array
55
+
56
+ if (x2 - x1) * (y2 - y1) == 0.0:
57
+ warnings.warn("Zero area box skipped: {}.".format(box_part))
58
+ continue
59
+
60
+ # [label, score, weight, model index, x1, y1, x2, y2]
61
+ b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
62
+
63
+ label_k = "*" if class_agnostic else label
64
+ if label_k not in new_boxes:
65
+ new_boxes[label_k] = []
66
+ new_boxes[label_k].append(b)
67
+
68
+ # Sort each list in dict by score and transform it to numpy array
69
+ for k in new_boxes:
70
+ current_boxes = np.array(new_boxes[k])
71
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
72
+
73
+ return new_boxes
74
+
75
+
76
+ def merge_labels(labels, confs):
77
+ """
78
+ Custom function for merging labels.
79
+ If all labels are the same, return the unique value.
80
+ Else, return the label of the most confident non-title (class 2) box.
81
+
82
+ Args:
83
+ labels (np array [n]): Labels.
84
+ confs (np array [n]): Confidence.
85
+
86
+ Returns:
87
+ int: Label.
88
+ """
89
+ if len(np.unique(labels)) == 1:
90
+ return labels[0]
91
+ else: # Most confident and not a title
92
+ confs = confs[confs != 2]
93
+ labels = labels[labels != 2]
94
+ return labels[np.argmax(confs)]
95
+
96
+
97
+ def get_weighted_box(boxes, conf_type="avg"):
98
+ """
99
+ Merges boxes by using the weighted fusion.
100
+
101
+ Args:
102
+ boxes (np array [n x 8]): Boxes to merge.
103
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
104
+
105
+ Returns:
106
+ np array [8]: Merged box.
107
+ """
108
+ box = np.zeros(8, dtype=np.float32)
109
+ conf = 0
110
+ conf_list = []
111
+ w = 0
112
+ for b in boxes:
113
+ box[4:] += b[1] * b[4:]
114
+ conf += b[1]
115
+ conf_list.append(b[1])
116
+ w += b[2]
117
+
118
+ box[0] = merge_labels(
119
+ np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
120
+ )
121
+
122
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
123
+ box[2] = w
124
+ box[3] = -1 # model index field is retained for consistency but is not used.
125
+ box[4:] /= conf
126
+ return box
127
+
128
+
129
+ def get_biggest_box(boxes, conf_type="avg"):
130
+ """
131
+ Merges boxes by using the biggest box.
132
+
133
+ Args:
134
+ boxes (np array [n x 8]): Boxes to merge.
135
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
136
+
137
+ Returns:
138
+ np array [8]: Merged box.
139
+ """
140
+ box = np.zeros(8, dtype=np.float32)
141
+ box[4:] = boxes[0][4:]
142
+ conf_list = []
143
+ w = 0
144
+ for b in boxes:
145
+ box[4] = min(box[4], b[4])
146
+ box[5] = min(box[5], b[5])
147
+ box[6] = max(box[6], b[6])
148
+ box[7] = max(box[7], b[7])
149
+ conf_list.append(b[1])
150
+ w += b[2]
151
+
152
+ box[0] = merge_labels(
153
+ np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
154
+ )
155
+ # print(box[0], np.array([b[0] for b in boxes]))
156
+
157
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
158
+ box[2] = w
159
+ box[3] = -1 # model index field is retained for consistency but is not used.
160
+ return box
161
+
162
+
163
+ def find_matching_box_fast(boxes_list, new_box, match_iou):
164
+ """
165
+ Reimplementation of find_matching_box with numpy instead of loops.
166
+ Gives significant speed up for larger arrays (~100x).
167
+ This was previously the bottleneck since the function is called for every entry in the array.
168
+ """
169
+
170
+ def bb_iou_array(boxes, new_box):
171
+ # bb interesection over union
172
+ xA = np.maximum(boxes[:, 0], new_box[0])
173
+ yA = np.maximum(boxes[:, 1], new_box[1])
174
+ xB = np.minimum(boxes[:, 2], new_box[2])
175
+ yB = np.minimum(boxes[:, 3], new_box[3])
176
+
177
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
178
+
179
+ # compute the area of both the prediction and ground-truth rectangles
180
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
181
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
182
+
183
+ iou = interArea / (boxAArea + boxBArea - interArea)
184
+
185
+ return iou
186
+
187
+ if boxes_list.shape[0] == 0:
188
+ return -1, match_iou
189
+
190
+ ious = bb_iou_array(boxes_list[:, 4:], new_box[4:])
191
+ # ious[boxes[:, 0] != new_box[0]] = -1
192
+
193
+ best_idx = np.argmax(ious)
194
+ best_iou = ious[best_idx]
195
+
196
+ if best_iou <= match_iou:
197
+ best_iou = match_iou
198
+ best_idx = -1
199
+
200
+ return best_idx, best_iou
201
+
202
+
203
+ def weighted_boxes_fusion(
204
+ boxes_list,
205
+ scores_list,
206
+ labels_list,
207
+ iou_thr=0.5,
208
+ skip_box_thr=0.0,
209
+ conf_type="avg",
210
+ merge_type="weighted",
211
+ class_agnostic=False,
212
+ ):
213
+ """
214
+ Custom WBF implementation that supports a class_agnostic mode and a biggest box fusion.
215
+ Boxes are expected to be in normalized (x0, y0, x1, y1) format.
216
+
217
+ Args:
218
+ boxes_list (list[np.ndarray[n x 4]]): List of boxes. One list per model.
219
+ scores_list (list[np.ndarray[n]]): List of confidences.
220
+ labels_list (list[np.ndarray[n]]): List of labels.
221
+ iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55.
222
+ skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0.
223
+ conf_type (str, optional): Confidence merging type ("avg" or "max"). Defaults to "avg".
224
+ merge_type (str, optional): Merge type ("weighted" or "biggest"). Defaults to "weighted".
225
+ class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
226
+
227
+ Returns:
228
+ np array[N x 4]: Merged boxes,
229
+ np array[N]: Merged confidences,
230
+ np array[N]: Merged labels.
231
+ """
232
+ weights = np.ones(len(boxes_list))
233
+
234
+ assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"'
235
+ assert merge_type in ["weighted", "biggest"], 'Conf type must be "weighted" or "biggest"'
236
+
237
+ filtered_boxes = prefilter_boxes(
238
+ boxes_list,
239
+ scores_list,
240
+ labels_list,
241
+ weights,
242
+ skip_box_thr,
243
+ class_agnostic=class_agnostic,
244
+ )
245
+ if len(filtered_boxes) == 0:
246
+ return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
247
+
248
+ overall_boxes = []
249
+ for label in filtered_boxes:
250
+ boxes = filtered_boxes[label]
251
+ clusters = []
252
+
253
+ # Clusterize boxes
254
+ for j in range(len(boxes)):
255
+ ids = [i for i in range(len(boxes)) if i != j]
256
+ index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr)
257
+
258
+ if index != -1:
259
+ index = ids[index]
260
+ cluster_idx = [
261
+ clust_idx
262
+ for clust_idx, clust in enumerate(clusters)
263
+ if (j in clust or index in clust)
264
+ ]
265
+ if len(cluster_idx):
266
+ cluster_idx = cluster_idx[0]
267
+ clusters[cluster_idx] = list(
268
+ set(clusters[cluster_idx] + [index, j])
269
+ )
270
+ else:
271
+ clusters.append([index, j])
272
+ else:
273
+ clusters.append([j])
274
+
275
+ for j, c in enumerate(clusters):
276
+ if merge_type == "weighted":
277
+ weighted_box = get_weighted_box(boxes[c], conf_type)
278
+ elif merge_type == "biggest":
279
+ weighted_box = get_biggest_box(boxes[c], conf_type)
280
+
281
+ if conf_type == "max":
282
+ weighted_box[1] = weighted_box[1] / weights.max()
283
+ else: # avg
284
+ weighted_box[1] = weighted_box[1] * len(c) / weights.sum()
285
+ overall_boxes.append(weighted_box)
286
+
287
+ overall_boxes = np.array(overall_boxes)
288
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
289
+ boxes = overall_boxes[:, 4:]
290
+ scores = overall_boxes[:, 1]
291
+ labels = overall_boxes[:, 0]
292
+ return boxes, scores, labels