Spaces:
Sleeping
Sleeping
🔀 [Merge] branch 'MODELv2' into TEST
Browse files- yolo/utils/bounding_box_utils.py +32 -21
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -108,12 +108,13 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
|
|
| 108 |
return bbox.to(dtype=data_type)
|
| 109 |
|
| 110 |
|
| 111 |
-
def generate_anchors(image_size: List[int],
|
| 112 |
"""
|
| 113 |
Find the anchor maps for each w, h.
|
| 114 |
|
| 115 |
Args:
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
Returns:
|
| 119 |
all_anchors [HW x 2]:
|
|
@@ -122,15 +123,14 @@ def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
|
|
| 122 |
W, H = image_size
|
| 123 |
anchors = []
|
| 124 |
scaler = []
|
| 125 |
-
for
|
| 126 |
-
|
| 127 |
-
anchor_num = anchor_wh[0] * anchor_wh[1]
|
| 128 |
scaler.append(torch.full((anchor_num,), stride))
|
| 129 |
shift = stride // 2
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
anchor = torch.stack([
|
| 134 |
anchors.append(anchor)
|
| 135 |
all_anchors = torch.cat(anchors, dim=0)
|
| 136 |
all_scalers = torch.cat(scaler, dim=0)
|
|
@@ -172,6 +172,7 @@ class BoxMatcher:
|
|
| 172 |
Returns:
|
| 173 |
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
|
| 174 |
"""
|
|
|
|
| 175 |
target_cls = target_cls.expand(-1, -1, 8400)
|
| 176 |
predict_cls = predict_cls.transpose(1, 2)
|
| 177 |
cls_probabilities = torch.gather(predict_cls, 1, target_cls)
|
|
@@ -266,24 +267,34 @@ class BoxMatcher:
|
|
| 266 |
|
| 267 |
class Vec2Box:
|
| 268 |
def __init__(self, model: YOLO, image_size, device):
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
for predict_head in dummy_output["Main"]:
|
| 275 |
-
_, _, *anchor_num = predict_head[2].shape
|
| 276 |
-
anchors_num.append(anchor_num)
|
| 277 |
else:
|
| 278 |
-
logger.info(
|
| 279 |
-
|
| 280 |
|
|
|
|
| 281 |
if not isinstance(model, YOLO):
|
| 282 |
device = torch.device("cpu")
|
| 283 |
|
| 284 |
-
anchor_grid, scaler = generate_anchors(image_size,
|
| 285 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
def __call__(self, predicts):
|
| 289 |
preds_cls, preds_anc, preds_box = [], [], []
|
|
|
|
| 108 |
return bbox.to(dtype=data_type)
|
| 109 |
|
| 110 |
|
| 111 |
+
def generate_anchors(image_size: List[int], strides: List[int]):
|
| 112 |
"""
|
| 113 |
Find the anchor maps for each w, h.
|
| 114 |
|
| 115 |
Args:
|
| 116 |
+
image_size List: the image size of augmented image size
|
| 117 |
+
strides List[8, 16, 32, ...]: the stride size for each predicted layer
|
| 118 |
|
| 119 |
Returns:
|
| 120 |
all_anchors [HW x 2]:
|
|
|
|
| 123 |
W, H = image_size
|
| 124 |
anchors = []
|
| 125 |
scaler = []
|
| 126 |
+
for stride in strides:
|
| 127 |
+
anchor_num = W // stride * H // stride
|
|
|
|
| 128 |
scaler.append(torch.full((anchor_num,), stride))
|
| 129 |
shift = stride // 2
|
| 130 |
+
h = torch.arange(0, H, stride) + shift
|
| 131 |
+
w = torch.arange(0, W, stride) + shift
|
| 132 |
+
anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
|
| 133 |
+
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
|
| 134 |
anchors.append(anchor)
|
| 135 |
all_anchors = torch.cat(anchors, dim=0)
|
| 136 |
all_scalers = torch.cat(scaler, dim=0)
|
|
|
|
| 172 |
Returns:
|
| 173 |
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
|
| 174 |
"""
|
| 175 |
+
# TODO: Turn 8400 to HW
|
| 176 |
target_cls = target_cls.expand(-1, -1, 8400)
|
| 177 |
predict_cls = predict_cls.transpose(1, 2)
|
| 178 |
cls_probabilities = torch.gather(predict_cls, 1, target_cls)
|
|
|
|
| 267 |
|
| 268 |
class Vec2Box:
|
| 269 |
def __init__(self, model: YOLO, image_size, device):
|
| 270 |
+
self.device = device
|
| 271 |
+
|
| 272 |
+
if getattr(model, "strides"):
|
| 273 |
+
logger.info(f"🈶 Found stride of model {model.strides}")
|
| 274 |
+
self.strides = model.strides
|
|
|
|
|
|
|
|
|
|
| 275 |
else:
|
| 276 |
+
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
|
| 277 |
+
self.strides = self.create_auto_anchor(model, image_size)
|
| 278 |
|
| 279 |
+
# TODO: this is a exception of onnx, remove it when onnx device if fixed
|
| 280 |
if not isinstance(model, YOLO):
|
| 281 |
device = torch.device("cpu")
|
| 282 |
|
| 283 |
+
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
| 284 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
| 285 |
+
|
| 286 |
+
def create_auto_anchor(self, model: YOLO, image_size):
|
| 287 |
+
dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
|
| 288 |
+
dummy_output = model(dummy_input)
|
| 289 |
+
strides = []
|
| 290 |
+
for predict_head in dummy_output["Main"]:
|
| 291 |
+
_, _, *anchor_num = predict_head[2].shape
|
| 292 |
+
strides.append(image_size[1] // anchor_num[1])
|
| 293 |
+
return strides
|
| 294 |
+
|
| 295 |
+
def update(self, image_size):
|
| 296 |
+
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
| 297 |
+
self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device)
|
| 298 |
|
| 299 |
def __call__(self, predicts):
|
| 300 |
preds_cls, preds_anc, preds_box = [], [], []
|