Theo Viel
commited on
Commit
·
da3baa6
1
Parent(s):
ac224e3
Advanced post-processing
Browse files- Demo.ipynb +2 -2
- post_processing/page_elt_pp.py +203 -0
- post_processing/text_pp.py +225 -0
- post_processing/wbf.py +292 -0
Demo.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98c4ca2b9c91864ea8f45b0d91900ce9460582ce2a9419a02efe8d5188f60b88
|
| 3 |
+
size 1482484
|
post_processing/page_elt_pp.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def expand_boxes(boxes, r_x=(1, 1), r_y=(1, 1), size_agnostic=True):
|
| 5 |
+
"""
|
| 6 |
+
Expands bounding boxes by a specified ratio.
|
| 7 |
+
Expected box format is normalized [x_min, y_min, x_max, y_max].
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
|
| 11 |
+
r_x (tuple, optional): Left, right expansion ratios. Defaults to (1, 1) (no expansion).
|
| 12 |
+
r_y (tuple, optional): Up, down expansion ratios. Defaults to (1, 1) (no expansion).
|
| 13 |
+
size_agnostic (bool, optional): Expand independently of the bbox shape. Defaults to True.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
numpy.ndarray: Adjusted bounding boxes clipped to the [0, 1] range.
|
| 17 |
+
"""
|
| 18 |
+
old_boxes = boxes.copy()
|
| 19 |
+
|
| 20 |
+
if not size_agnostic:
|
| 21 |
+
h = boxes[:, 3] - boxes[:, 1]
|
| 22 |
+
w = boxes[:, 2] - boxes[:, 0]
|
| 23 |
+
else:
|
| 24 |
+
h, w = 1, 1
|
| 25 |
+
|
| 26 |
+
boxes[:, 0] -= w * (r_x[0] - 1) # left
|
| 27 |
+
boxes[:, 2] += w * (r_x[1] - 1) # right
|
| 28 |
+
boxes[:, 1] -= h * (r_y[0] - 1) # up
|
| 29 |
+
boxes[:, 3] += h * (r_y[1] - 1) # down
|
| 30 |
+
|
| 31 |
+
boxes = np.clip(boxes, 0, 1)
|
| 32 |
+
|
| 33 |
+
# Enforce non-overlapping boxes
|
| 34 |
+
for i in range(len(boxes)):
|
| 35 |
+
for j in range(i + 1, len(boxes)):
|
| 36 |
+
iou = bb_iou_array(boxes[i][None], boxes[j])[0]
|
| 37 |
+
old_iou = bb_iou_array(old_boxes[i][None], old_boxes[j])[0]
|
| 38 |
+
# print(iou, old_iou)
|
| 39 |
+
if iou > 0.05 and old_iou < 0.1:
|
| 40 |
+
if boxes[i, 1] < boxes[j, 1]: # i above j
|
| 41 |
+
boxes[j, 1] = min(old_boxes[j, 1], boxes[i, 3])
|
| 42 |
+
if old_iou > 0:
|
| 43 |
+
boxes[i, 3] = max(old_boxes[i, 3], boxes[j, 1])
|
| 44 |
+
else:
|
| 45 |
+
boxes[i, 1] = min(old_boxes[i, 1], boxes[j, 3])
|
| 46 |
+
if old_iou > 0:
|
| 47 |
+
boxes[j, 3] = max(old_boxes[j, 3], boxes[i, 1])
|
| 48 |
+
|
| 49 |
+
return boxes
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def merge_boxes(b1, b2):
|
| 53 |
+
"""
|
| 54 |
+
Merges two bounding boxes into a single box that encompasses both.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
b1 (numpy.ndarray): First bounding box [x_min, y_min, x_max, y_max].
|
| 58 |
+
b2 (numpy.ndarray): Second bounding box [x_min, y_min, x_max, y_max].
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
numpy.ndarray: A single bounding box that covers both input boxes.
|
| 62 |
+
"""
|
| 63 |
+
b = b1.copy()
|
| 64 |
+
b[0] = min(b1[0], b2[0])
|
| 65 |
+
b[1] = min(b1[1], b2[1])
|
| 66 |
+
b[2] = max(b1[2], b2[2])
|
| 67 |
+
b[3] = max(b1[3], b2[3])
|
| 68 |
+
return b
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def bb_iou_array(boxes, new_box):
|
| 72 |
+
"""
|
| 73 |
+
Calculates the Intersection over Union (IoU) between a box and an array of boxes.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
|
| 77 |
+
new_box (numpy.ndarray): A single bounding box [x_min, y_min, x_max, y_max].
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
numpy.ndarray: Array of IoU values between the new_box and each box in the array.
|
| 81 |
+
"""
|
| 82 |
+
# bb interesection over union
|
| 83 |
+
xA = np.maximum(boxes[:, 0], new_box[0])
|
| 84 |
+
yA = np.maximum(boxes[:, 1], new_box[1])
|
| 85 |
+
xB = np.minimum(boxes[:, 2], new_box[2])
|
| 86 |
+
yB = np.minimum(boxes[:, 3], new_box[3])
|
| 87 |
+
|
| 88 |
+
interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
|
| 89 |
+
|
| 90 |
+
# compute the area of both the prediction and ground-truth rectangles
|
| 91 |
+
boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 92 |
+
boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
|
| 93 |
+
|
| 94 |
+
iou = interArea / (boxAArea + boxBArea - interArea)
|
| 95 |
+
|
| 96 |
+
return iou
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def match_with_title(
|
| 100 |
+
bbox, title_bboxes, match_dist=0.1, delta=1.5, already_matched=[]
|
| 101 |
+
):
|
| 102 |
+
"""
|
| 103 |
+
Matches a bounding box with a title bounding box based on IoU or proximity.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
bbox (numpy.ndarray): Bounding box to match with title [x_min, y_min, x_max, y_max].
|
| 107 |
+
title_bboxes (numpy.ndarray): Array of title bounding boxes with shape (N, 4).
|
| 108 |
+
match_dist (float, optional): Maximum distance for matching. Defaults to 0.1.
|
| 109 |
+
delta (float, optional): Multiplier for matching several titles. Defaults to 1.5.
|
| 110 |
+
already_matched (list, optional): List of already matched title indices. Defaults to [].
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
tuple or None: If matched, returns a tuple of (merged_bbox, updated_title_bboxes).
|
| 114 |
+
If no match is found, returns None, None.
|
| 115 |
+
"""
|
| 116 |
+
if not len(title_bboxes):
|
| 117 |
+
return None, None
|
| 118 |
+
|
| 119 |
+
dist_above = np.abs(title_bboxes[:, 3] - bbox[1])
|
| 120 |
+
dist_below = np.abs(bbox[3] - title_bboxes[:, 1])
|
| 121 |
+
|
| 122 |
+
dist_left = np.abs(title_bboxes[:, 0] - bbox[0])
|
| 123 |
+
dist_center = np.abs(title_bboxes[:, 0] + title_bboxes[:, 2] - bbox[0] - bbox[2]) / 2
|
| 124 |
+
|
| 125 |
+
dists = np.min([dist_above, dist_below], 0)
|
| 126 |
+
dists += np.min([dist_left, dist_center], 0) / 2
|
| 127 |
+
|
| 128 |
+
ious = bb_iou_array(title_bboxes, bbox)
|
| 129 |
+
dists = np.where(ious > 0, min(match_dist, np.min(dists)), dists)
|
| 130 |
+
|
| 131 |
+
if len(already_matched):
|
| 132 |
+
dists[already_matched] = match_dist * 10 # Remove already matched titles
|
| 133 |
+
|
| 134 |
+
# print(dists)
|
| 135 |
+
matches = None # noqa
|
| 136 |
+
if np.min(dists) <= match_dist:
|
| 137 |
+
matches = np.where(
|
| 138 |
+
dists <= min(match_dist, np.min(dists) * delta)
|
| 139 |
+
)[0]
|
| 140 |
+
|
| 141 |
+
if matches is not None:
|
| 142 |
+
new_bbox = bbox
|
| 143 |
+
for match in matches:
|
| 144 |
+
new_bbox = merge_boxes(new_bbox, title_bboxes[match])
|
| 145 |
+
return new_bbox, list(matches)
|
| 146 |
+
else:
|
| 147 |
+
return None, None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def match_boxes_with_title(
|
| 151 |
+
boxes, confs, labels, classes, to_match_labels=["chart"], remove_matched_titles=False
|
| 152 |
+
):
|
| 153 |
+
"""
|
| 154 |
+
Matches charts with title.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
|
| 158 |
+
confs (numpy.ndarray): Array of confidence scores with shape (N,).
|
| 159 |
+
labels (numpy.ndarray): Array of labels with shape (N,).
|
| 160 |
+
classes (list): List of class names.
|
| 161 |
+
to_match_labels (list): List of class names to match with titles.
|
| 162 |
+
remove_matched_titles (bool): Whether to remove matched titles from the boxes.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
boxes (numpy.ndarray): Array of bounding boxes with shape (M, 4).
|
| 166 |
+
confs (numpy.ndarray): Array of confidence scores with shape (M,).
|
| 167 |
+
labels (numpy.ndarray): Array of labels with shape (M,).
|
| 168 |
+
found_title (list): List of indices of matched titles.
|
| 169 |
+
no_found_title (list): List of indices of unmatched titles.
|
| 170 |
+
"""
|
| 171 |
+
# Put titles at the end
|
| 172 |
+
title_ids = np.where(labels == classes.index("title"))[0]
|
| 173 |
+
order = np.concatenate([np.delete(np.arange(len(boxes)), title_ids), title_ids])
|
| 174 |
+
boxes = boxes[order]
|
| 175 |
+
confs = confs[order]
|
| 176 |
+
labels = labels[order]
|
| 177 |
+
|
| 178 |
+
# Ids
|
| 179 |
+
title_ids = np.where(labels == classes.index("title"))[0]
|
| 180 |
+
to_match = np.where(np.isin(labels, [classes.index(c) for c in to_match_labels]))[0]
|
| 181 |
+
|
| 182 |
+
# Matching
|
| 183 |
+
found_title, already_matched = [], []
|
| 184 |
+
for i in range(len(boxes)):
|
| 185 |
+
if i not in to_match:
|
| 186 |
+
continue
|
| 187 |
+
merged_box, matched_title_ids = match_with_title(
|
| 188 |
+
boxes[i],
|
| 189 |
+
boxes[title_ids],
|
| 190 |
+
already_matched=already_matched,
|
| 191 |
+
)
|
| 192 |
+
if matched_title_ids is not None:
|
| 193 |
+
# print(f'Merged {classes[int(labels[i])]} at idx #{i} with title {matched_title_ids[-1]}') # noqa
|
| 194 |
+
boxes[i] = merged_box
|
| 195 |
+
already_matched += matched_title_ids
|
| 196 |
+
found_title.append(i)
|
| 197 |
+
|
| 198 |
+
if remove_matched_titles and len(already_matched):
|
| 199 |
+
boxes = np.delete(boxes, title_ids[already_matched], axis=0)
|
| 200 |
+
confs = np.delete(confs, title_ids[already_matched], axis=0)
|
| 201 |
+
labels = np.delete(labels, title_ids[already_matched], axis=0)
|
| 202 |
+
|
| 203 |
+
return boxes, confs, labels, found_title
|
post_processing/text_pp.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_overlaps(boxes, other_boxes, normalize="box_only"):
|
| 5 |
+
"""
|
| 6 |
+
Checks if a box overlaps with any other box.
|
| 7 |
+
Boxes are expeceted in format (x0, y0, x1, y1)
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
boxes (np array [4] or [n x 4]): Boxes.
|
| 11 |
+
other_boxes (np array [m x 4]): Other boxes.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
np array [n x m]: Overlaps.
|
| 15 |
+
"""
|
| 16 |
+
if boxes.ndim == 1:
|
| 17 |
+
boxes = boxes[None, :]
|
| 18 |
+
|
| 19 |
+
x0, y0, x1, y1 = (
|
| 20 |
+
boxes[:, 0][:, None], boxes[:, 1][:, None], boxes[:, 2][:, None], boxes[:, 3][:, None]
|
| 21 |
+
)
|
| 22 |
+
areas = ((y1 - y0) * (x1 - x0))
|
| 23 |
+
|
| 24 |
+
x0_other, y0_other, x1_other, y1_other = (
|
| 25 |
+
other_boxes[:, 0][None, :],
|
| 26 |
+
other_boxes[:, 1][None, :],
|
| 27 |
+
other_boxes[:, 2][None, :],
|
| 28 |
+
other_boxes[:, 3][None, :]
|
| 29 |
+
)
|
| 30 |
+
areas_other = ((y1_other - y0_other) * (x1_other - x0_other))
|
| 31 |
+
|
| 32 |
+
# Intersection
|
| 33 |
+
inter_y0 = np.maximum(y0, y0_other)
|
| 34 |
+
inter_y1 = np.minimum(y1, y1_other)
|
| 35 |
+
inter_x0 = np.maximum(x0, x0_other)
|
| 36 |
+
inter_x1 = np.minimum(x1, x1_other)
|
| 37 |
+
inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0)
|
| 38 |
+
|
| 39 |
+
# Overlap
|
| 40 |
+
if normalize == "box_only": # Only consider box included in other box
|
| 41 |
+
overlaps = inter_area / areas
|
| 42 |
+
elif normalize == "all": # Consider box included in other box and other box included in box
|
| 43 |
+
overlaps = inter_area / np.minimum(areas, areas_other[:, None])
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"Invalid normalization: {normalize}")
|
| 46 |
+
return overlaps
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_distances(title_boxes, other_boxes):
|
| 50 |
+
"""
|
| 51 |
+
Computes the distances between title and table/chart boxes.
|
| 52 |
+
Distance is computed as the sum of the vertical and horizontal distances.
|
| 53 |
+
Horizontal distance uses min(boxes center dist, boxes left dist).
|
| 54 |
+
Vertical distance uses min(top_title to bottom_other dists, bottom_title to top_other dists).
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
title_boxes (np array [n_titles x 4]): Title boxes.
|
| 58 |
+
other_boxes (np array [n_other x 4]): Other boxes.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
np array [n_titles x n_other]: Distances between titles and other boxes.
|
| 62 |
+
"""
|
| 63 |
+
x0_title, xc_title, y0_title, y1_title = (
|
| 64 |
+
title_boxes[:, 0],
|
| 65 |
+
(title_boxes[:, 0] + title_boxes[:, 2]) / 2,
|
| 66 |
+
title_boxes[:, 1],
|
| 67 |
+
title_boxes[:, 3]
|
| 68 |
+
)
|
| 69 |
+
x0_other, xc_other, y0_other, y1_other = (
|
| 70 |
+
other_boxes[:, 0],
|
| 71 |
+
(other_boxes[:, 0] + other_boxes[:, 2]) / 2,
|
| 72 |
+
other_boxes[:, 1],
|
| 73 |
+
other_boxes[:, 3]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
x_dists = np.min([
|
| 77 |
+
np.abs(xc_title[:, None] - xc_other[None, :]), # Title center to other center
|
| 78 |
+
np.abs(x0_title[:, None] - x0_other[None, :]), # Title left to other left
|
| 79 |
+
], axis=0)
|
| 80 |
+
|
| 81 |
+
y_dists = np.min([
|
| 82 |
+
np.abs(y1_title[:, None] - y0_other[None, :]), # Title above other
|
| 83 |
+
np.abs(y0_title[:, None] - y1_other[None, :]), # Title below other
|
| 84 |
+
], axis=0)
|
| 85 |
+
|
| 86 |
+
dists = y_dists + x_dists / 2
|
| 87 |
+
return dists
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def find_titles(title_boxes, table_boxes, chart_boxes, max_dist=0.1):
|
| 91 |
+
"""
|
| 92 |
+
Associates titles to tables and charts.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
title_boxes (np array [n_titles x 4]): Title boxes.
|
| 96 |
+
table_boxes (np array [n_tables x 4]): Table boxes.
|
| 97 |
+
chart_boxes (np array [n_charts x 4]): Chart boxes.
|
| 98 |
+
max_dist (float, optional): Maximum distance between title and table/chart. Defaults to 0.1.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
dict: Dictionary of assigned titles.
|
| 102 |
+
- Keys are the indices of the titles,
|
| 103 |
+
- Values are tuples of:
|
| 104 |
+
- str: Whether the title is assigned to a "chart" or "table"
|
| 105 |
+
- int: index of the assigned table/chart
|
| 106 |
+
"""
|
| 107 |
+
if not len(title_boxes) or not (len(table_boxes) or len(chart_boxes)):
|
| 108 |
+
return {}
|
| 109 |
+
|
| 110 |
+
# print(title_boxes.shape, table_boxes.shape, chart_boxes.shape)
|
| 111 |
+
|
| 112 |
+
# Get distances
|
| 113 |
+
chart_distances = np.ones((len(title_boxes), 0))
|
| 114 |
+
if len(chart_boxes):
|
| 115 |
+
chart_distances = get_distances(title_boxes, chart_boxes)
|
| 116 |
+
chart_overlaps = get_overlaps(title_boxes, chart_boxes, normalize="box_only")
|
| 117 |
+
# print(chart_overlaps, "chart_overlaps", chart_overlaps.shape)
|
| 118 |
+
# print(chart_distances, "chart_distances", chart_distances.shape)
|
| 119 |
+
chart_distances = np.where(chart_overlaps > 0.25, 0, chart_distances)
|
| 120 |
+
|
| 121 |
+
# print(chart_distances)
|
| 122 |
+
|
| 123 |
+
table_distances = np.ones((len(title_boxes), 0))
|
| 124 |
+
if len(table_boxes):
|
| 125 |
+
table_distances = get_distances(title_boxes, table_boxes)
|
| 126 |
+
if len(chart_boxes): # Penalize table titles that are inside charts
|
| 127 |
+
table_distances = np.where(
|
| 128 |
+
chart_overlaps.max(1, keepdims=True) > 0.25, table_distances * 10, table_distances
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# print(table_distances, "table_distances")
|
| 132 |
+
|
| 133 |
+
# Assign to tables
|
| 134 |
+
assigned_titles = {}
|
| 135 |
+
for i, table in enumerate(table_boxes):
|
| 136 |
+
best_match = np.argmin(table_distances[:, i])
|
| 137 |
+
if table_distances[best_match, i] < max_dist:
|
| 138 |
+
assigned_titles[best_match] = ("table", i)
|
| 139 |
+
table_distances[best_match] = np.inf
|
| 140 |
+
chart_distances[best_match] = np.inf
|
| 141 |
+
|
| 142 |
+
# Assign to charts
|
| 143 |
+
for i, chart in enumerate(chart_boxes):
|
| 144 |
+
best_match = np.argmin(chart_distances[:, i])
|
| 145 |
+
if chart_distances[best_match, i] < max_dist:
|
| 146 |
+
assigned_titles[best_match] = ("chart", i)
|
| 147 |
+
chart_distances[best_match] = np.inf
|
| 148 |
+
|
| 149 |
+
return assigned_titles
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def postprocess_included(
|
| 153 |
+
boxes, labels, confs, class_="title", classes=["table", "chart", "title", "infographic"]
|
| 154 |
+
):
|
| 155 |
+
"""
|
| 156 |
+
Post process title predictions.
|
| 157 |
+
- Remove titles that are included in other boxes
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
boxes (numpy.ndarray [N, 4]): Array of bounding boxes.
|
| 161 |
+
labels (numpy.ndarray [N]): Array of labels.
|
| 162 |
+
confs (numpy.ndarray [N]): Array of confidences.
|
| 163 |
+
class_ (str, optional): Class to postprocess. Defaults to "title".
|
| 164 |
+
classes (list, optional): Classes. Defaults to ["table", "chart", "title", "infographic"].
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
boxes (numpy.ndarray): Array of bounding boxes.
|
| 168 |
+
labels (numpy.ndarray): Array of labels.
|
| 169 |
+
confs (numpy.ndarray): Array of confidences.
|
| 170 |
+
"""
|
| 171 |
+
boxes_to_pp = boxes[labels == classes.index(class_)]
|
| 172 |
+
confs_to_pp = confs[labels == classes.index(class_)]
|
| 173 |
+
|
| 174 |
+
order = np.argsort(confs_to_pp) # least to most confident for NMS
|
| 175 |
+
boxes_to_pp, confs_to_pp = boxes_to_pp[order], confs_to_pp[order]
|
| 176 |
+
|
| 177 |
+
if len(boxes_to_pp) == 0:
|
| 178 |
+
return boxes, labels, confs
|
| 179 |
+
|
| 180 |
+
# other_boxes = boxes[labels != classes.index("title")]
|
| 181 |
+
|
| 182 |
+
inclusion_classes = ["table", "infographic", "chart"]
|
| 183 |
+
if class_ in ["header_footer", "title"]:
|
| 184 |
+
inclusion_classes.append("paragraph")
|
| 185 |
+
|
| 186 |
+
other_boxes = boxes[np.isin(
|
| 187 |
+
labels,
|
| 188 |
+
[classes.index(c) for c in inclusion_classes])
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
# Remove boxes included in other_boxes
|
| 192 |
+
kept_boxes, kept_confs = [], []
|
| 193 |
+
for i, b in enumerate(boxes_to_pp):
|
| 194 |
+
# # Inclusion NMS
|
| 195 |
+
# if i < len(titles) - 1:
|
| 196 |
+
# overlaps_titles = get_overlaps(t, titles[i + 1:], normalize="all")
|
| 197 |
+
# if overlaps_titles.max() > 0.9:
|
| 198 |
+
# continue
|
| 199 |
+
|
| 200 |
+
# print(t)
|
| 201 |
+
# print(other_boxes)
|
| 202 |
+
if len(other_boxes) > 0:
|
| 203 |
+
overlaps = get_overlaps(b, other_boxes, normalize="box_only")
|
| 204 |
+
if overlaps.max() > 0.9:
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
kept_boxes.append(b)
|
| 208 |
+
kept_confs.append(confs_to_pp[i])
|
| 209 |
+
|
| 210 |
+
# Aggregate
|
| 211 |
+
kept_boxes = np.stack(kept_boxes) if len(kept_boxes) else np.empty((0, 4))
|
| 212 |
+
kept_confs = np.stack(kept_confs) if len(kept_confs) else np.empty(0)
|
| 213 |
+
|
| 214 |
+
boxes_pp = np.concatenate(
|
| 215 |
+
[boxes[labels != classes.index(class_)], kept_boxes]
|
| 216 |
+
)
|
| 217 |
+
confs_pp = np.concatenate(
|
| 218 |
+
[confs[labels != classes.index(class_)], kept_confs]
|
| 219 |
+
)
|
| 220 |
+
labels_pp = np.concatenate([
|
| 221 |
+
labels[labels != classes.index(class_)],
|
| 222 |
+
np.ones(len(kept_boxes)) * classes.index(class_)
|
| 223 |
+
])
|
| 224 |
+
|
| 225 |
+
return boxes_pp, labels_pp, confs_pp
|
post_processing/wbf.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from:
|
| 2 |
+
# https://github.com/ZFTurbo/Weighted-Boxes-Fusion/blob/master/ensemble_boxes/ensemble_boxes_wbf.py
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def prefilter_boxes(boxes, scores, labels, weights, thr, class_agnostic=False):
|
| 9 |
+
"""
|
| 10 |
+
Reformats and filters boxes.
|
| 11 |
+
Output is a dict of boxes to merge separately.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
boxes (list[np array[n x 4]]): List of boxes. One list per model.
|
| 15 |
+
scores (list[np array[n]]): List of confidences.
|
| 16 |
+
labels (list[np array[n]]): List of labels.
|
| 17 |
+
weights (list): Model weights.
|
| 18 |
+
thr (float): Confidence threshold
|
| 19 |
+
class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
dict[np array [? x 8]]: Filtered boxes.
|
| 23 |
+
"""
|
| 24 |
+
# Create dict with boxes stored by its label
|
| 25 |
+
new_boxes = dict()
|
| 26 |
+
|
| 27 |
+
for t in range(len(boxes)):
|
| 28 |
+
assert len(boxes[t]) == len(scores[t]), "len(boxes) != len(scores)"
|
| 29 |
+
assert len(boxes[t]) == len(labels[t]), "len(boxes) != len(labels)"
|
| 30 |
+
|
| 31 |
+
for j in range(len(boxes[t])):
|
| 32 |
+
score = scores[t][j]
|
| 33 |
+
if score < thr:
|
| 34 |
+
continue
|
| 35 |
+
label = int(labels[t][j])
|
| 36 |
+
box_part = boxes[t][j]
|
| 37 |
+
x1 = float(box_part[0])
|
| 38 |
+
y1 = float(box_part[1])
|
| 39 |
+
x2 = float(box_part[2])
|
| 40 |
+
y2 = float(box_part[3])
|
| 41 |
+
|
| 42 |
+
# Box data checks
|
| 43 |
+
if x2 < x1:
|
| 44 |
+
warnings.warn("X2 < X1 value in box. Swap them.")
|
| 45 |
+
x1, x2 = x2, x1
|
| 46 |
+
if y2 < y1:
|
| 47 |
+
warnings.warn("Y2 < Y1 value in box. Swap them.")
|
| 48 |
+
y1, y2 = y2, y1
|
| 49 |
+
|
| 50 |
+
array = np.array([x1, x2, y1, y2])
|
| 51 |
+
if array.min() < 0 or array.max() > 1:
|
| 52 |
+
warnings.warn("Coordinates outside [0, 1]")
|
| 53 |
+
array = np.clip(array, 0, 1)
|
| 54 |
+
x1, x2, y1, y2 = array
|
| 55 |
+
|
| 56 |
+
if (x2 - x1) * (y2 - y1) == 0.0:
|
| 57 |
+
warnings.warn("Zero area box skipped: {}.".format(box_part))
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
# [label, score, weight, model index, x1, y1, x2, y2]
|
| 61 |
+
b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
|
| 62 |
+
|
| 63 |
+
label_k = "*" if class_agnostic else label
|
| 64 |
+
if label_k not in new_boxes:
|
| 65 |
+
new_boxes[label_k] = []
|
| 66 |
+
new_boxes[label_k].append(b)
|
| 67 |
+
|
| 68 |
+
# Sort each list in dict by score and transform it to numpy array
|
| 69 |
+
for k in new_boxes:
|
| 70 |
+
current_boxes = np.array(new_boxes[k])
|
| 71 |
+
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
|
| 72 |
+
|
| 73 |
+
return new_boxes
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def merge_labels(labels, confs):
|
| 77 |
+
"""
|
| 78 |
+
Custom function for merging labels.
|
| 79 |
+
If all labels are the same, return the unique value.
|
| 80 |
+
Else, return the label of the most confident non-title (class 2) box.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
labels (np array [n]): Labels.
|
| 84 |
+
confs (np array [n]): Confidence.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
int: Label.
|
| 88 |
+
"""
|
| 89 |
+
if len(np.unique(labels)) == 1:
|
| 90 |
+
return labels[0]
|
| 91 |
+
else: # Most confident and not a title
|
| 92 |
+
confs = confs[confs != 2]
|
| 93 |
+
labels = labels[labels != 2]
|
| 94 |
+
return labels[np.argmax(confs)]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_weighted_box(boxes, conf_type="avg"):
|
| 98 |
+
"""
|
| 99 |
+
Merges boxes by using the weighted fusion.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
boxes (np array [n x 8]): Boxes to merge.
|
| 103 |
+
conf_type (str, optional): Confidence merging type. Defaults to "avg".
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
np array [8]: Merged box.
|
| 107 |
+
"""
|
| 108 |
+
box = np.zeros(8, dtype=np.float32)
|
| 109 |
+
conf = 0
|
| 110 |
+
conf_list = []
|
| 111 |
+
w = 0
|
| 112 |
+
for b in boxes:
|
| 113 |
+
box[4:] += b[1] * b[4:]
|
| 114 |
+
conf += b[1]
|
| 115 |
+
conf_list.append(b[1])
|
| 116 |
+
w += b[2]
|
| 117 |
+
|
| 118 |
+
box[0] = merge_labels(
|
| 119 |
+
np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
|
| 123 |
+
box[2] = w
|
| 124 |
+
box[3] = -1 # model index field is retained for consistency but is not used.
|
| 125 |
+
box[4:] /= conf
|
| 126 |
+
return box
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_biggest_box(boxes, conf_type="avg"):
|
| 130 |
+
"""
|
| 131 |
+
Merges boxes by using the biggest box.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
boxes (np array [n x 8]): Boxes to merge.
|
| 135 |
+
conf_type (str, optional): Confidence merging type. Defaults to "avg".
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
np array [8]: Merged box.
|
| 139 |
+
"""
|
| 140 |
+
box = np.zeros(8, dtype=np.float32)
|
| 141 |
+
box[4:] = boxes[0][4:]
|
| 142 |
+
conf_list = []
|
| 143 |
+
w = 0
|
| 144 |
+
for b in boxes:
|
| 145 |
+
box[4] = min(box[4], b[4])
|
| 146 |
+
box[5] = min(box[5], b[5])
|
| 147 |
+
box[6] = max(box[6], b[6])
|
| 148 |
+
box[7] = max(box[7], b[7])
|
| 149 |
+
conf_list.append(b[1])
|
| 150 |
+
w += b[2]
|
| 151 |
+
|
| 152 |
+
box[0] = merge_labels(
|
| 153 |
+
np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
|
| 154 |
+
)
|
| 155 |
+
# print(box[0], np.array([b[0] for b in boxes]))
|
| 156 |
+
|
| 157 |
+
box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
|
| 158 |
+
box[2] = w
|
| 159 |
+
box[3] = -1 # model index field is retained for consistency but is not used.
|
| 160 |
+
return box
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def find_matching_box_fast(boxes_list, new_box, match_iou):
|
| 164 |
+
"""
|
| 165 |
+
Reimplementation of find_matching_box with numpy instead of loops.
|
| 166 |
+
Gives significant speed up for larger arrays (~100x).
|
| 167 |
+
This was previously the bottleneck since the function is called for every entry in the array.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def bb_iou_array(boxes, new_box):
|
| 171 |
+
# bb interesection over union
|
| 172 |
+
xA = np.maximum(boxes[:, 0], new_box[0])
|
| 173 |
+
yA = np.maximum(boxes[:, 1], new_box[1])
|
| 174 |
+
xB = np.minimum(boxes[:, 2], new_box[2])
|
| 175 |
+
yB = np.minimum(boxes[:, 3], new_box[3])
|
| 176 |
+
|
| 177 |
+
interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
|
| 178 |
+
|
| 179 |
+
# compute the area of both the prediction and ground-truth rectangles
|
| 180 |
+
boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 181 |
+
boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
|
| 182 |
+
|
| 183 |
+
iou = interArea / (boxAArea + boxBArea - interArea)
|
| 184 |
+
|
| 185 |
+
return iou
|
| 186 |
+
|
| 187 |
+
if boxes_list.shape[0] == 0:
|
| 188 |
+
return -1, match_iou
|
| 189 |
+
|
| 190 |
+
ious = bb_iou_array(boxes_list[:, 4:], new_box[4:])
|
| 191 |
+
# ious[boxes[:, 0] != new_box[0]] = -1
|
| 192 |
+
|
| 193 |
+
best_idx = np.argmax(ious)
|
| 194 |
+
best_iou = ious[best_idx]
|
| 195 |
+
|
| 196 |
+
if best_iou <= match_iou:
|
| 197 |
+
best_iou = match_iou
|
| 198 |
+
best_idx = -1
|
| 199 |
+
|
| 200 |
+
return best_idx, best_iou
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def weighted_boxes_fusion(
|
| 204 |
+
boxes_list,
|
| 205 |
+
scores_list,
|
| 206 |
+
labels_list,
|
| 207 |
+
iou_thr=0.5,
|
| 208 |
+
skip_box_thr=0.0,
|
| 209 |
+
conf_type="avg",
|
| 210 |
+
merge_type="weighted",
|
| 211 |
+
class_agnostic=False,
|
| 212 |
+
):
|
| 213 |
+
"""
|
| 214 |
+
Custom WBF implementation that supports a class_agnostic mode and a biggest box fusion.
|
| 215 |
+
Boxes are expected to be in normalized (x0, y0, x1, y1) format.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
boxes_list (list[np.ndarray[n x 4]]): List of boxes. One list per model.
|
| 219 |
+
scores_list (list[np.ndarray[n]]): List of confidences.
|
| 220 |
+
labels_list (list[np.ndarray[n]]): List of labels.
|
| 221 |
+
iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55.
|
| 222 |
+
skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0.
|
| 223 |
+
conf_type (str, optional): Confidence merging type ("avg" or "max"). Defaults to "avg".
|
| 224 |
+
merge_type (str, optional): Merge type ("weighted" or "biggest"). Defaults to "weighted".
|
| 225 |
+
class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
np array[N x 4]: Merged boxes,
|
| 229 |
+
np array[N]: Merged confidences,
|
| 230 |
+
np array[N]: Merged labels.
|
| 231 |
+
"""
|
| 232 |
+
weights = np.ones(len(boxes_list))
|
| 233 |
+
|
| 234 |
+
assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"'
|
| 235 |
+
assert merge_type in ["weighted", "biggest"], 'Conf type must be "weighted" or "biggest"'
|
| 236 |
+
|
| 237 |
+
filtered_boxes = prefilter_boxes(
|
| 238 |
+
boxes_list,
|
| 239 |
+
scores_list,
|
| 240 |
+
labels_list,
|
| 241 |
+
weights,
|
| 242 |
+
skip_box_thr,
|
| 243 |
+
class_agnostic=class_agnostic,
|
| 244 |
+
)
|
| 245 |
+
if len(filtered_boxes) == 0:
|
| 246 |
+
return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
|
| 247 |
+
|
| 248 |
+
overall_boxes = []
|
| 249 |
+
for label in filtered_boxes:
|
| 250 |
+
boxes = filtered_boxes[label]
|
| 251 |
+
clusters = []
|
| 252 |
+
|
| 253 |
+
# Clusterize boxes
|
| 254 |
+
for j in range(len(boxes)):
|
| 255 |
+
ids = [i for i in range(len(boxes)) if i != j]
|
| 256 |
+
index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr)
|
| 257 |
+
|
| 258 |
+
if index != -1:
|
| 259 |
+
index = ids[index]
|
| 260 |
+
cluster_idx = [
|
| 261 |
+
clust_idx
|
| 262 |
+
for clust_idx, clust in enumerate(clusters)
|
| 263 |
+
if (j in clust or index in clust)
|
| 264 |
+
]
|
| 265 |
+
if len(cluster_idx):
|
| 266 |
+
cluster_idx = cluster_idx[0]
|
| 267 |
+
clusters[cluster_idx] = list(
|
| 268 |
+
set(clusters[cluster_idx] + [index, j])
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
clusters.append([index, j])
|
| 272 |
+
else:
|
| 273 |
+
clusters.append([j])
|
| 274 |
+
|
| 275 |
+
for j, c in enumerate(clusters):
|
| 276 |
+
if merge_type == "weighted":
|
| 277 |
+
weighted_box = get_weighted_box(boxes[c], conf_type)
|
| 278 |
+
elif merge_type == "biggest":
|
| 279 |
+
weighted_box = get_biggest_box(boxes[c], conf_type)
|
| 280 |
+
|
| 281 |
+
if conf_type == "max":
|
| 282 |
+
weighted_box[1] = weighted_box[1] / weights.max()
|
| 283 |
+
else: # avg
|
| 284 |
+
weighted_box[1] = weighted_box[1] * len(c) / weights.sum()
|
| 285 |
+
overall_boxes.append(weighted_box)
|
| 286 |
+
|
| 287 |
+
overall_boxes = np.array(overall_boxes)
|
| 288 |
+
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
|
| 289 |
+
boxes = overall_boxes[:, 4:]
|
| 290 |
+
scores = overall_boxes[:, 1]
|
| 291 |
+
labels = overall_boxes[:, 0]
|
| 292 |
+
return boxes, scores, labels
|