fffiloni commited on
Commit
0f8c994
1 Parent(s): 40ef1c3

Upload 4 files

Browse files
xdecoder/utils/box_ops.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
13
+ return torch.stack(b, dim=-1)
14
+
15
+
16
+ def box_xyxy_to_cxcywh(x):
17
+ x0, y0, x1, y1 = x.unbind(-1)
18
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
19
+ (x1 - x0), (y1 - y0)]
20
+ return torch.stack(b, dim=-1)
21
+
22
+ def box_xywh_to_xyxy(x):
23
+ x0, y0, x1, y1 = x.unbind(-1)
24
+ b = [x0, y0, (x0 + x1), (y0 + y1)]
25
+ return torch.stack(b, dim=-1)
26
+
27
+
28
+ # modified from torchvision to also return the union
29
+ def box_iou(boxes1, boxes2):
30
+ area1 = box_area(boxes1)
31
+ area2 = box_area(boxes2)
32
+
33
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
34
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
35
+
36
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
37
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
38
+
39
+ union = area1[:, None] + area2 - inter
40
+
41
+ iou = inter / union
42
+ return iou, union
43
+
44
+
45
+ def generalized_box_iou(boxes1, boxes2):
46
+ """
47
+ Generalized IoU from https://giou.stanford.edu/
48
+
49
+ The boxes should be in [x0, y0, x1, y1] format
50
+
51
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
52
+ and M = len(boxes2)
53
+ """
54
+ # degenerate boxes gives inf / nan results
55
+ # so do an early check
56
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
57
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
58
+ iou, union = box_iou(boxes1, boxes2)
59
+
60
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
61
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
62
+
63
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
64
+ area = wh[:, :, 0] * wh[:, :, 1]
65
+
66
+ return iou - (area - union) / area
67
+
68
+
69
+ def masks_to_boxes(masks):
70
+ """Compute the bounding boxes around the provided masks
71
+
72
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
73
+
74
+ Returns a [N, 4] tensors, with the boxes in xyxy format
75
+ """
76
+ if masks.numel() == 0:
77
+ return torch.zeros((0, 4), device=masks.device)
78
+
79
+ h, w = masks.shape[-2:]
80
+
81
+ y = torch.arange(0, h, dtype=torch.float)
82
+ x = torch.arange(0, w, dtype=torch.float)
83
+ y, x = torch.meshgrid(y, x)
84
+
85
+ x_mask = (masks * x.unsqueeze(0))
86
+ x_max = x_mask.flatten(1).max(-1)[0]
87
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
88
+
89
+ y_mask = (masks * y.unsqueeze(0))
90
+ y_max = y_mask.flatten(1).max(-1)[0]
91
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
92
+
93
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
xdecoder/utils/config.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ import functools
5
+ import inspect
6
+
7
+ def configurable(init_func=None, *, from_config=None):
8
+ """
9
+ Decorate a function or a class's __init__ method so that it can be called
10
+ with a :class:`CfgNode` object using a :func:`from_config` function that translates
11
+ :class:`CfgNode` to arguments.
12
+
13
+ Examples:
14
+ ::
15
+ # Usage 1: Decorator on __init__:
16
+ class A:
17
+ @configurable
18
+ def __init__(self, a, b=2, c=3):
19
+ pass
20
+
21
+ @classmethod
22
+ def from_config(cls, cfg): # 'cfg' must be the first argument
23
+ # Returns kwargs to be passed to __init__
24
+ return {"a": cfg.A, "b": cfg.B}
25
+
26
+ a1 = A(a=1, b=2) # regular construction
27
+ a2 = A(cfg) # construct with a cfg
28
+ a3 = A(cfg, b=3, c=4) # construct with extra overwrite
29
+
30
+ # Usage 2: Decorator on any function. Needs an extra from_config argument:
31
+ @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
32
+ def a_func(a, b=2, c=3):
33
+ pass
34
+
35
+ a1 = a_func(a=1, b=2) # regular call
36
+ a2 = a_func(cfg) # call with a cfg
37
+ a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
38
+
39
+ Args:
40
+ init_func (callable): a class's ``__init__`` method in usage 1. The
41
+ class must have a ``from_config`` classmethod which takes `cfg` as
42
+ the first argument.
43
+ from_config (callable): the from_config function in usage 2. It must take `cfg`
44
+ as its first argument.
45
+ """
46
+
47
+ if init_func is not None:
48
+ assert (
49
+ inspect.isfunction(init_func)
50
+ and from_config is None
51
+ and init_func.__name__ == "__init__"
52
+ ), "Incorrect use of @configurable. Check API documentation for examples."
53
+
54
+ @functools.wraps(init_func)
55
+ def wrapped(self, *args, **kwargs):
56
+ try:
57
+ from_config_func = type(self).from_config
58
+ except AttributeError as e:
59
+ raise AttributeError(
60
+ "Class with @configurable must have a 'from_config' classmethod."
61
+ ) from e
62
+ if not inspect.ismethod(from_config_func):
63
+ raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
64
+
65
+ if _called_with_cfg(*args, **kwargs):
66
+ explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
67
+ init_func(self, **explicit_args)
68
+ else:
69
+ init_func(self, *args, **kwargs)
70
+
71
+ return wrapped
72
+
73
+ else:
74
+ if from_config is None:
75
+ return configurable # @configurable() is made equivalent to @configurable
76
+ assert inspect.isfunction(
77
+ from_config
78
+ ), "from_config argument of configurable must be a function!"
79
+
80
+ def wrapper(orig_func):
81
+ @functools.wraps(orig_func)
82
+ def wrapped(*args, **kwargs):
83
+ if _called_with_cfg(*args, **kwargs):
84
+ explicit_args = _get_args_from_config(from_config, *args, **kwargs)
85
+ return orig_func(**explicit_args)
86
+ else:
87
+ return orig_func(*args, **kwargs)
88
+
89
+ wrapped.from_config = from_config
90
+ return wrapped
91
+
92
+ return wrapper
93
+
94
+ def _called_with_cfg(*args, **kwargs):
95
+ """
96
+ Returns:
97
+ bool: whether the arguments contain CfgNode and should be considered
98
+ forwarded to from_config.
99
+ """
100
+ from omegaconf import DictConfig
101
+
102
+ if len(args) and isinstance(args[0], (dict)):
103
+ return True
104
+ if isinstance(kwargs.pop("cfg", None), (dict)):
105
+ return True
106
+ # `from_config`'s first argument is forced to be "cfg".
107
+ # So the above check covers all cases.
108
+ return False
109
+
110
+ def _get_args_from_config(from_config_func, *args, **kwargs):
111
+ """
112
+ Use `from_config` to obtain explicit arguments.
113
+
114
+ Returns:
115
+ dict: arguments to be used for cls.__init__
116
+ """
117
+ signature = inspect.signature(from_config_func)
118
+ if list(signature.parameters.keys())[0] != "cfg":
119
+ if inspect.isfunction(from_config_func):
120
+ name = from_config_func.__name__
121
+ else:
122
+ name = f"{from_config_func.__self__}.from_config"
123
+ raise TypeError(f"{name} must take 'cfg' as the first argument!")
124
+ support_var_arg = any(
125
+ param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
126
+ for param in signature.parameters.values()
127
+ )
128
+ if support_var_arg: # forward all arguments to from_config, if from_config accepts them
129
+ ret = from_config_func(*args, **kwargs)
130
+ else:
131
+ # forward supported arguments to from_config
132
+ supported_arg_names = set(signature.parameters.keys())
133
+ extra_kwargs = {}
134
+ for name in list(kwargs.keys()):
135
+ if name not in supported_arg_names:
136
+ extra_kwargs[name] = kwargs.pop(name)
137
+ ret = from_config_func(*args, **kwargs)
138
+ # forward the other arguments to __init__
139
+ ret.update(extra_kwargs)
140
+ return ret
xdecoder/utils/it_contrastive.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def is_dist_initialized():
6
+ return torch.distributed.is_initialized()
7
+
8
+ def get_world_size():
9
+ if is_dist_initialized():
10
+ return torch.distributed.get_world_size()
11
+ return 1
12
+
13
+ def all_gather_grad(x):
14
+ if get_world_size() > 1:
15
+ all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
16
+ torch.distributed.all_gather(all_x, x)
17
+ all_x[torch.distributed.get_rank()] = x
18
+ x = torch.cat(all_x, dim=0)
19
+ return x
20
+
21
+ @torch.no_grad()
22
+ def all_gather_nograd(tensor):
23
+ # from albef
24
+ """
25
+ Performs all_gather operation on the provided tensors.
26
+ *** Warning ***: torch.distributed.all_gather has no gradient.
27
+ """
28
+ if get_world_size() > 1:
29
+ tensors_gather = [torch.ones_like(tensor)
30
+ for _ in range(torch.distributed.get_world_size())]
31
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
32
+
33
+ tensor = torch.cat(tensors_gather, dim=0)
34
+ return tensor
35
+
36
+ def image_text_contrastive_loss(image_feat, text_feat, temperature, image_id=None, text_id=None):
37
+ # add the following 4 lines
38
+ image_feat = all_gather_grad(image_feat)
39
+ text_feat = all_gather_grad(text_feat)
40
+
41
+ logits = torch.matmul(image_feat, text_feat.t())
42
+ logits /= temperature
43
+
44
+ if image_id is None and text_id is None:
45
+ gt = torch.arange(logits.shape[0], device=logits.device)
46
+ loss1 = F.cross_entropy(logits, gt)
47
+ loss2 = F.cross_entropy(logits.t(), gt)
48
+ else:
49
+ image_id = all_gather_grad(image_id)
50
+ text_id = all_gather_grad(text_id)
51
+
52
+ gt_image = image_id.reshape((-1, 1)) == image_id.reshape((1, -1))
53
+ gt_text = text_id.reshape((-1, 1)) == text_id.reshape((1, -1))
54
+ gt = torch.logical_or(gt_image, gt_text)
55
+
56
+ loss1 = -torch.sum(gt * F.log_softmax(logits, dim=1)) / gt.sum()
57
+ loss2 = -torch.sum(gt.t() * F.log_softmax(logits.t(), dim=1)) / gt.sum()
58
+
59
+ return (loss1 + loss2) / 2 * get_world_size() # scale it up by the number of GPUs
xdecoder/utils/misc.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
3
+ # Modified by Xueyan Zou
4
+ """
5
+ Misc functions, including distributed helpers.
6
+
7
+ Mostly copy-paste from torchvision references.
8
+ """
9
+ from typing import List, Optional
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torchvision
14
+ from torch import Tensor
15
+
16
+ def _max_by_axis(the_list):
17
+ # type: (List[List[int]]) -> List[int]
18
+ maxes = the_list[0]
19
+ for sublist in the_list[1:]:
20
+ for index, item in enumerate(sublist):
21
+ maxes[index] = max(maxes[index], item)
22
+ return maxes
23
+
24
+ class NestedTensor(object):
25
+ def __init__(self, tensors, mask: Optional[Tensor]):
26
+ self.tensors = tensors
27
+ self.mask = mask
28
+
29
+ def to(self, device):
30
+ # type: (Device) -> NestedTensor # noqa
31
+ cast_tensor = self.tensors.to(device)
32
+ mask = self.mask
33
+ if mask is not None:
34
+ assert mask is not None
35
+ cast_mask = mask.to(device)
36
+ else:
37
+ cast_mask = None
38
+ return NestedTensor(cast_tensor, cast_mask)
39
+
40
+ def decompose(self):
41
+ return self.tensors, self.mask
42
+
43
+ def __repr__(self):
44
+ return str(self.tensors)
45
+
46
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
47
+ # TODO make this more general
48
+ if tensor_list[0].ndim == 3:
49
+ if torchvision._is_tracing():
50
+ # nested_tensor_from_tensor_list() does not export well to ONNX
51
+ # call _onnx_nested_tensor_from_tensor_list() instead
52
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
53
+
54
+ # TODO make it support different-sized images
55
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
56
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
57
+ batch_shape = [len(tensor_list)] + max_size
58
+ b, c, h, w = batch_shape
59
+ dtype = tensor_list[0].dtype
60
+ device = tensor_list[0].device
61
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
62
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
63
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
64
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
65
+ m[: img.shape[1], : img.shape[2]] = False
66
+ elif tensor_list[0].ndim == 2:
67
+ if torchvision._is_tracing():
68
+ # nested_tensor_from_tensor_list() does not export well to ONNX
69
+ # call _onnx_nested_tensor_from_tensor_list() instead
70
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
71
+
72
+ # TODO make it support different-sized images
73
+ max_size = _max_by_axis([list(txt.shape) for txt in tensor_list])
74
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
75
+ batch_shape = [len(tensor_list)] + max_size
76
+ b, c, l = batch_shape
77
+ dtype = tensor_list[0].dtype
78
+ device = tensor_list[0].device
79
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
80
+ mask = torch.ones((b, l), dtype=torch.bool, device=device)
81
+ for txt, pad_txt, m in zip(tensor_list, tensor, mask):
82
+ pad_txt[: txt.shape[0], : txt.shape[1]] = txt
83
+ m[: txt.shape[1]] = False
84
+ else:
85
+ raise ValueError("not supported")
86
+ return NestedTensor(tensor, mask)
87
+
88
+ def _collate_and_pad_divisibility(tensor_list: list, div=32):
89
+ max_size = []
90
+ for i in range(tensor_list[0].dim()):
91
+ max_size_i = torch.max(
92
+ torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32)
93
+ ).to(torch.int64)
94
+ max_size.append(max_size_i)
95
+ max_size = tuple(max_size)
96
+
97
+ c,h,w = max_size
98
+ pad_h = (div - h % div) if h % div != 0 else 0
99
+ pad_w = (div - w % div) if w % div != 0 else 0
100
+ max_size = (c,h+pad_h,w+pad_w)
101
+
102
+ # work around for
103
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
104
+ # m[: img.shape[1], :img.shape[2]] = False
105
+ # which is not yet supported in onnx
106
+ padded_imgs = []
107
+ padded_masks = []
108
+ for img in tensor_list:
109
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
110
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
111
+ padded_imgs.append(padded_img)
112
+
113
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
114
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
115
+ padded_masks.append(padded_mask.to(torch.bool))
116
+
117
+ return padded_imgs
118
+
119
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
120
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
121
+ @torch.jit.unused
122
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
123
+ max_size = []
124
+ for i in range(tensor_list[0].dim()):
125
+ max_size_i = torch.max(
126
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
127
+ ).to(torch.int64)
128
+ max_size.append(max_size_i)
129
+ max_size = tuple(max_size)
130
+
131
+ # work around for
132
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
133
+ # m[: img.shape[1], :img.shape[2]] = False
134
+ # which is not yet supported in onnx
135
+ padded_imgs = []
136
+ padded_masks = []
137
+ for img in tensor_list:
138
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
139
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
140
+ padded_imgs.append(padded_img)
141
+
142
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
143
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
144
+ padded_masks.append(padded_mask.to(torch.bool))
145
+
146
+ tensor = torch.stack(padded_imgs)
147
+ mask = torch.stack(padded_masks)
148
+
149
+ return NestedTensor(tensor, mask=mask)
150
+
151
+
152
+ def is_dist_avail_and_initialized():
153
+ if not dist.is_available():
154
+ return False
155
+ if not dist.is_initialized():
156
+ return False
157
+ return True