Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import logging | |
| from typing import List, Tuple, Dict, Any, Optional | |
| def get_slice_bboxes( | |
| image_height: int, | |
| image_width: int, | |
| slice_height: int = 640, | |
| slice_width: int = 640, | |
| overlap_height_ratio: float = 0.2, | |
| overlap_width_ratio: float = 0.2, | |
| ) -> List[List[int]]: | |
| """ | |
| Calculate bounding boxes for slices with overlap. | |
| Returns: List of [x_min, y_min, x_max, y_max] | |
| """ | |
| slice_bboxes = [] | |
| y_max = y_min = 0 | |
| y_overlap = int(slice_height * overlap_height_ratio) | |
| x_overlap = int(slice_width * overlap_width_ratio) | |
| while y_max < image_height: | |
| x_min = x_max = 0 | |
| y_max = y_min + slice_height | |
| while x_max < image_width: | |
| x_max = x_min + slice_width | |
| # Adjustment for boundaries | |
| if y_max > image_height: | |
| y_max = image_height | |
| y_min = max(0, image_height - slice_height) | |
| if x_max > image_width: | |
| x_max = image_width | |
| x_min = max(0, image_width - slice_width) | |
| slice_bboxes.append([x_min, y_min, x_max, y_max]) | |
| x_min = x_max - x_overlap | |
| y_min = y_max - y_overlap | |
| return slice_bboxes | |
| def slice_image( | |
| image: np.ndarray, | |
| slice_bboxes: List[List[int]] | |
| ) -> List[np.ndarray]: | |
| """Crops the image based on provided bounding boxes.""" | |
| slices = [] | |
| for bbox in slice_bboxes: | |
| xmin, ymin, xmax, ymax = bbox | |
| slices.append(image[ymin:ymax, xmin:xmax]) | |
| return slices | |
| def shift_bboxes( | |
| bboxes: List[List[float]], | |
| slice_coords: List[int] | |
| ) -> List[List[float]]: | |
| """ | |
| Shifts bounding boxes from slice coordinates to global image coordinates. | |
| slice_coords: [xmin, ymin, xmax, ymax] | |
| bboxes: List of [xmin, ymin, xmax, ymax] | |
| """ | |
| shift_x = slice_coords[0] | |
| shift_y = slice_coords[1] | |
| shifted = [] | |
| for box in bboxes: | |
| # box = [x1, y1, x2, y2] | |
| shifted.append([ | |
| box[0] + shift_x, | |
| box[1] + shift_y, | |
| box[2] + shift_x, | |
| box[3] + shift_y | |
| ]) | |
| return shifted | |
| def batched_nms( | |
| boxes: torch.Tensor, | |
| scores: torch.Tensor, | |
| idxs: torch.Tensor, | |
| iou_threshold: float = 0.5 | |
| ) -> torch.Tensor: | |
| """ | |
| Performs non-maximum suppression in a batched fashion. | |
| Fallback to simple NMS if torchvision/ultralytics unavailable. | |
| """ | |
| if boxes.numel() == 0: | |
| return torch.empty((0,), dtype=torch.int64, device=boxes.device) | |
| # Try importing efficient NMS implementations | |
| try: | |
| import torchvision | |
| return torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold) | |
| except ImportError: | |
| pass | |
| try: | |
| from ultralytics.utils.ops import non_max_suppression | |
| # Ultralytics NMS is usually complex/end-to-end. We need simple box NMS. | |
| # Fallback to custom greedy NMS | |
| except ImportError: | |
| pass | |
| # Custom Batched NMS Implementation (Slow but standard) | |
| keep_indices = [] | |
| unique_labels = idxs.unique() | |
| for label in unique_labels: | |
| mask = (idxs == label) | |
| cls_boxes = boxes[mask] | |
| cls_scores = scores[mask] | |
| original_indices = torch.where(mask)[0] | |
| # Sort by score | |
| sorted_indices = torch.argsort(cls_scores, descending=True) | |
| cls_boxes = cls_boxes[sorted_indices] | |
| original_indices = original_indices[sorted_indices] | |
| cls_keep = [] | |
| while cls_boxes.size(0) > 0: | |
| current_idx = 0 | |
| cls_keep.append(original_indices[current_idx]) | |
| if cls_boxes.size(0) == 1: | |
| break | |
| current_box = cls_boxes[current_idx].unsqueeze(0) | |
| rest_boxes = cls_boxes[1:] | |
| # IoU Calculation | |
| x1 = torch.max(current_box[:, 0], rest_boxes[:, 0]) | |
| y1 = torch.max(current_box[:, 1], rest_boxes[:, 1]) | |
| x2 = torch.min(current_box[:, 2], rest_boxes[:, 2]) | |
| y2 = torch.min(current_box[:, 3], rest_boxes[:, 3]) | |
| inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) | |
| box_area = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1]) | |
| rest_area = (rest_boxes[:, 2] - rest_boxes[:, 0]) * (rest_boxes[:, 3] - rest_boxes[:, 1]) | |
| union_area = box_area + rest_area - inter_area | |
| iou = inter_area / (union_area + 1e-6) | |
| # Keep boxes with low IoU | |
| mask_iou = iou < iou_threshold | |
| cls_boxes = rest_boxes[mask_iou] | |
| original_indices = original_indices[1:][mask_iou] | |
| keep_indices.extend(cls_keep) | |
| return torch.tensor(keep_indices, dtype=torch.int64, device=boxes.device) | |