Spaces:
Runtime error
Runtime error
import re | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import collections | |
import logging | |
def get_area(pos): | |
""" | |
Args | |
pos: [B, N, 4] | |
(x1, x2, y1, y2) | |
Return | |
area : [B, N] | |
""" | |
# [B, N] | |
height = pos[:, :, 3] - pos[:, :, 2] | |
width = pos[:, :, 1] - pos[:, :, 0] | |
area = height * width | |
return area | |
def get_relative_distance(pos): | |
""" | |
Args | |
pos: [B, N, 4] | |
(x1, x2, y1, y2) | |
Return | |
out : [B, N, N, 4] | |
""" | |
# B, N = pos.size()[:-1] | |
# [B, N, N, 4] | |
relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2) | |
return relative_distance | |
class LossMeter(object): | |
def __init__(self, maxlen=100): | |
"""Computes and stores the running average""" | |
self.vals = collections.deque([], maxlen=maxlen) | |
def __len__(self): | |
return len(self.vals) | |
def update(self, new_val): | |
self.vals.append(new_val) | |
def val(self): | |
return sum(self.vals) / len(self.vals) | |
def __repr__(self): | |
return str(self.val) | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def load_state_dict(state_dict_path, loc='cpu'): | |
state_dict = torch.load(state_dict_path, map_location=loc) | |
# Change Multi GPU to single GPU | |
original_keys = list(state_dict.keys()) | |
for key in original_keys: | |
if key.startswith("module."): | |
new_key = key[len("module."):] | |
state_dict[new_key] = state_dict.pop(key) | |
return state_dict | |
def set_global_logging_level(level=logging.ERROR, prefices=[""]): | |
""" | |
Override logging levels of different modules based on their name as a prefix. | |
It needs to be invoked after the modules have been loaded so that their loggers have been initialized. | |
Args: | |
- level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR | |
- prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. | |
Default is `[""]` to match all active loggers. | |
The match is a case-sensitive `module_name.startswith(prefix)` | |
""" | |
prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') | |
for name in logging.root.manager.loggerDict: | |
if re.match(prefix_re, name): | |
logging.getLogger(name).setLevel(level) | |
def get_iou(anchors, gt_boxes): | |
""" | |
anchors: (N, 4) torch floattensor | |
gt_boxes: (K, 4) torch floattensor | |
overlaps: (N, K) ndarray of overlap between boxes and query_boxes | |
""" | |
N = anchors.size(0) | |
if gt_boxes.size() == (4,): | |
gt_boxes = gt_boxes.view(1, 4) | |
K = gt_boxes.size(0) | |
gt_boxes_area = ( | |
(gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * | |
(gt_boxes[:, 3] - gt_boxes[:, 1] + 1) | |
).view(1, K) | |
anchors_area = ( | |
(anchors[:, 2] - anchors[:, 0] + 1) * | |
(anchors[:, 3] - anchors[:, 1] + 1) | |
).view(N, 1) | |
boxes = anchors.view(N, 1, 4).expand(N, K, 4) | |
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) | |
iw = ( | |
torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) | |
- torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) | |
+ 1 | |
) | |
iw[iw < 0] = 0 | |
ih = ( | |
torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) | |
- torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) | |
+ 1 | |
) | |
ih[ih < 0] = 0 | |
ua = anchors_area + gt_boxes_area - (iw * ih) | |
overlaps = iw * ih / ua | |
return overlaps | |
def xywh_to_xyxy(boxes): | |
"""Convert [x y w h] box format to [x1 y1 x2 y2] format.""" | |
return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) | |