Spaces:
Paused
Paused
Upload 4 files
Browse files- xdecoder/utils/box_ops.py +93 -0
- xdecoder/utils/config.py +140 -0
- xdecoder/utils/it_contrastive.py +59 -0
- xdecoder/utils/misc.py +157 -0
xdecoder/utils/box_ops.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Utilities for bounding box manipulation and GIoU.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from torchvision.ops.boxes import box_area
|
7 |
+
|
8 |
+
|
9 |
+
def box_cxcywh_to_xyxy(x):
|
10 |
+
x_c, y_c, w, h = x.unbind(-1)
|
11 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
12 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
13 |
+
return torch.stack(b, dim=-1)
|
14 |
+
|
15 |
+
|
16 |
+
def box_xyxy_to_cxcywh(x):
|
17 |
+
x0, y0, x1, y1 = x.unbind(-1)
|
18 |
+
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
19 |
+
(x1 - x0), (y1 - y0)]
|
20 |
+
return torch.stack(b, dim=-1)
|
21 |
+
|
22 |
+
def box_xywh_to_xyxy(x):
|
23 |
+
x0, y0, x1, y1 = x.unbind(-1)
|
24 |
+
b = [x0, y0, (x0 + x1), (y0 + y1)]
|
25 |
+
return torch.stack(b, dim=-1)
|
26 |
+
|
27 |
+
|
28 |
+
# modified from torchvision to also return the union
|
29 |
+
def box_iou(boxes1, boxes2):
|
30 |
+
area1 = box_area(boxes1)
|
31 |
+
area2 = box_area(boxes2)
|
32 |
+
|
33 |
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
34 |
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
35 |
+
|
36 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
37 |
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
38 |
+
|
39 |
+
union = area1[:, None] + area2 - inter
|
40 |
+
|
41 |
+
iou = inter / union
|
42 |
+
return iou, union
|
43 |
+
|
44 |
+
|
45 |
+
def generalized_box_iou(boxes1, boxes2):
|
46 |
+
"""
|
47 |
+
Generalized IoU from https://giou.stanford.edu/
|
48 |
+
|
49 |
+
The boxes should be in [x0, y0, x1, y1] format
|
50 |
+
|
51 |
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
52 |
+
and M = len(boxes2)
|
53 |
+
"""
|
54 |
+
# degenerate boxes gives inf / nan results
|
55 |
+
# so do an early check
|
56 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
57 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
58 |
+
iou, union = box_iou(boxes1, boxes2)
|
59 |
+
|
60 |
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
61 |
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
62 |
+
|
63 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
64 |
+
area = wh[:, :, 0] * wh[:, :, 1]
|
65 |
+
|
66 |
+
return iou - (area - union) / area
|
67 |
+
|
68 |
+
|
69 |
+
def masks_to_boxes(masks):
|
70 |
+
"""Compute the bounding boxes around the provided masks
|
71 |
+
|
72 |
+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
73 |
+
|
74 |
+
Returns a [N, 4] tensors, with the boxes in xyxy format
|
75 |
+
"""
|
76 |
+
if masks.numel() == 0:
|
77 |
+
return torch.zeros((0, 4), device=masks.device)
|
78 |
+
|
79 |
+
h, w = masks.shape[-2:]
|
80 |
+
|
81 |
+
y = torch.arange(0, h, dtype=torch.float)
|
82 |
+
x = torch.arange(0, w, dtype=torch.float)
|
83 |
+
y, x = torch.meshgrid(y, x)
|
84 |
+
|
85 |
+
x_mask = (masks * x.unsqueeze(0))
|
86 |
+
x_max = x_mask.flatten(1).max(-1)[0]
|
87 |
+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
88 |
+
|
89 |
+
y_mask = (masks * y.unsqueeze(0))
|
90 |
+
y_max = y_mask.flatten(1).max(-1)[0]
|
91 |
+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
92 |
+
|
93 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
xdecoder/utils/config.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import functools
|
5 |
+
import inspect
|
6 |
+
|
7 |
+
def configurable(init_func=None, *, from_config=None):
|
8 |
+
"""
|
9 |
+
Decorate a function or a class's __init__ method so that it can be called
|
10 |
+
with a :class:`CfgNode` object using a :func:`from_config` function that translates
|
11 |
+
:class:`CfgNode` to arguments.
|
12 |
+
|
13 |
+
Examples:
|
14 |
+
::
|
15 |
+
# Usage 1: Decorator on __init__:
|
16 |
+
class A:
|
17 |
+
@configurable
|
18 |
+
def __init__(self, a, b=2, c=3):
|
19 |
+
pass
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def from_config(cls, cfg): # 'cfg' must be the first argument
|
23 |
+
# Returns kwargs to be passed to __init__
|
24 |
+
return {"a": cfg.A, "b": cfg.B}
|
25 |
+
|
26 |
+
a1 = A(a=1, b=2) # regular construction
|
27 |
+
a2 = A(cfg) # construct with a cfg
|
28 |
+
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
|
29 |
+
|
30 |
+
# Usage 2: Decorator on any function. Needs an extra from_config argument:
|
31 |
+
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
|
32 |
+
def a_func(a, b=2, c=3):
|
33 |
+
pass
|
34 |
+
|
35 |
+
a1 = a_func(a=1, b=2) # regular call
|
36 |
+
a2 = a_func(cfg) # call with a cfg
|
37 |
+
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
|
38 |
+
|
39 |
+
Args:
|
40 |
+
init_func (callable): a class's ``__init__`` method in usage 1. The
|
41 |
+
class must have a ``from_config`` classmethod which takes `cfg` as
|
42 |
+
the first argument.
|
43 |
+
from_config (callable): the from_config function in usage 2. It must take `cfg`
|
44 |
+
as its first argument.
|
45 |
+
"""
|
46 |
+
|
47 |
+
if init_func is not None:
|
48 |
+
assert (
|
49 |
+
inspect.isfunction(init_func)
|
50 |
+
and from_config is None
|
51 |
+
and init_func.__name__ == "__init__"
|
52 |
+
), "Incorrect use of @configurable. Check API documentation for examples."
|
53 |
+
|
54 |
+
@functools.wraps(init_func)
|
55 |
+
def wrapped(self, *args, **kwargs):
|
56 |
+
try:
|
57 |
+
from_config_func = type(self).from_config
|
58 |
+
except AttributeError as e:
|
59 |
+
raise AttributeError(
|
60 |
+
"Class with @configurable must have a 'from_config' classmethod."
|
61 |
+
) from e
|
62 |
+
if not inspect.ismethod(from_config_func):
|
63 |
+
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
|
64 |
+
|
65 |
+
if _called_with_cfg(*args, **kwargs):
|
66 |
+
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
|
67 |
+
init_func(self, **explicit_args)
|
68 |
+
else:
|
69 |
+
init_func(self, *args, **kwargs)
|
70 |
+
|
71 |
+
return wrapped
|
72 |
+
|
73 |
+
else:
|
74 |
+
if from_config is None:
|
75 |
+
return configurable # @configurable() is made equivalent to @configurable
|
76 |
+
assert inspect.isfunction(
|
77 |
+
from_config
|
78 |
+
), "from_config argument of configurable must be a function!"
|
79 |
+
|
80 |
+
def wrapper(orig_func):
|
81 |
+
@functools.wraps(orig_func)
|
82 |
+
def wrapped(*args, **kwargs):
|
83 |
+
if _called_with_cfg(*args, **kwargs):
|
84 |
+
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
|
85 |
+
return orig_func(**explicit_args)
|
86 |
+
else:
|
87 |
+
return orig_func(*args, **kwargs)
|
88 |
+
|
89 |
+
wrapped.from_config = from_config
|
90 |
+
return wrapped
|
91 |
+
|
92 |
+
return wrapper
|
93 |
+
|
94 |
+
def _called_with_cfg(*args, **kwargs):
|
95 |
+
"""
|
96 |
+
Returns:
|
97 |
+
bool: whether the arguments contain CfgNode and should be considered
|
98 |
+
forwarded to from_config.
|
99 |
+
"""
|
100 |
+
from omegaconf import DictConfig
|
101 |
+
|
102 |
+
if len(args) and isinstance(args[0], (dict)):
|
103 |
+
return True
|
104 |
+
if isinstance(kwargs.pop("cfg", None), (dict)):
|
105 |
+
return True
|
106 |
+
# `from_config`'s first argument is forced to be "cfg".
|
107 |
+
# So the above check covers all cases.
|
108 |
+
return False
|
109 |
+
|
110 |
+
def _get_args_from_config(from_config_func, *args, **kwargs):
|
111 |
+
"""
|
112 |
+
Use `from_config` to obtain explicit arguments.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
dict: arguments to be used for cls.__init__
|
116 |
+
"""
|
117 |
+
signature = inspect.signature(from_config_func)
|
118 |
+
if list(signature.parameters.keys())[0] != "cfg":
|
119 |
+
if inspect.isfunction(from_config_func):
|
120 |
+
name = from_config_func.__name__
|
121 |
+
else:
|
122 |
+
name = f"{from_config_func.__self__}.from_config"
|
123 |
+
raise TypeError(f"{name} must take 'cfg' as the first argument!")
|
124 |
+
support_var_arg = any(
|
125 |
+
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
|
126 |
+
for param in signature.parameters.values()
|
127 |
+
)
|
128 |
+
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
|
129 |
+
ret = from_config_func(*args, **kwargs)
|
130 |
+
else:
|
131 |
+
# forward supported arguments to from_config
|
132 |
+
supported_arg_names = set(signature.parameters.keys())
|
133 |
+
extra_kwargs = {}
|
134 |
+
for name in list(kwargs.keys()):
|
135 |
+
if name not in supported_arg_names:
|
136 |
+
extra_kwargs[name] = kwargs.pop(name)
|
137 |
+
ret = from_config_func(*args, **kwargs)
|
138 |
+
# forward the other arguments to __init__
|
139 |
+
ret.update(extra_kwargs)
|
140 |
+
return ret
|
xdecoder/utils/it_contrastive.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def is_dist_initialized():
|
6 |
+
return torch.distributed.is_initialized()
|
7 |
+
|
8 |
+
def get_world_size():
|
9 |
+
if is_dist_initialized():
|
10 |
+
return torch.distributed.get_world_size()
|
11 |
+
return 1
|
12 |
+
|
13 |
+
def all_gather_grad(x):
|
14 |
+
if get_world_size() > 1:
|
15 |
+
all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
|
16 |
+
torch.distributed.all_gather(all_x, x)
|
17 |
+
all_x[torch.distributed.get_rank()] = x
|
18 |
+
x = torch.cat(all_x, dim=0)
|
19 |
+
return x
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def all_gather_nograd(tensor):
|
23 |
+
# from albef
|
24 |
+
"""
|
25 |
+
Performs all_gather operation on the provided tensors.
|
26 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
27 |
+
"""
|
28 |
+
if get_world_size() > 1:
|
29 |
+
tensors_gather = [torch.ones_like(tensor)
|
30 |
+
for _ in range(torch.distributed.get_world_size())]
|
31 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
32 |
+
|
33 |
+
tensor = torch.cat(tensors_gather, dim=0)
|
34 |
+
return tensor
|
35 |
+
|
36 |
+
def image_text_contrastive_loss(image_feat, text_feat, temperature, image_id=None, text_id=None):
|
37 |
+
# add the following 4 lines
|
38 |
+
image_feat = all_gather_grad(image_feat)
|
39 |
+
text_feat = all_gather_grad(text_feat)
|
40 |
+
|
41 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
42 |
+
logits /= temperature
|
43 |
+
|
44 |
+
if image_id is None and text_id is None:
|
45 |
+
gt = torch.arange(logits.shape[0], device=logits.device)
|
46 |
+
loss1 = F.cross_entropy(logits, gt)
|
47 |
+
loss2 = F.cross_entropy(logits.t(), gt)
|
48 |
+
else:
|
49 |
+
image_id = all_gather_grad(image_id)
|
50 |
+
text_id = all_gather_grad(text_id)
|
51 |
+
|
52 |
+
gt_image = image_id.reshape((-1, 1)) == image_id.reshape((1, -1))
|
53 |
+
gt_text = text_id.reshape((-1, 1)) == text_id.reshape((1, -1))
|
54 |
+
gt = torch.logical_or(gt_image, gt_text)
|
55 |
+
|
56 |
+
loss1 = -torch.sum(gt * F.log_softmax(logits, dim=1)) / gt.sum()
|
57 |
+
loss2 = -torch.sum(gt.t() * F.log_softmax(logits.t(), dim=1)) / gt.sum()
|
58 |
+
|
59 |
+
return (loss1 + loss2) / 2 * get_world_size() # scale it up by the number of GPUs
|
xdecoder/utils/misc.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
|
3 |
+
# Modified by Xueyan Zou
|
4 |
+
"""
|
5 |
+
Misc functions, including distributed helpers.
|
6 |
+
|
7 |
+
Mostly copy-paste from torchvision references.
|
8 |
+
"""
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torchvision
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
def _max_by_axis(the_list):
|
17 |
+
# type: (List[List[int]]) -> List[int]
|
18 |
+
maxes = the_list[0]
|
19 |
+
for sublist in the_list[1:]:
|
20 |
+
for index, item in enumerate(sublist):
|
21 |
+
maxes[index] = max(maxes[index], item)
|
22 |
+
return maxes
|
23 |
+
|
24 |
+
class NestedTensor(object):
|
25 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
26 |
+
self.tensors = tensors
|
27 |
+
self.mask = mask
|
28 |
+
|
29 |
+
def to(self, device):
|
30 |
+
# type: (Device) -> NestedTensor # noqa
|
31 |
+
cast_tensor = self.tensors.to(device)
|
32 |
+
mask = self.mask
|
33 |
+
if mask is not None:
|
34 |
+
assert mask is not None
|
35 |
+
cast_mask = mask.to(device)
|
36 |
+
else:
|
37 |
+
cast_mask = None
|
38 |
+
return NestedTensor(cast_tensor, cast_mask)
|
39 |
+
|
40 |
+
def decompose(self):
|
41 |
+
return self.tensors, self.mask
|
42 |
+
|
43 |
+
def __repr__(self):
|
44 |
+
return str(self.tensors)
|
45 |
+
|
46 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
47 |
+
# TODO make this more general
|
48 |
+
if tensor_list[0].ndim == 3:
|
49 |
+
if torchvision._is_tracing():
|
50 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
51 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
52 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
53 |
+
|
54 |
+
# TODO make it support different-sized images
|
55 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
56 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
57 |
+
batch_shape = [len(tensor_list)] + max_size
|
58 |
+
b, c, h, w = batch_shape
|
59 |
+
dtype = tensor_list[0].dtype
|
60 |
+
device = tensor_list[0].device
|
61 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
62 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
63 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
64 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
65 |
+
m[: img.shape[1], : img.shape[2]] = False
|
66 |
+
elif tensor_list[0].ndim == 2:
|
67 |
+
if torchvision._is_tracing():
|
68 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
69 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
70 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
71 |
+
|
72 |
+
# TODO make it support different-sized images
|
73 |
+
max_size = _max_by_axis([list(txt.shape) for txt in tensor_list])
|
74 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
75 |
+
batch_shape = [len(tensor_list)] + max_size
|
76 |
+
b, c, l = batch_shape
|
77 |
+
dtype = tensor_list[0].dtype
|
78 |
+
device = tensor_list[0].device
|
79 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
80 |
+
mask = torch.ones((b, l), dtype=torch.bool, device=device)
|
81 |
+
for txt, pad_txt, m in zip(tensor_list, tensor, mask):
|
82 |
+
pad_txt[: txt.shape[0], : txt.shape[1]] = txt
|
83 |
+
m[: txt.shape[1]] = False
|
84 |
+
else:
|
85 |
+
raise ValueError("not supported")
|
86 |
+
return NestedTensor(tensor, mask)
|
87 |
+
|
88 |
+
def _collate_and_pad_divisibility(tensor_list: list, div=32):
|
89 |
+
max_size = []
|
90 |
+
for i in range(tensor_list[0].dim()):
|
91 |
+
max_size_i = torch.max(
|
92 |
+
torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32)
|
93 |
+
).to(torch.int64)
|
94 |
+
max_size.append(max_size_i)
|
95 |
+
max_size = tuple(max_size)
|
96 |
+
|
97 |
+
c,h,w = max_size
|
98 |
+
pad_h = (div - h % div) if h % div != 0 else 0
|
99 |
+
pad_w = (div - w % div) if w % div != 0 else 0
|
100 |
+
max_size = (c,h+pad_h,w+pad_w)
|
101 |
+
|
102 |
+
# work around for
|
103 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
104 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
105 |
+
# which is not yet supported in onnx
|
106 |
+
padded_imgs = []
|
107 |
+
padded_masks = []
|
108 |
+
for img in tensor_list:
|
109 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
110 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
111 |
+
padded_imgs.append(padded_img)
|
112 |
+
|
113 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
114 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
115 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
116 |
+
|
117 |
+
return padded_imgs
|
118 |
+
|
119 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
120 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
121 |
+
@torch.jit.unused
|
122 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
123 |
+
max_size = []
|
124 |
+
for i in range(tensor_list[0].dim()):
|
125 |
+
max_size_i = torch.max(
|
126 |
+
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
127 |
+
).to(torch.int64)
|
128 |
+
max_size.append(max_size_i)
|
129 |
+
max_size = tuple(max_size)
|
130 |
+
|
131 |
+
# work around for
|
132 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
133 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
134 |
+
# which is not yet supported in onnx
|
135 |
+
padded_imgs = []
|
136 |
+
padded_masks = []
|
137 |
+
for img in tensor_list:
|
138 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
139 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
140 |
+
padded_imgs.append(padded_img)
|
141 |
+
|
142 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
143 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
144 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
145 |
+
|
146 |
+
tensor = torch.stack(padded_imgs)
|
147 |
+
mask = torch.stack(padded_masks)
|
148 |
+
|
149 |
+
return NestedTensor(tensor, mask=mask)
|
150 |
+
|
151 |
+
|
152 |
+
def is_dist_avail_and_initialized():
|
153 |
+
if not dist.is_available():
|
154 |
+
return False
|
155 |
+
if not dist.is_initialized():
|
156 |
+
return False
|
157 |
+
return True
|