mmazuecos commited on
Commit
2d07fab
β€’
1 Parent(s): a0502e7
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ data/saiapr_tc-12.zip filter=lfs diff=lfs merge=lfs -text
36
+ cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt filter=lfs diff=lfs merge=lfs -text
37
+ data/val-sim_metric.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: RECModel
3
- emoji: πŸ’©
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.9
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: ProbingREC
3
+ emoji: πŸ‘
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.4
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import IntuitionKillingMachine
2
+ from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
3
+ from torchvision.transforms import Compose
4
+ from encoders import get_tokenizer
5
+ from PIL import Image, ImageDraw
6
+ from zipfile import ZipFile
7
+ from copy import copy
8
+ import gradio as gr
9
+ import pandas as pd
10
+ import torch
11
+
12
+ def parse_model_args(model_path):
13
+ _, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
14
+ return {
15
+ 'dataset': dataset,
16
+ 'max_length': int(max_length),
17
+ 'input_size': int(input_size),
18
+ 'backbone': backbone,
19
+ 'num_heads': int(num_heads),
20
+ 'num_layers': int(num_layers),
21
+ 'num_conv': int(num_conv),
22
+ 'mu': float(mu),
23
+ 'mask_pooling': bool(mask_pooling == '1')
24
+ }
25
+
26
+
27
+ class Prober:
28
+ def __init__(self,
29
+ df_path=None,
30
+ dataset_path=None,
31
+ model_checkpoint=None):
32
+ params = parse_model_args(model_checkpoint)
33
+ mean = [0.485, 0.456, 0.406]
34
+ sdev = [0.229, 0.224, 0.225]
35
+ self.tokenizer = get_tokenizer()
36
+ self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
37
+ self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
38
+ self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
39
+ self.model = IntuitionKillingMachine(
40
+ backbone=params['backbone'],
41
+ pretrained=True,
42
+ num_heads=params['num_heads'],
43
+ num_layers=params['num_layers'],
44
+ num_conv=params['num_conv'],
45
+ segmentation_head=bool(params['mu'] > 0.0),
46
+ mask_pooling=params['mask_pooling']
47
+ )
48
+ self.load_model(model_checkpoint)
49
+ self.transform = Compose([
50
+ ToTensor(),
51
+ Normalize(mean, sdev),
52
+ SquarePad(),
53
+ Resize(size=(params['input_size'], params['input_size'])),
54
+ NormalizeBoxCoords(),
55
+ ])
56
+ self.max_length = 30
57
+ self.zipfile = ZipFile(dataset_path, 'r')
58
+
59
+ def load_model(self, model_checkpoint):
60
+ checkpoint = torch.load(
61
+ model_checkpoint, map_location=lambda storage, loc: storage
62
+ )
63
+
64
+ # strip 'model.' from pl checkpoint
65
+ state_dict = {
66
+ k[len('model.'):]: v
67
+ for k, v in checkpoint['state_dict'].items()
68
+ }
69
+
70
+ missing, _ = self.model.load_state_dict(state_dict, strict=False)
71
+
72
+ # ensure the only missing keys are those of the segmentation head only
73
+ assert [k for k in missing if 'segm' not in k] == []
74
+
75
+ self.model = self.model.eval()
76
+
77
+
78
+ @torch.no_grad()
79
+ def probe(self, idx, re, search_by_sample_id: bool= True):
80
+ if search_by_sample_id:
81
+ img_path, target, = self.df.loc[idx][['file_path','bbox']].values
82
+ else:
83
+ img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
84
+ img = Image.open(self.zipfile.open(img_path)).convert('RGB')
85
+ W0, H0 = img.size
86
+ sample = {
87
+ 'image': img,
88
+ 'image_size': (H0, W0), # image original size
89
+ 'bbox': torch.tensor([copy(target)]),
90
+ 'bbox_raw': torch.tensor([copy(target)]),
91
+ 'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
92
+ 'mask_bbox': None, # target bbox mask
93
+ }
94
+ sample = self.transform(sample)
95
+ tok = self.tokenizer(re,
96
+ max_length=30,
97
+ return_tensors='pt',
98
+ truncation=True)
99
+ inn = {'image': torch.stack([sample['image']]),
100
+ 'mask': torch.stack([sample['mask']]),
101
+ 'tok': tok}
102
+ output = undo_box_transforms_batch(self.model(inn)[0],
103
+ [sample['tr_param']]).numpy().tolist()[0]
104
+ img1 = ImageDraw.Draw(img)
105
+ #img1.rectangle(target, outline ="#0000FF00", width=3)
106
+ img1.rectangle(output, outline ="#00FF0000", width=3)
107
+ return img
108
+
109
+
110
+ prober = Prober(
111
+ df_path = 'data/val-sim_metric.json',
112
+ dataset_path = "data/saiapr_tc-12.zip",
113
+ model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
114
+ )
115
+
116
+ demo = gr.Interface(fn=prober.probe, inputs=["number", "text", "checkbox"], outputs="image")
117
+
118
+ demo.queue(concurrency_count=10)
119
+ demo.launch(debug=True)
backbones.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torch.nn as nn
4
+
5
+ from torchvision.ops.misc import FrozenBatchNorm2d
6
+
7
+ from torchvision.models import resnet, detection, segmentation
8
+
9
+ import timm
10
+
11
+
12
+ # https://detectron2.readthedocs.io/en/latest/modules/layers.html#detectron2.layers.FrozenBatchNorm2d.convert_frozen_batchnorm
13
+ @torch.no_grad()
14
+ def convert_frozen_batchnorm(module):
15
+ bn_module = (
16
+ nn.modules.batchnorm.BatchNorm2d,
17
+ nn.modules.batchnorm.SyncBatchNorm
18
+ )
19
+ res = module
20
+ if isinstance(module, bn_module):
21
+ res = FrozenBatchNorm2d(module.num_features)
22
+ if module.affine:
23
+ res.weight.data = module.weight.data.clone().detach()
24
+ res.bias.data = module.bias.data.clone().detach()
25
+ res.running_mean.data = module.running_mean.data
26
+ res.running_var.data = module.running_var.data
27
+ res.eps = module.eps
28
+ else:
29
+ for name, child in module.named_children():
30
+ new_child = convert_frozen_batchnorm(child)
31
+ if new_child is not child:
32
+ res.add_module(name, new_child)
33
+ return res
34
+
35
+
36
+ def get_backbone(backbone, pretrained=True):
37
+ if backbone in ('resnet18', 'resnet34', 'resnet50', 'resnet101'):
38
+ # pretrained on ImageNet for classification
39
+ model = resnet.__dict__[backbone](
40
+ pretrained=pretrained, norm_layer=FrozenBatchNorm2d
41
+ )
42
+ elif backbone == 'resnet50d':
43
+ # pretrained on COCO for detection
44
+ model = convert_frozen_batchnorm(
45
+ detection.fasterrcnn_resnet50_fpn(pretrained=pretrained).backbone.body
46
+ )
47
+ elif backbone == 'resnet50s':
48
+ # pretrained on COCO for segmentation
49
+ model = convert_frozen_batchnorm(
50
+ segmentation.deeplabv3_resnet50(pretrained=pretrained).backbone
51
+ )
52
+ elif backbone == 'resnet101s':
53
+ # pretrained on COCO for segmentation
54
+ model = convert_frozen_batchnorm(
55
+ segmentation.deeplabv3_resnet101(pretrained=pretrained).backbone
56
+ )
57
+
58
+ elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
59
+ # model = convert_frozen_batchnorm(
60
+ # timm.create_model(
61
+ # backbone.replace('-', '_'),
62
+ # pretrained=True,
63
+ # features_only=True,
64
+ # #out_indices=(1, 2, 3, 4)
65
+ # )
66
+ # )
67
+ model = convert_frozen_batchnorm(
68
+ timm.create_model(
69
+ backbone.replace('-', '_'),
70
+ pretrained=pretrained,
71
+ num_classes=0,
72
+ global_pool=''
73
+ )
74
+ )
75
+
76
+ else:
77
+ raise RuntimeError(f'{backbone} is not a valid backbone')
78
+
79
+ # empty cache (dealloc modules other than the backbone)
80
+ torch.cuda.empty_cache()
81
+
82
+ return model
cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aaaf1696c537a1a2b049ddfa150d36770b6e92c8524ca4e3706755c00648f26
3
+ size 1752031089
datasets.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import json
4
+
5
+ import random
6
+
7
+ import torch
8
+
9
+ import ijson
10
+
11
+ import numpy as np
12
+
13
+ from PIL import Image
14
+
15
+ from torchvision.transforms import ToTensor
16
+
17
+ from torchvision.ops import box_convert, clip_boxes_to_image
18
+
19
+ from re_classifier import REClassifier
20
+
21
+ from utils import progressbar
22
+
23
+
24
+ def collate_fn(batch):
25
+ image = torch.stack([s['image'] for s in batch], dim=0)
26
+
27
+ image_size = torch.FloatTensor([s['image_size'] for s in batch])
28
+
29
+ # bbox = torch.stack([s['bbox'] for s in batch], dim=0)
30
+ bbox = torch.cat([s['bbox'] for s in batch], dim=0)
31
+
32
+ # bbox_raw = torch.stack([s['bbox_raw'] for s in batch], dim=0)
33
+ bbox_raw = torch.cat([s['bbox_raw'] for s in batch], dim=0)
34
+
35
+ expr = [s['expr'] for s in batch]
36
+
37
+ tok = None
38
+ if batch[0]['tok'] is not None:
39
+ tok = {
40
+ 'input_ids': torch.cat([s['tok']['input_ids'] for s in batch], dim=0),
41
+ 'attention_mask': torch.cat([s['tok']['attention_mask'] for s in batch], dim=0)
42
+ }
43
+
44
+ # dynamic batching
45
+ max_length = max([s['tok']['length'] for s in batch])
46
+ tok = {
47
+ 'input_ids': tok['input_ids'][:, :max_length],
48
+ 'attention_mask': tok['attention_mask'][:, :max_length],
49
+ }
50
+
51
+ mask = None
52
+ if batch[0]['mask'] is not None:
53
+ mask = torch.stack([s['mask'] for s in batch], dim=0)
54
+
55
+ mask_bbox = None
56
+ if batch[0]['mask_bbox'] is not None:
57
+ mask_bbox = torch.stack([s['mask_bbox'] for s in batch], dim=0)
58
+
59
+ tr_param = [s['tr_param'] for s in batch]
60
+
61
+ return {
62
+ 'image': image,
63
+ 'image_size': image_size,
64
+ 'bbox': bbox,
65
+ 'bbox_raw': bbox_raw,
66
+ 'expr': expr,
67
+ 'tok': tok,
68
+ 'tr_param': tr_param,
69
+ 'mask': mask,
70
+ 'mask_bbox': mask_bbox,
71
+ }
72
+
73
+
74
+ class RECDataset(torch.utils.data.Dataset):
75
+ def __init__(self, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False):
76
+ super().__init__()
77
+ self.samples = [] # list of samples: [(file_name, expresion, bbox)]
78
+ self.transform = transform
79
+ self.tokenizer = tokenizer
80
+ self.max_length = int(max_length)
81
+ self.with_mask_bbox = bool(with_mask_bbox)
82
+
83
+ def tokenize(self, inp, max_length):
84
+ return self.tokenizer(
85
+ inp,
86
+ return_tensors='pt',
87
+ padding='max_length',
88
+ return_token_type_ids=False,
89
+ return_attention_mask=True,
90
+ add_special_tokens=True,
91
+ truncation=True,
92
+ max_length=max_length
93
+ )
94
+
95
+ def print_stats(self):
96
+ print(f'{len(self.samples)} samples')
97
+ lens = [len(expr.split()) for _, expr, _ in self.samples]
98
+ print('expression lengths stats: '
99
+ f'min={np.min(lens):.1f}, '
100
+ f'mean={np.mean(lens):.1f}, '
101
+ f'median={np.median(lens):.1f}, '
102
+ f'max={np.max(lens):.1f}, '
103
+ f'99.9P={np.percentile(lens, 99.9):.1f}'
104
+ )
105
+
106
+ def __len__(self):
107
+ return len(self.samples)
108
+
109
+ def __getitem__(self, idx):
110
+ file_name, expr, bbox = self.samples[idx]
111
+
112
+ if not os.path.exists(file_name):
113
+ raise IOError(f'{file_name} not found')
114
+ img = Image.open(file_name).convert('RGB')
115
+
116
+ # if isinstance(expr, (list, tuple)):
117
+ # expr = random.choice(expr)
118
+
119
+ # image size as read from disk (PIL)
120
+ W0, H0 = img.size
121
+
122
+ # # ensure box coordinates fall inside the image
123
+ # bbox = clip_boxes_to_image(bbox, (H0, W0))
124
+ # assert torch.all(bbox[:, (0, 1)] <= bbox[:, (2, 3)]) # xyxy format
125
+
126
+ sample = {
127
+ 'image': img,
128
+ 'image_size': (H0, W0), # image original size
129
+ 'bbox': bbox.clone(), # box transformations are inplace ops
130
+ 'bbox_raw': bbox.clone(), # raw boxes w/o any transformation (in pixels)
131
+ 'expr': expr,
132
+ 'tok': None,
133
+ 'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
134
+ 'mask_bbox': None, # target bbox mask
135
+ }
136
+
137
+ # apply transforms
138
+ if self.transform is None:
139
+ sample['image'] = ToTensor()(sample['image'])
140
+ else:
141
+ sample = self.transform(sample)
142
+
143
+ # tokenize after the transformations (just in case there where a left<>right substitution)
144
+ if self.tokenizer is not None:
145
+ sample['tok'] = self.tokenize(sample['expr'], self.max_length)
146
+ sample['tok']['length'] = sample['tok']['attention_mask'].sum(1).item()
147
+
148
+ # bbox segmentation mask
149
+ if self.with_mask_bbox:
150
+ # image size after transforms
151
+ _, H, W = sample['image'].size()
152
+
153
+ # transformed bbox in pixels
154
+ bbox = sample['bbox'].clone()
155
+ bbox[:, (0, 2)] *= W
156
+ bbox[:, (1, 3)] *= H
157
+ bbox = clip_boxes_to_image((bbox + 0.5).long(), (H, W))
158
+
159
+ # output mask
160
+ sample['mask_bbox'] = torch.zeros((1, H, W), dtype=torch.float32)
161
+ for x1, y1, x2, y2 in bbox.tolist():
162
+ sample['mask_bbox'][:, y1:y2+1, x1:x2+1] = 1.0
163
+
164
+ return sample
165
+
166
+
167
+ class RegionDescriptionsVisualGnome(RECDataset):
168
+ def __init__(self, data_root, transform=None, tokenizer=None,
169
+ max_length=32, with_mask_bbox=False):
170
+ super().__init__(transform=transform, tokenizer=tokenizer,
171
+ max_length=max_length, with_mask_bbox=with_mask_bbox)
172
+
173
+
174
+ # if available, read COCO IDs from the val, testA and testB splits from
175
+ # the RefCOCO dataset
176
+ try:
177
+ with open('./refcoco_valtest_ids.txt', 'r') as fh:
178
+ refcoco_ids = [int(lin.strip()) for lin in fh.readlines()]
179
+ except:
180
+ refcoco_ids = []
181
+
182
+ def path_from_url(fname):
183
+ return os.path.join(data_root, fname[fname.index('VG_100K'):])
184
+
185
+ with open(os.path.join(data_root, 'image_data.json'), 'r') as f:
186
+ image_data = {
187
+ data['image_id']: path_from_url(data['url'])
188
+ for data in json.load(f)
189
+ if data['coco_id'] is None or data['coco_id'] not in refcoco_ids
190
+ }
191
+ print(f'{len(image_data)} images')
192
+
193
+ self.samples = []
194
+
195
+ with open(os.path.join(data_root, 'region_descriptions.json'), 'r') as f:
196
+ for record in progressbar(ijson.items(f, 'item.regions.item'), desc='loading data'):
197
+ if record['image_id'] not in image_data:
198
+ continue
199
+ file_name = image_data[record['image_id']]
200
+
201
+ expr = record['phrase']
202
+
203
+ bbox = [record['x'], record['y'], record['width'], record['height']]
204
+ bbox = torch.atleast_2d(torch.FloatTensor(bbox))
205
+ bbox = box_convert(bbox, 'xywh', 'xyxy') # xyxy
206
+
207
+ self.samples.append((file_name, expr, bbox))
208
+
209
+ self.print_stats()
210
+
211
+
212
+ class ReferDataset(RECDataset):
213
+ def __init__(self, data_root, dataset, split_by, split, transform=None,
214
+ tokenizer=None, max_length=32, with_mask_bbox=False):
215
+ super().__init__(transform=transform, tokenizer=tokenizer,
216
+ max_length=max_length, with_mask_bbox=with_mask_bbox)
217
+
218
+ # https://github.com/lichengunc/refer
219
+ try:
220
+ import sys
221
+ sys.path.append('refer')
222
+ from refer import REFER
223
+ except:
224
+ raise RuntimeError('create a symlink to valid refer compilation '
225
+ '(see https://github.com/lichengunc/refer)')
226
+
227
+ refer = REFER(data_root, dataset, split_by)
228
+ ref_ids = sorted(refer.getRefIds(split=split))
229
+
230
+ self.samples = []
231
+
232
+ for rid in progressbar(ref_ids, desc='loading data'):
233
+ ref = refer.Refs[rid]
234
+ ann = refer.refToAnn[rid]
235
+
236
+ file_name = refer.Imgs[ref['image_id']]['file_name']
237
+ if dataset == 'refclef':
238
+ file_name = os.path.join(
239
+ 'refer', 'data', 'images', 'saiapr_tc-12', file_name
240
+ )
241
+ else:
242
+ coco_set = file_name.split('_')[1]
243
+ file_name = os.path.join(
244
+ 'refer', 'data', 'images', 'mscoco', coco_set, file_name
245
+ )
246
+
247
+ bbox = ann['bbox']
248
+ bbox = torch.atleast_2d(torch.FloatTensor(bbox))
249
+ bbox = box_convert(bbox, 'xywh', 'xyxy') # xyxy
250
+
251
+ sentences = [s['sent'] for s in ref['sentences']]
252
+ if 'train' in split: # remove repeated expresions
253
+ sentences = list(set(sentences))
254
+ sentences = sorted(sentences)
255
+
256
+ self.samples += [(file_name, expr, bbox) for expr in sentences]
257
+
258
+ self.print_stats()
259
+
260
+
261
+ class RefCLEF(ReferDataset):
262
+ def __init__(self, *args, **kwargs):
263
+ assert args[0] in ('train', 'val', 'test')
264
+ super().__init__('refer/data', 'refclef', 'berkeley', *args, **kwargs)
265
+
266
+
267
+ class RefCOCO(ReferDataset):
268
+ def __init__(self, *args, **kwargs):
269
+ assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB')
270
+ super().__init__('refer/data', 'refcoco', 'unc', *args, **kwargs)
271
+
272
+
273
+ class RefCOCOp(ReferDataset):
274
+ def __init__(self, *args, **kwargs):
275
+ assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB')
276
+ super().__init__('refer/data', 'refcoco+', 'unc', *args, **kwargs)
277
+
278
+
279
+ class RefCOCOg(ReferDataset):
280
+ def __init__(self, *args, **kwargs):
281
+ assert args[0] in ('train', 'val', 'test')
282
+ super().__init__('refer/data', 'refcocog', 'umd', *args, **kwargs)
embeddings.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+ from torch import nn
6
+
7
+
8
+ # adapted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
9
+ class PositionEmbedding1D(nn.Module):
10
+ def __init__(self, embedding_dim, dropout=0.1, max_len=128):
11
+ super().__init__()
12
+
13
+ # self.dropout = nn.Dropout(p=dropout)
14
+
15
+ position = torch.arange(max_len).unsqueeze(1)
16
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim))
17
+ pe = torch.zeros(max_len, embedding_dim)
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+ pe = pe.unsqueeze(0) # .transpose(0, 1)
21
+ self.register_buffer('pe', pe)
22
+
23
+ def forward(self, x):
24
+ # # x: Tensor, shape [batch_size, seq_len, embedding_dim]
25
+ # x = x + self.pe[:, :x.size(1)]
26
+ # return self.dropout(x)
27
+ N, T, _ = x.size()
28
+ return self.pe[:, :T].repeat(N, 1, 1)
29
+
30
+
31
+ class LearnedPositionEmbedding1D(nn.Module):
32
+ def __init__(self, embedding_dim, max_len=128):
33
+ super().__init__()
34
+ self.pe = nn.Parameter(torch.Tensor(1, max_len, embedding_dim))
35
+ self.reset_parameters()
36
+
37
+ def reset_parameters(self):
38
+ nn.init.xavier_normal_(self.pe)
39
+
40
+ def forward(self, x):
41
+ N, T, _ = x.size()
42
+ return self.pe[:, :T].repeat(N, 1, 1)
43
+
44
+
45
+ # https://huggingface.co/transformers/_modules/transformers/models/detr/modeling_detr.html
46
+ class PositionEmbedding2D(nn.Module):
47
+ def __init__(self, embedding_dim, temperature=10000, normalize=False,
48
+ scale=None):
49
+ super().__init__()
50
+ assert embedding_dim % 2 == 0
51
+ self.half_embedding_dim = embedding_dim // 2
52
+ self.temperature = temperature
53
+ self.normalize = normalize
54
+ if scale is not None and normalize is False:
55
+ raise ValueError("normalize should be True if scale is passed")
56
+ if scale is None:
57
+ scale = 2 * math.pi
58
+ self.scale = scale
59
+
60
+ def forward(self, pixel_values, pixel_mask):
61
+ assert pixel_mask is not None, "No pixel mask provided"
62
+ if pixel_mask.dim() == 4 and pixel_mask.size(1) == 1:
63
+ pixel_mask = pixel_mask.squeeze(1)
64
+ y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
65
+ x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
66
+ if self.normalize:
67
+ y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
68
+ x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
69
+
70
+ dim_t = torch.arange(self.half_embedding_dim, dtype=torch.float32, device=pixel_values.device)
71
+ dim_t = self.temperature ** (2 * torch.divide(dim_t, 2, rounding_mode='floor') / self.half_embedding_dim)
72
+
73
+ pos_x = x_embed[:, :, :, None] / dim_t
74
+ pos_y = y_embed[:, :, :, None] / dim_t
75
+ pos_x = torch.stack((
76
+ pos_x[:, :, :, 0::2].sin(),
77
+ pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
78
+ pos_y = torch.stack((
79
+ pos_y[:, :, :, 0::2].sin(),
80
+ pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
81
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
82
+ return pos
83
+
84
+
85
+ # https://huggingface.co/transformers/_modules/transformers/models/detr/modeling_detr.html
86
+ class LearnedPositionEmbedding2D(nn.Module):
87
+ def __init__(self, embedding_dim):
88
+ super().__init__()
89
+ assert embedding_dim % 2 == 0, 'embedding dimensionality must be even'
90
+ self.rows_embeddings = nn.Embedding(50, embedding_dim//2)
91
+ self.cols_embeddings = nn.Embedding(50, embedding_dim//2)
92
+
93
+ def forward(self, pixel_values, pixel_mask=None):
94
+ h, w = pixel_values.shape[-2:]
95
+ i = torch.arange(w, device=pixel_values.device)
96
+ j = torch.arange(h, device=pixel_values.device)
97
+ x_emb = self.cols_embeddings(i)
98
+ y_emb = self.rows_embeddings(j)
99
+ pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1)], dim=-1)
100
+ pos = pos.permute(2, 0, 1)
101
+ pos = pos.unsqueeze(0)
102
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
103
+ return pos
104
+
105
+
106
+ class Box8PositionEmbedding2D(nn.Module):
107
+ def __init__(self, embedding_dim, with_projection=True):
108
+ super().__init__()
109
+
110
+ self.proj = None
111
+ if with_projection:
112
+ self.proj = nn.Linear(8, embedding_dim)
113
+ nn.init.xavier_normal_(self.proj.weight)
114
+ nn.init.zeros_(self.proj.bias)
115
+
116
+ def forward(self, fmap, fmap_mask=None):
117
+ N, _, H, W = fmap.size()
118
+
119
+ y1, x1 = torch.meshgrid(
120
+ torch.arange(H, device=fmap.device, dtype=torch.float)/H,
121
+ torch.arange(W, device=fmap.device, dtype=torch.float)/W
122
+ )
123
+ y2, x2 = x1+1.0/W, y1+1.0/H
124
+ ww, hh = x2-x1, y2-y1
125
+ # x1, y1 = 2*x1-1, 2*y1-1
126
+ # x2, y2 = 2*x2-1, 2*y2-1
127
+ xc, yc = x1+0.5/W, y1+0.5/H
128
+
129
+ pos = torch.stack([x1, y1, x2, y2, xc, yc, ww, hh], dim=-1)
130
+ if self.proj is not None:
131
+ pos = self.proj(pos)
132
+ pos = pos.permute(2, 0, 1)
133
+ pos = pos.unsqueeze(0).repeat(N, 1, 1, 1)
134
+ return pos
135
+
136
+ def encode_boxes(self, boxes):
137
+ x1, y1, x2, y2 = boxes.unbind(-1)
138
+ ww, hh = x2-x1, y2-y1
139
+ xc, yc = x1+0.5*ww, y1+0.5*hh
140
+ pos = torch.stack([x1, y1, x2, y2, xc, yc, ww, hh], dim=-1)
141
+ if self.proj is not None:
142
+ pos = self.proj(pos)
143
+ return pos
144
+
145
+
146
+ class RelativePositionEmbedding2D(nn.Module):
147
+ def __init__(self, embedding_dim, spatial_bins=(16, 16), with_projection=True):
148
+ super().__init__()
149
+
150
+ assert isinstance(spatial_bins, (list, tuple)) and len(spatial_bins) == 2
151
+ self.spatial_bins = spatial_bins
152
+
153
+ self.proj = None
154
+ if with_projection:
155
+ self.proj = nn.Linear(2*spatial_bins[0]*spatial_bins[1], embedding_dim)
156
+ nn.init.xavier_normal_(self.proj.weight)
157
+ nn.init.zeros_(self.proj.bias)
158
+
159
+ def forward(self, fmap, fmap_mask=None):
160
+ N, _, H, W = fmap.size()
161
+
162
+ BH, BW = self.spatial_bins
163
+ yc, xc = torch.meshgrid(
164
+ 0.5/BH + torch.arange(BH, device=fmap.device, dtype=torch.float)/BH,
165
+ 0.5/BW + torch.arange(BW, device=fmap.device, dtype=torch.float)/BW
166
+ )
167
+
168
+ pos = torch.stack([xc, yc], dim=-1).view(-1, 1, 2)
169
+ pos = (pos - pos.transpose(0, 1)).reshape(BH, BW, -1) # relative positions
170
+
171
+ if self.proj is not None:
172
+ pos = self.proj(pos)
173
+
174
+ pos = pos.permute(2, 0, 1)
175
+ pos = pos.unsqueeze(0)
176
+
177
+ if H != BH or W != BW:
178
+ pos = nn.functional.interpolate(pos, (H, W), mode='nearest')
179
+
180
+ pos = pos.repeat(N, 1, 1, 1)
181
+
182
+ return pos
encoders.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+
7
+ import transformers
8
+
9
+ import torch.nn.functional as F
10
+
11
+ from torch import nn
12
+
13
+ from torchvision.models import detection
14
+
15
+ from backbones import get_backbone
16
+
17
+ from embeddings import Box8PositionEmbedding2D
18
+
19
+ EPS = 1e-5
20
+
21
+ TRANSFORMER_MODEL = 'bert-base-uncased'
22
+ # TRANSFORMER_MODEL = 'distilroberta-base'
23
+
24
+
25
+ def get_tokenizer(cache=None):
26
+ if cache is None:
27
+ return transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)
28
+
29
+ model_path = os.path.join(cache, TRANSFORMER_MODEL)
30
+ os.makedirs(model_path, exist_ok=True)
31
+
32
+ if os.path.exists(os.path.join(model_path, 'config.json')):
33
+ return transformers.BertTokenizer.from_pretrained(model_path)
34
+
35
+ tokenizer = transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)
36
+ tokenizer.save_pretrained(model_path)
37
+
38
+ return tokenizer
39
+
40
+
41
+ def weight_init(m):
42
+ if isinstance(m, nn.Conv2d):
43
+ nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
44
+ if m.bias is not None:
45
+ nn.init.zeros_(m.bias)
46
+ elif isinstance(m, nn.Linear):
47
+ nn.init.xavier_normal_(m.weight)
48
+ if m.bias is not None:
49
+ nn.init.zeros_(m.bias)
50
+ elif isinstance(m, nn.Embedding):
51
+ nn.init.xavier_normal_(m.weight)
52
+
53
+
54
+ class ImageEncoder(nn.Module):
55
+ def __init__(self, backbone='resnet50', out_channels=256, pretrained=True,
56
+ freeze_pretrained=False, with_pos=True):
57
+ super().__init__()
58
+
59
+ model = get_backbone(backbone, pretrained)
60
+
61
+ if pretrained and freeze_pretrained:
62
+ for p in model.parameters():
63
+ p.requires_grad = False
64
+
65
+ if 'resnet' in backbone:
66
+ self.backbone = detection.backbone_utils.IntermediateLayerGetter(
67
+ model, return_layers=OrderedDict({'layer4': 'output'})
68
+ )
69
+ channels = 512 if backbone in ('resnet18', 'resnet34') else 2048
70
+
71
+ elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
72
+ output_layer_name = list(model.named_children())[-1][0]
73
+ self.backbone = detection.backbone_utils.IntermediateLayerGetter(
74
+ model, return_layers=OrderedDict({output_layer_name: 'output'})
75
+ )
76
+ channels = {
77
+ 'cspdarknet53': 1024,
78
+ 'efficientnet-b0': 1280,
79
+ 'efficientnet-b3': 1536
80
+ }[backbone]
81
+
82
+ else:
83
+ raise RuntimeError('not a valid backbone')
84
+
85
+ in_channels = channels+8 if with_pos else channels
86
+
87
+ self.proj = nn.Sequential(
88
+ nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
89
+ nn.GroupNorm(1, out_channels, eps=EPS),
90
+ # nn.ReLU(inplace=True),
91
+ )
92
+ self.proj.apply(weight_init)
93
+
94
+ self.pos_emb = None
95
+ if with_pos:
96
+ self.pos_emb = Box8PositionEmbedding2D(with_projection=False)
97
+
98
+ self.out_channels = out_channels
99
+
100
+ def forward(self, img, mask=None):
101
+ x = self.backbone(img)['output']
102
+ if self.pos_emb is not None:
103
+ x = torch.cat([x, self.pos_emb(x)], dim=1)
104
+ x = self.proj(x) # NxDxHxW
105
+
106
+ x_mask = None
107
+ if mask is not None:
108
+ _, _, H, W = x.size()
109
+ x_mask = F.interpolate(mask, (H, W), mode='bilinear')
110
+ x_mask = (x_mask > 0.5).long()
111
+
112
+ return x, x_mask
113
+
114
+
115
+ class FPNImageEncoder(nn.Module):
116
+ def __init__(self,
117
+ backbone='resnet50', out_channels=256, pretrained=True,
118
+ freeze_pretrained=False, with_pos=True):
119
+ super().__init__()
120
+
121
+ model = get_backbone(backbone, pretrained)
122
+
123
+ if pretrained and freeze_pretrained:
124
+ for p in model.parameters():
125
+ p.requires_grad = False
126
+
127
+ if 'resnet' in backbone:
128
+ if backbone in ('resnet18', 'resnet34'):
129
+ in_channels_list = [64, 128, 256, 512]
130
+ else:
131
+ in_channels_list = [256, 512, 1024, 2048]
132
+ return_layers = OrderedDict({
133
+ 'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'
134
+ })
135
+
136
+ # elif backbone == 'cspdarknet53':
137
+ # in_channels_list = [128, 256, 512, 1024]
138
+ # return_layers = OrderedDict({
139
+ # '1':'0', '2':'1', '3':'2', '4':'3'
140
+ # })
141
+
142
+ else:
143
+ raise RuntimeError('not a valid backbone')
144
+
145
+ self.backbone = model
146
+
147
+ self.fpn = detection.backbone_utils.BackboneWithFPN(
148
+ backbone=self.backbone,
149
+ return_layers=return_layers,
150
+ in_channels_list=in_channels_list,
151
+ out_channels=out_channels
152
+ )
153
+
154
+ self.fpn.fpn.extra_blocks = None # removes the 'pool' layer added by default
155
+
156
+ self.out_channels = out_channels
157
+
158
+ in_channels = int(out_channels + float(with_pos) * 8)
159
+
160
+ self.proj = nn.ModuleDict({
161
+ level: nn.Sequential(
162
+ nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
163
+ nn.GroupNorm(1, out_channels, eps=EPS),
164
+ # nn.ReLU(inplace=True),
165
+ ) for level in return_layers.values()
166
+ })
167
+ self.proj.apply(weight_init)
168
+
169
+ self.pos_emb = None
170
+ if with_pos:
171
+ self.pos_emb = Box8PositionEmbedding2D(with_projection=False)
172
+
173
+ def forward(self, x, mask=None):
174
+ x = self.fpn(x)
175
+
176
+ # smallest feature map (eg. 16x16 for an input of 512x512 pixels)
177
+ _, _, H, W = list(x.values())[-1].size()
178
+
179
+ x_out = None
180
+ for level, fmap in x.items():
181
+ # fmap = torch.relu(fmap) # FPN blocks end in a conv2d, w/o activ.
182
+ if self.pos_emb is not None:
183
+ fmap = torch.cat([fmap, self.pos_emb(fmap)], dim=1) # +Pos
184
+ fmap = self.proj[level](fmap) # Conv+BN+ReLU
185
+ fmap = F.interpolate(fmap, (H, W), mode='nearest') # to a smaller size
186
+ if x_out is None:
187
+ x_out = fmap
188
+ else:
189
+ x_out += fmap
190
+
191
+ x_mask = None
192
+ if mask is not None:
193
+ x_mask = F.interpolate(mask, (H, W), mode='bilinear')
194
+ x_mask = (x_mask > 0.5).long()
195
+
196
+ return x_out, x_mask
197
+
198
+
199
+ class TransformerImageEncoder(nn.Module):
200
+ def __init__(self,
201
+ backbone='resnet50', out_channels=256, pretrained=True,
202
+ freeze_pretrained=False, num_heads=8, num_layers=6,
203
+ dropout_p=0.1):
204
+ super().__init__()
205
+
206
+ model = get_backbone(backbone, pretrained)
207
+
208
+ if pretrained and freeze_pretrained:
209
+ for p in model.parameters():
210
+ p.requires_grad = False
211
+
212
+ if 'resnet' in backbone:
213
+ self.backbone = detection.backbone_utils.IntermediateLayerGetter(
214
+ model, return_layers=OrderedDict({'layer4': 'output'})
215
+ )
216
+ channels = 512 if backbone in ('resnet18', 'resnet34') else 2048
217
+
218
+ elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
219
+ output_layer_name = list(model.named_children())[-1][0]
220
+ self.backbone = detection.backbone_utils.IntermediateLayerGetter(
221
+ model, return_layers=OrderedDict({output_layer_name: 'output'})
222
+ )
223
+ channels = {
224
+ 'cspdarknet53': 1024,
225
+ 'efficientnet-b0': 1280,
226
+ 'efficientnet-b3': 1536
227
+ }[backbone]
228
+
229
+ else:
230
+ raise RuntimeError('not a valid backbone')
231
+
232
+ self.proj = nn.Sequential(
233
+ nn.Conv2d(channels, out_channels, (1, 1), 1, bias=False),
234
+ nn.GroupNorm(1, out_channels, eps=EPS),
235
+ # nn.ReLU(inplace=True),
236
+ )
237
+ self.proj.apply(weight_init)
238
+
239
+ from transformers_pos import (
240
+ TransformerEncoder,
241
+ TransformerEncoderLayer,
242
+ )
243
+
244
+ self.encoder = TransformerEncoder(
245
+ TransformerEncoderLayer(
246
+ d_model=out_channels,
247
+ nhead=num_heads,
248
+ dropout=dropout_p,
249
+ batch_first=True
250
+ ),
251
+ num_layers=num_layers
252
+ )
253
+
254
+ self.pos_emb = Box8PositionEmbedding2D(embedding_dim=out_channels)
255
+
256
+ self.out_channels = out_channels
257
+
258
+ def flatten(self, x):
259
+ N, _, H, W = x.size()
260
+ x = x.to(memory_format=torch.channels_last)
261
+ x = x.permute(0, 2, 3, 1).view(N, H*W, -1) # NxHWxD
262
+ return x
263
+
264
+ def forward(self, img, mask=None):
265
+ x = self.backbone(img)['output']
266
+ x = self.proj(x) # NxDxHxW
267
+
268
+ N, _, H, W = x.size()
269
+
270
+ pos = self.pos_emb(x) # NxDxHxW
271
+ pos = self.flatten(pos) # NxRxD
272
+
273
+ x = self.flatten(x) # NxRxD
274
+
275
+ # visibility mask
276
+ x_mask = None
277
+ if mask is not None:
278
+ x_mask = F.interpolate(mask, (H, W), mode='bilinear')
279
+ x_mask = (x_mask > 0.5).long()
280
+
281
+ if mask is None:
282
+ x = self.encoder(x, pos=pos) # NxRxD
283
+ else:
284
+ mask = self.flatten(x_mask).squeeze(-1)
285
+ x = self.encoder(x, src_key_padding_mask=(mask==0), pos=pos) # NxRxD
286
+
287
+ x = x.permute(0, 2, 1).view(N, -1, H, W) # NxDxHxW
288
+
289
+ return x, x_mask
290
+
291
+
292
+ class LanguageEncoder(nn.Module):
293
+ def __init__(self, out_features=256, dropout_p=0.2,
294
+ freeze_pretrained=False, global_pooling=True):
295
+ super().__init__()
296
+ self.language_model = transformers.AutoModel.from_pretrained(
297
+ TRANSFORMER_MODEL
298
+ )
299
+
300
+ if freeze_pretrained:
301
+ for p in self.language_model.parameters():
302
+ p.requires_grad = False
303
+
304
+ self.out_features = out_features
305
+
306
+ self.proj = nn.Sequential(
307
+ nn.Linear(768, out_features),
308
+ nn.LayerNorm(out_features, eps=1e-5),
309
+ # nn.ReLU(inplace=True),
310
+ # nn.Dropout(dropout_p),
311
+ )
312
+ self.proj.apply(weight_init)
313
+
314
+ self.global_pooling = bool(global_pooling)
315
+
316
+ def forward(self, z):
317
+ res = self.language_model(
318
+ input_ids=z['input_ids'],
319
+ position_ids=None,
320
+ attention_mask=z['attention_mask']
321
+ )
322
+
323
+ if self.global_pooling:
324
+ z, z_mask = self.proj(res.pooler_output), None
325
+ else:
326
+ z, z_mask = self.proj(res.last_hidden_state), z['attention_mask']
327
+
328
+ return z, z_mask
329
+
330
+
331
+ class RNNLanguageEncoder(nn.Module):
332
+ def __init__(self,
333
+ model_type='gru', hidden_size=1024, num_layers=2,
334
+ out_features=256, dropout_p=0.2, global_pooling=True):
335
+ super().__init__()
336
+ self.embeddings = transformers.AutoModel.from_pretrained(
337
+ TRANSFORMER_MODEL
338
+ ).embeddings.word_embeddings
339
+ self.embeddings.weight.requires_grad = True
340
+
341
+ # self.dropout_emb = nn.Dropout(0.5)
342
+ self.dropout_emb = nn.Dropout(dropout_p)
343
+
344
+ assert model_type in ('gru', 'lstm')
345
+ self.rnn = (nn.GRU if model_type == 'gru' else nn.LSTM)(
346
+ input_size=self.embeddings.weight.size(1),
347
+ hidden_size=hidden_size,
348
+ num_layers=num_layers,
349
+ dropout=dropout_p,
350
+ batch_first=True,
351
+ bidirectional=True
352
+ )
353
+
354
+ self.proj = nn.Sequential(
355
+ nn.Linear(2*hidden_size, out_features),
356
+ nn.LayerNorm(out_features, eps=1e-5),
357
+ # nn.ReLU(inplace=True),
358
+ # nn.Dropout(dropout_p),
359
+ )
360
+ self.proj.apply(weight_init)
361
+
362
+ self.out_features = out_features
363
+
364
+ self.global_pooling = bool(global_pooling)
365
+ assert global_pooling # only w/ global pooling
366
+
367
+ def forward(self, z):
368
+ z_mask = z['attention_mask']
369
+
370
+ z = self.dropout_emb(self.embeddings(z['input_ids']))
371
+ z, h_n = self.rnn(z, None)
372
+
373
+ if isinstance(self.rnn, nn.LSTM):
374
+ h_n = h_n[0]
375
+
376
+ # hidden states as (num_layers, num_directions, batch, hidden_size)
377
+ h_n = h_n.view(self.rnn.num_layers, 2, z.size(0), self.rnn.hidden_size)
378
+
379
+ # last hidden states
380
+ h_n = h_n[-1].permute(1, 0, 2).reshape(z.size(0), -1)
381
+ h_n = self.proj(h_n)
382
+ return h_n, z_mask
383
+
384
+
385
+ class SimpleEncoder(nn.Module):
386
+ def __init__(self, out_features=256, dropout_p=0.1, global_pooling=True):
387
+ super().__init__()
388
+ self.embeddings = transformers.AutoModel.from_pretrained(
389
+ TRANSFORMER_MODEL
390
+ ).embeddings.word_embeddings
391
+ self.embeddings.weight.requires_grad = True
392
+
393
+ # self.dropout_emb = nn.Dropout(0.5)
394
+ self.dropout_emb = nn.Dropout(dropout_p)
395
+
396
+ self.proj = nn.Sequential(
397
+ nn.Linear(768, out_features),
398
+ nn.LayerNorm(out_features, eps=1e-5),
399
+ # nn.ReLU(inplace=True),
400
+ # nn.Dropout(dropout_p),
401
+ )
402
+ self.proj.apply(weight_init)
403
+
404
+ self.out_features = out_features
405
+
406
+ self.global_pooling = bool(global_pooling)
407
+ assert not self.global_pooling # only w/o global pooling
408
+
409
+ def forward(self, z):
410
+ z_mask = z['attention_mask']
411
+ z = self.embeddings(z['input_ids'])
412
+ z = self.proj(self.dropout_emb(z))
413
+ # z[:, 0] = torch.mean(z[:, 1:], 1)
414
+ return z, z_mask
models.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torchvision.ops import box_convert
5
+ import embeddings as emb
6
+ import encoders as enc
7
+ from encoders import weight_init
8
+
9
+ def conv3x3(in_channels, out_channels, num_groups=0):
10
+ return nn.Sequential(
11
+ # Conv2d w/o bias since BatchNorm2d/GroupNorm already accounts for it (affine=True)
12
+ nn.Conv2d(in_channels, out_channels, (3, 3), 1, 1, bias=False),
13
+ nn.BatchNorm2d(out_channels) if num_groups < 1 else nn.GroupNorm(num_groups, out_channels),
14
+ nn.ReLU(inplace=True),
15
+ )
16
+
17
+
18
+ class IntuitionKillingMachine(nn.Module):
19
+ def __init__(self,
20
+ backbone='resnet50', pretrained=True, embedding_size=256,
21
+ num_heads=8, num_layers=6, num_conv=4, dropout_p=0.1,
22
+ segmentation_head=True, mask_pooling=True):
23
+ super().__init__()
24
+
25
+ if backbone.endswith('+tr'):
26
+ self.vis_enc = enc.TransformerImageEncoder(
27
+ backbone=backbone.rstrip('+tr'),
28
+ out_channels=embedding_size,
29
+ pretrained=pretrained,
30
+ )
31
+
32
+ elif backbone.endswith('+fpn'):
33
+ self.vis_enc = enc.FPNImageEncoder(
34
+ backbone=backbone.rstrip('+fpn'),
35
+ out_channels=embedding_size,
36
+ pretrained=pretrained,
37
+ with_pos=False
38
+ )
39
+ else:
40
+ self.vis_enc = enc.ImageEncoder(
41
+ backbone=backbone,
42
+ out_channels=embedding_size,
43
+ pretrained=pretrained,
44
+ with_pos=False
45
+ )
46
+
47
+ # freeze ResNet stem
48
+ if 'resnet' in backbone:
49
+ self.vis_enc.backbone.conv1.requires_grad = False
50
+ self.vis_enc.backbone.conv1.eval()
51
+
52
+ self.vis_pos_emb = emb.LearnedPositionEmbedding2D(
53
+ embedding_dim=embedding_size
54
+ )
55
+
56
+ self.lan_enc = enc.LanguageEncoder(
57
+ out_features=embedding_size,
58
+ global_pooling=False,
59
+ dropout_p=dropout_p
60
+ )
61
+
62
+ self.lan_pos_emb = emb.LearnedPositionEmbedding1D(
63
+ embedding_dim=embedding_size
64
+ )
65
+
66
+ from transformers_pos import (
67
+ XTransformerEncoder,
68
+ TransformerEncoder,
69
+ TransformerEncoderLayer,
70
+ )
71
+
72
+ self.encoder = TransformerEncoder(
73
+ TransformerEncoderLayer(
74
+ d_model=embedding_size,
75
+ nhead=num_heads,
76
+ dropout=dropout_p,
77
+ batch_first=True
78
+ ),
79
+ num_layers=num_layers
80
+ )
81
+
82
+ # ---
83
+ # CONV PRE-HEAD (NECK?)
84
+
85
+ if num_conv > 0:
86
+ self.pre_head = nn.Sequential(*[
87
+ conv3x3(embedding_size, embedding_size) for _ in range(num_conv)
88
+ ])
89
+ self.pre_head.apply(weight_init)
90
+ else:
91
+ self.pre_head = nn.Identity()
92
+
93
+ # ---
94
+ # OUTPUT HEADS
95
+
96
+ # box prediction
97
+ self.head = nn.Sequential(
98
+ nn.Linear(embedding_size, 4, bias=True),
99
+ nn.Sigmoid()
100
+ )
101
+ self.head.apply(weight_init)
102
+
103
+ # box segmentation mask
104
+ self.segm_head = None
105
+ if segmentation_head:
106
+ self.segm_head = nn.Sequential(
107
+ nn.Conv2d(embedding_size, 1, (3, 3), 1, 1, bias=True),
108
+ #nn.Sigmoid()
109
+ )
110
+ self.segm_head.apply(weight_init)
111
+
112
+ # ---
113
+
114
+ self.mask_pooling = bool(mask_pooling)
115
+
116
+ if self.mask_pooling and self.segm_head is None:
117
+ raise RuntimeError('mask pooling w/o a segmentation head does not makes sense')
118
+
119
+ self.embedding_size = embedding_size
120
+
121
+ # def slow_param_ids(self, **kwargs):
122
+ # return []
123
+
124
+ def slow_param_ids(self, slow_visual_backbone=True, slow_language_backbone=True):
125
+ ids = []
126
+
127
+ if slow_visual_backbone:
128
+ ids += [id(p) for p in self.vis_enc.backbone.parameters()]
129
+ if hasattr(self.vis_enc, 'encoder'): # +tr
130
+ ids += [id(p) for p in self.vis_enc.encoder.parameters()]
131
+
132
+ if slow_language_backbone:
133
+ if isinstance(self.lan_enc, enc.LanguageEncoder):
134
+ ids += [id(p) for p in self.lan_enc.language_model.parameters()]
135
+ else:
136
+ ids += [id(p) for p in self.lan_enc.embeddings.parameters()]
137
+
138
+ return ids
139
+
140
+ def flatten(self, x):
141
+ N, D, H, W = x.size()
142
+ x = x.to(memory_format=torch.channels_last)
143
+ x = x.permute(0, 2, 3, 1).view(N, H*W, D)
144
+ return x # NxHWxD
145
+
146
+ def unflatten(self, x, size):
147
+ N, R, D = x.size()
148
+ H, W = size
149
+ assert R == H*W, 'wrong tensor size'
150
+ x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
151
+ x = x.view(N, D, H, W)
152
+ return x # NxDxHxW
153
+
154
+ def forward(self, input):
155
+ img, mask, tok = input['image'], input['mask'], input['tok']
156
+
157
+ # ---
158
+ # VISUAL EMBEDDINGS
159
+
160
+ x, x_mask = self.vis_enc(img, mask) # NxDxHxW, NxHxW
161
+ x_pos = self.vis_pos_emb(x, x_mask)
162
+
163
+ N, D, H, W = x.size() # save dims before flatten
164
+
165
+ x = self.flatten(x) # NxRxD
166
+ x_mask = self.flatten(x_mask).squeeze(-1) # NxR
167
+ x_pos = self.flatten(x_pos) # NxRxD
168
+
169
+ # ---
170
+ # LANGUAGE EMBEDDINGS
171
+
172
+ z, z_mask = self.lan_enc(tok) # NxTxD, NxT
173
+ z_pos = self.lan_pos_emb(z) # NxTxD
174
+
175
+ # ---
176
+ # V+L TRANSFORMER
177
+
178
+ # [...visual...]+[[CLS]...language tokens...[SEP]]
179
+ xz = torch.cat([x, z], dim=1)
180
+ xz_mask = torch.cat([x_mask, z_mask], dim=1)
181
+ xz_pos = torch.cat([x_pos, z_pos], dim=1)
182
+
183
+ xz = self.encoder(xz, src_key_padding_mask=(xz_mask==0), pos=xz_pos) #, size=(H,W))
184
+
185
+ # restore spatiality of visual embeddings after cross-modal encoding
186
+ xz_vis = xz[:, :H*W, ...]
187
+ xz_vis = self.unflatten(xz_vis, (H, W))
188
+
189
+ x_mask = self.unflatten(x_mask.unsqueeze(-1), (H, W))
190
+
191
+ # ---
192
+
193
+ # convolutional pre-head
194
+ xz_vis = self.pre_head(xz_vis)
195
+
196
+ # ---
197
+
198
+ # segmentation head w/ (opt.) pooling
199
+ segm_mask, pooled_feat = None, None
200
+ if self.segm_head is not None:
201
+ segm_mask = torch.sigmoid(self.segm_head(xz_vis)) * x_mask
202
+ if self.mask_pooling: # box mask guided pooling
203
+ pooled_feat = (segm_mask * xz_vis).sum((2, 3)) / segm_mask.sum((2, 3))
204
+ segm_mask = F.interpolate(segm_mask, img.size()[2:], mode='bilinear', align_corners=True)
205
+
206
+ # if not mask_pooling, do the pooling using all visual feats (equiv. to a uniform mask)
207
+ if pooled_feat is None:
208
+ pooled_feat = (x_mask * xz_vis).sum((2, 3)) / x_mask.sum((2, 3))
209
+
210
+ # bbox prediction
211
+ pred = self.head(pooled_feat)
212
+ pred = box_convert(pred, 'cxcywh', 'xyxy')
213
+
214
+ return pred, segm_mask
215
+
216
+ class HeadlessMachine(nn.Module):
217
+ def __init__(self,
218
+ backbone='resnet50', pretrained=True, embedding_size=256,
219
+ num_heads=8, num_layers=6, num_conv=4, dropout_p=0.1,
220
+ segmentation_head=True, mask_pooling=True):
221
+ super().__init__()
222
+
223
+ if backbone.endswith('+tr'):
224
+ self.vis_enc = enc.TransformerImageEncoder(
225
+ backbone=backbone.rstrip('+tr'),
226
+ out_channels=embedding_size,
227
+ pretrained=pretrained,
228
+ )
229
+
230
+ elif backbone.endswith('+fpn'):
231
+ self.vis_enc = enc.FPNImageEncoder(
232
+ backbone=backbone.rstrip('+fpn'),
233
+ out_channels=embedding_size,
234
+ pretrained=pretrained,
235
+ with_pos=False
236
+ )
237
+ else:
238
+ self.vis_enc = enc.ImageEncoder(
239
+ backbone=backbone,
240
+ out_channels=embedding_size,
241
+ pretrained=pretrained,
242
+ with_pos=False
243
+ )
244
+
245
+ # freeze ResNet stem
246
+ if 'resnet' in backbone:
247
+ self.vis_enc.backbone.conv1.requires_grad = False
248
+ self.vis_enc.backbone.conv1.eval()
249
+
250
+ self.vis_pos_emb = emb.LearnedPositionEmbedding2D(
251
+ embedding_dim=embedding_size
252
+ )
253
+
254
+ self.lan_enc = enc.LanguageEncoder(
255
+ out_features=embedding_size,
256
+ global_pooling=False,
257
+ dropout_p=dropout_p
258
+ )
259
+
260
+ self.lan_pos_emb = emb.LearnedPositionEmbedding1D(
261
+ embedding_dim=embedding_size
262
+ )
263
+
264
+ from transformers_pos import (
265
+ XTransformerEncoder,
266
+ TransformerEncoder,
267
+ TransformerEncoderLayer,
268
+ )
269
+
270
+ self.encoder = TransformerEncoder(
271
+ TransformerEncoderLayer(
272
+ d_model=embedding_size,
273
+ nhead=num_heads,
274
+ dropout=dropout_p,
275
+ batch_first=True
276
+ ),
277
+ num_layers=num_layers
278
+ )
279
+
280
+ # ---
281
+ # CONV PRE-HEAD (NECK?)
282
+
283
+ if num_conv > 0:
284
+ self.pre_head = nn.Sequential(*[
285
+ conv3x3(embedding_size, embedding_size) for _ in range(num_conv)
286
+ ])
287
+ self.pre_head.apply(weight_init)
288
+ else:
289
+ self.pre_head = nn.Identity()
290
+
291
+ # ---
292
+ # OUTPUT HEADS
293
+
294
+ # box prediction
295
+ self.head = nn.Sequential(
296
+ nn.Linear(embedding_size, 4, bias=True),
297
+ nn.Sigmoid()
298
+ )
299
+ self.head.apply(weight_init)
300
+
301
+ # box segmentation mask
302
+ self.segm_head = None
303
+ if segmentation_head:
304
+ self.segm_head = nn.Sequential(
305
+ nn.Conv2d(embedding_size, 1, (3, 3), 1, 1, bias=True),
306
+ #nn.Sigmoid()
307
+ )
308
+ self.segm_head.apply(weight_init)
309
+
310
+ # ---
311
+
312
+ self.mask_pooling = bool(mask_pooling)
313
+
314
+ if self.mask_pooling and self.segm_head is None:
315
+ raise RuntimeError('mask pooling w/o a segmentation head does not makes sense')
316
+
317
+ self.embedding_size = embedding_size
318
+
319
+ # def slow_param_ids(self, **kwargs):
320
+ # return []
321
+
322
+ def slow_param_ids(self, slow_visual_backbone=True, slow_language_backbone=True):
323
+ ids = []
324
+
325
+ if slow_visual_backbone:
326
+ ids += [id(p) for p in self.vis_enc.backbone.parameters()]
327
+ if hasattr(self.vis_enc, 'encoder'): # +tr
328
+ ids += [id(p) for p in self.vis_enc.encoder.parameters()]
329
+
330
+ if slow_language_backbone:
331
+ if isinstance(self.lan_enc, enc.LanguageEncoder):
332
+ ids += [id(p) for p in self.lan_enc.language_model.parameters()]
333
+ else:
334
+ ids += [id(p) for p in self.lan_enc.embeddings.parameters()]
335
+
336
+ return ids
337
+
338
+ def flatten(self, x):
339
+ N, D, H, W = x.size()
340
+ x = x.to(memory_format=torch.channels_last)
341
+ x = x.permute(0, 2, 3, 1).view(N, H*W, D)
342
+ return x # NxHWxD
343
+
344
+ def unflatten(self, x, size):
345
+ N, R, D = x.size()
346
+ H, W = size
347
+ assert R == H*W, 'wrong tensor size'
348
+ x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
349
+ x = x.view(N, D, H, W)
350
+ return x # NxDxHxW
351
+
352
+ def forward(self, input):
353
+ img, mask, tok = input['image'], input['mask'], input['tok']
354
+
355
+ # ---
356
+ # VISUAL EMBEDDINGS
357
+
358
+ x, x_mask = self.vis_enc(img, mask) # NxDxHxW, NxHxW
359
+ x_pos = self.vis_pos_emb(x, x_mask)
360
+
361
+ N, D, H, W = x.size() # save dims before flatten
362
+
363
+ x = self.flatten(x) # NxRxD
364
+ x_mask = self.flatten(x_mask).squeeze(-1) # NxR
365
+ x_pos = self.flatten(x_pos) # NxRxD
366
+
367
+ # ---
368
+ # LANGUAGE EMBEDDINGS
369
+
370
+ z, z_mask = self.lan_enc(tok) # NxTxD, NxT
371
+ z_pos = self.lan_pos_emb(z) # NxTxD
372
+
373
+ # ---
374
+ # V+L TRANSFORMER
375
+
376
+ # [...visual...]+[[CLS]...language tokens...[SEP]]
377
+ xz = torch.cat([x, z], dim=1)
378
+ xz_mask = torch.cat([x_mask, z_mask], dim=1)
379
+ xz_pos = torch.cat([x_pos, z_pos], dim=1)
380
+
381
+ xz = self.encoder(xz, src_key_padding_mask=(xz_mask==0), pos=xz_pos) #, size=(H,W))
382
+
383
+ # restore spatiality of visual embeddings after cross-modal encoding
384
+ xz_vis = xz[:, :H*W, ...]
385
+ xz_vis = self.unflatten(xz_vis, (H, W))
386
+
387
+ x_mask = self.unflatten(x_mask.unsqueeze(-1), (H, W))
388
+
389
+ # ---
390
+
391
+ # convolutional pre-head
392
+ xz_vis = self.pre_head(xz_vis)
393
+
394
+ # ---
395
+
396
+ # segmentation head w/ (opt.) pooling
397
+ segm_mask, pooled_feat = None, None
398
+ if self.segm_head is not None:
399
+ segm_mask = torch.sigmoid(self.segm_head(xz_vis)) * x_mask
400
+ if self.mask_pooling: # box mask guided pooling
401
+ pooled_feat = (segm_mask * xz_vis).sum((2, 3)) / segm_mask.sum((2, 3))
402
+ segm_mask = F.interpolate(segm_mask, img.size()[2:], mode='bilinear', align_corners=True)
403
+
404
+ # if not mask_pooling, do the pooling using all visual feats (equiv. to a uniform mask)
405
+ if pooled_feat is None:
406
+ pooled_feat = (x_mask * xz_vis).sum((2, 3)) / x_mask.sum((2, 3))
407
+
408
+ # bbox prediction
409
+ pred = self.head(pooled_feat)
410
+ pred = box_convert(pred, 'cxcywh', 'xyxy')
411
+
412
+ return pred, segm_mask
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Pillow==9.1.0
2
+ timm==0.6.7
3
+ torch==1.9.0
4
+ torchvision==0.10.0
5
+ transformers==4.12.3
testing_loading.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import IntuitionKillingMachine
2
+ from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
3
+ from torchvision.transforms import Compose
4
+ from encoders import get_tokenizer
5
+ from PIL import Image, ImageDraw
6
+ from zipfile import ZipFile
7
+ from copy import copy
8
+ import pandas as pd
9
+ import torch
10
+
11
+ def parse_model_args(model_path):
12
+ _, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
13
+ return {
14
+ 'dataset': dataset,
15
+ 'max_length': int(max_length),
16
+ 'input_size': int(input_size),
17
+ 'backbone': backbone,
18
+ 'num_heads': int(num_heads),
19
+ 'num_layers': int(num_layers),
20
+ 'num_conv': int(num_conv),
21
+ 'mu': float(mu),
22
+ 'mask_pooling': bool(mask_pooling == '1')
23
+ }
24
+
25
+
26
+ class Prober:
27
+ def __init__(self,
28
+ df_path=None,
29
+ dataset_path=None,
30
+ model_checkpoint=None):
31
+ params = parse_model_args(model_checkpoint)
32
+ mean = [0.485, 0.456, 0.406]
33
+ sdev = [0.229, 0.224, 0.225]
34
+ self.tokenizer = get_tokenizer()
35
+ self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
36
+ self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
37
+ self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
38
+ self.model = IntuitionKillingMachine(
39
+ backbone=params['backbone'],
40
+ pretrained=True,
41
+ num_heads=params['num_heads'],
42
+ num_layers=params['num_layers'],
43
+ num_conv=params['num_conv'],
44
+ segmentation_head=bool(params['mu'] > 0.0),
45
+ mask_pooling=params['mask_pooling']
46
+ )
47
+ self.transform = Compose([
48
+ ToTensor(),
49
+ Normalize(mean, sdev),
50
+ SquarePad(),
51
+ Resize(size=(params['input_size'], params['input_size'])),
52
+ NormalizeBoxCoords(),
53
+ ])
54
+ self.max_length = 30
55
+ self.zipfile = ZipFile(dataset_path, 'r')
56
+
57
+ @torch.no_grad()
58
+ def probe(self, idx, re, search_by_sample_id: bool= True):
59
+ if search_by_sample_id:
60
+ img_path, target, = self.df.loc[idx][['file_path','bbox']].values
61
+ else:
62
+ img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
63
+ img = Image.open(self.zipfile.open(img_path)).convert('RGB')
64
+ W0, H0 = img.size
65
+ sample = {
66
+ 'image': img,
67
+ 'image_size': (H0, W0), # image original size
68
+ 'bbox': torch.tensor([copy(target)]),
69
+ 'bbox_raw': torch.tensor([copy(target)]),
70
+ 'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
71
+ 'mask_bbox': None, # target bbox mask
72
+ }
73
+ print('inn bbox: ', sample['bbox'])
74
+ sample = self.transform(sample)
75
+ tok = self.tokenizer(re,
76
+ max_length=30,
77
+ return_tensors='pt',
78
+ truncation=True)
79
+ inn = {'image': torch.stack([sample['image']]),
80
+ 'mask': torch.stack([sample['mask']]),
81
+ 'bbox': torch.stack([sample['bbox']]),
82
+ 'tok': tok}
83
+ output = undo_box_transforms_batch(self.model(inn)[0],
84
+ [sample['tr_param']]).numpy().tolist()[0]
85
+ img1 = ImageDraw.Draw(img)
86
+ #img1.rectangle(target, outline ="#0000FF00", width=3)
87
+ img1.rectangle(output, outline ="#00FF0000", width=3)
88
+ return img
89
+
90
+ if __name__ == "__main__":
91
+ prober = Prober(
92
+ df_path = 'data/val-sim_metric.json',
93
+ dataset_path = "data/saiapr_tc-12.zip",
94
+ model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
95
+ )
96
+ prober.probe(0, "tree")
97
+ print("Done")
transformers_pos.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, Any
3
+
4
+ import torch
5
+
6
+ from torch import Tensor
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ def conv3x3(in_channels, out_channels, num_groups=0):
12
+ return nn.Sequential(
13
+ # Conv2d w/o bias since BatchNorm2d/GroupNorm already accounts for it (affine=True)
14
+ nn.Conv2d(in_channels, out_channels, (3, 3), 1, 1, bias=False),
15
+ nn.BatchNorm2d(out_channels) if num_groups < 1 else nn.GroupNorm(num_groups, out_channels),
16
+ nn.ReLU(inplace=True),
17
+ )
18
+
19
+
20
+ class XTransformerEncoder(nn.Module):
21
+ __constants__ = ['norm']
22
+ def __init__(self, encoder_layer, num_layers, num_conv=2, norm=None):
23
+ super().__init__()
24
+ self.layers = _get_clones(encoder_layer, num_layers)
25
+ self.num_layers = num_layers
26
+ self.norm = norm
27
+
28
+ d_model = encoder_layer.linear1.in_features
29
+ self.conv = nn.ModuleList([
30
+ nn.Sequential(*[
31
+ conv3x3(d_model, d_model) for _ in range(num_conv)
32
+ ]) for _ in range(num_layers)
33
+ ])
34
+
35
+ def flatten(self, x):
36
+ N, D, H, W = x.size()
37
+ x = x.to(memory_format=torch.channels_last)
38
+ x = x.permute(0, 2, 3, 1).view(N, H*W, D)
39
+ return x # NxHWxD
40
+
41
+ def unflatten(self, x, size):
42
+ N, R, D = x.size()
43
+ H, W = size
44
+ assert R == H*W, 'wrong tensor size'
45
+ x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
46
+ x = x.view(N, D, H, W)
47
+ return x # NxDxHxW
48
+
49
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, size=None) -> Tensor:
50
+ output = src
51
+
52
+ for i, mod in enumerate(self.layers):
53
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
54
+
55
+ vis = self.unflatten(output[:, :size[0]*size[1]], size)
56
+ vis = self.flatten(self.conv[i](vis))
57
+
58
+ output = torch.cat([vis, output[:, size[0]*size[1]:]], dim=1)
59
+
60
+ if self.norm is not None:
61
+ output = self.norm(output)
62
+
63
+ return output
64
+
65
+
66
+ class TransformerEncoder(nn.Module):
67
+ r"""TransformerEncoder is a stack of N encoder layers
68
+
69
+ Args:
70
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
71
+ num_layers: the number of sub-encoder-layers in the encoder (required).
72
+ norm: the layer normalization component (optional).
73
+
74
+ Examples::
75
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
76
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
77
+ >>> src = torch.rand(10, 32, 512)
78
+ >>> out = transformer_encoder(src)
79
+ """
80
+ __constants__ = ['norm']
81
+
82
+ def __init__(self, encoder_layer, num_layers, norm=None):
83
+ super(TransformerEncoder, self).__init__()
84
+ self.layers = _get_clones(encoder_layer, num_layers)
85
+ self.num_layers = num_layers
86
+ self.norm = norm
87
+
88
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None) -> Tensor:
89
+ r"""Pass the input through the encoder layers in turn.
90
+
91
+ Args:
92
+ src: the sequence to the encoder (required).
93
+ mask: the mask for the src sequence (optional).
94
+ src_key_padding_mask: the mask for the src keys per batch (optional).
95
+
96
+ Shape:
97
+ see the docs in Transformer class.
98
+ """
99
+ output = src
100
+
101
+ for mod in self.layers:
102
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
103
+
104
+ if self.norm is not None:
105
+ output = self.norm(output)
106
+
107
+ return output
108
+
109
+
110
+ class TransformerEncoderLayer(nn.Module):
111
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
112
+ This standard encoder layer is based on the paper "Attention Is All You Need".
113
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
114
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
115
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
116
+ in a different way during application.
117
+
118
+ Args:
119
+ d_model: the number of expected features in the input (required).
120
+ nhead: the number of heads in the multiheadattention models (required).
121
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
122
+ dropout: the dropout value (default=0.1).
123
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
124
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
125
+ batch_first: If ``True``, then the input and output tensors are provided
126
+ as (batch, seq, feature). Default: ``False``.
127
+
128
+ Examples::
129
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
130
+ >>> src = torch.rand(10, 32, 512)
131
+ >>> out = encoder_layer(src)
132
+
133
+ Alternatively, when ``batch_first`` is ``True``:
134
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
135
+ >>> src = torch.rand(32, 10, 512)
136
+ >>> out = encoder_layer(src)
137
+ """
138
+ __constants__ = ['batch_first']
139
+
140
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
141
+ layer_norm_eps=1e-5, batch_first=False,
142
+ device=None, dtype=None) -> None:
143
+ factory_kwargs = {'device': device, 'dtype': dtype}
144
+ super(TransformerEncoderLayer, self).__init__()
145
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
146
+ **factory_kwargs)
147
+ # Implementation of Feedforward model
148
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
149
+ self.dropout = nn.Dropout(dropout)
150
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
151
+
152
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
153
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
154
+ self.dropout1 = nn.Dropout(dropout)
155
+ self.dropout2 = nn.Dropout(dropout)
156
+
157
+ self.activation = _get_activation_fn(activation)
158
+
159
+ def __setstate__(self, state):
160
+ if 'activation' not in state:
161
+ state['activation'] = F.relu
162
+ super(TransformerEncoderLayer, self).__setstate__(state)
163
+
164
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None) -> Tensor:
165
+ r"""Pass the input through the encoder layer.
166
+
167
+ Args:
168
+ src: the sequence to the encoder layer (required).
169
+ src_mask: the mask for the src sequence (optional).
170
+ src_key_padding_mask: the mask for the src keys per batch (optional).
171
+
172
+ Shape:
173
+ see the docs in Transformer class.
174
+ """
175
+
176
+ q = k = src if pos is None else src + pos
177
+
178
+ src2 = self.self_attn(q, k, src, attn_mask=src_mask,
179
+ key_padding_mask=src_key_padding_mask)[0]
180
+ src = src + self.dropout1(src2)
181
+ src = self.norm1(src)
182
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
183
+ src = src + self.dropout2(src2)
184
+ src = self.norm2(src)
185
+ return src
186
+
187
+
188
+ def _get_clones(module, N):
189
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
190
+
191
+
192
+ def _get_activation_fn(activation):
193
+ if activation == "relu":
194
+ return F.relu
195
+ elif activation == "gelu":
196
+ return F.gelu
197
+
198
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
transforms.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torchvision import transforms
4
+
5
+ from torchvision.transforms import Compose
6
+
7
+ from PIL import Image
8
+
9
+
10
+ class ToTensor(transforms.ToTensor):
11
+ def __call__(self, input):
12
+ if not isinstance(input, dict):
13
+ return super().__call__(input)
14
+ assert 'image' in input
15
+ input['image'] = super().__call__(input['image'])
16
+ return input
17
+
18
+
19
+ class Normalize(transforms.Normalize):
20
+ def __call__(self, input):
21
+ if not isinstance(input, dict):
22
+ return super().__call__(input)
23
+ assert 'image' in input
24
+ input['image'] = super().__call__(input['image'])
25
+ return input
26
+
27
+
28
+ class NormalizeBoxCoords(transforms.ToTensor):
29
+ def __call__(self, input):
30
+ if not isinstance(input, dict):
31
+ return super().__call__(input)
32
+ assert 'image' in input and 'bbox' in input
33
+ _, H, W = input['image'].size()
34
+ input['bbox'][:, (0, 2)] /= W
35
+ input['bbox'][:, (1, 3)] /= H
36
+
37
+ if 'tr_param' not in input:
38
+ input['tr_param'] = []
39
+ input['tr_param'].append({'normalize_box_coords': (H, W)})
40
+
41
+ return input
42
+
43
+
44
+ class SquarePad(torch.nn.Module):
45
+ def __call__(self, input):
46
+ if isinstance(input, Image.Image):
47
+ raise NotImplementedError('put the SquarePad transform after ToTensor')
48
+
49
+ assert 'image' in input
50
+ _, h, w = input['image'].size()
51
+
52
+ max_wh = max(w, h)
53
+ xp = int(0.5 * (max_wh - w))
54
+ yp = int(0.5 * (max_wh - h))
55
+ padding = (xp, yp, (max_wh-xp)-w, (max_wh-yp)-h)
56
+
57
+ input['image'] = transforms.functional.pad(
58
+ input['image'], padding, fill=0, padding_mode='constant'
59
+ )
60
+ # input['image'] = transforms.functional.pad(
61
+ # input['image'], padding, padding_mode='edge'
62
+ # )
63
+
64
+ if 'mask' in input:
65
+ input['mask'] = transforms.functional.pad(
66
+ input['mask'], padding, fill=0, padding_mode='constant'
67
+ )
68
+
69
+ if 'bbox' in input:
70
+ input['bbox'][:, (0, 2)] += xp
71
+ input['bbox'][:, (1, 3)] += yp
72
+
73
+ if 'tr_param' not in input:
74
+ input['tr_param'] = []
75
+ input['tr_param'].append({'square_pad': padding})
76
+
77
+ return input
78
+
79
+
80
+ class Resize(transforms.Resize):
81
+ def __call__(self, input):
82
+ if not isinstance(input, dict):
83
+ return super().__call__(input)
84
+
85
+ assert 'image' in input
86
+
87
+ if not torch.is_tensor(input['image']):
88
+ raise NotImplementedError('put the Resize transform after ToTensor')
89
+
90
+ _, img_h, img_w = input['image'].size()
91
+
92
+ if isinstance(self.size, int):
93
+ dst_h = self.size if img_h < img_w else int(self.size * img_h / img_w)
94
+ dst_w = self.size if img_w < img_h else int(self.size * img_w / img_h)
95
+ else:
96
+ dst_h, dst_w = self.size
97
+
98
+ input['image'] = super().__call__(input['image'])
99
+
100
+ if 'mask' in input:
101
+ input['mask'] = super().__call__(input['mask'])
102
+
103
+ sx, sy = dst_w / img_w, dst_h / img_h
104
+
105
+ if 'bbox' in input:
106
+ input['bbox'][:, (0, 2)] *= sx
107
+ input['bbox'][:, (1, 3)] *= sy
108
+
109
+ if 'tr_param' not in input:
110
+ input['tr_param'] = []
111
+ input['tr_param'].append({'resize': (sx, sy)})
112
+
113
+ return input
114
+
115
+
116
+ class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
117
+ def __call__(self, input):
118
+ if not isinstance(input, dict):
119
+ return super().__call__(input)
120
+
121
+ assert 'image' in input
122
+
123
+ if not torch.is_tensor(input['image']):
124
+ raise NotImplementedError('use Resize after ToTensor')
125
+
126
+ result = super().__call__(input['image'])
127
+ if result is input['image']: # not flipped
128
+ return input
129
+ input['image'] = result
130
+
131
+ if 'mask' in input:
132
+ input['mask'] = torch.flip(input['mask'], dims=(-1,))
133
+
134
+ img_w = input['image'].size(2)
135
+
136
+ if 'bbox' in input:
137
+ input['bbox'][:, (0, 2)] = img_w - input['bbox'][:, (2, 0)]
138
+
139
+ if 'expr' in input:
140
+ input['expr'] = input['expr'].replace('left', '<LEFT>').replace('right', 'left').replace('<LEFT>', 'right')
141
+
142
+ return input
143
+
144
+
145
+ class RandomAffine(transforms.RandomAffine):
146
+ def get_params(self, *args, **kwargs):
147
+ self.params = super().get_params(*args, **kwargs)
148
+ return self.params
149
+
150
+ def __call__(self, input):
151
+ if not isinstance(input, dict):
152
+ return super().__call__(input)
153
+
154
+ assert 'image' in input
155
+
156
+ if not torch.is_tensor(input['image']):
157
+ raise NotImplementedError('put the Resize transform after ToTensor')
158
+
159
+ #self.fill = input['image'].mean((1,2)) # set fill value to the mean pixel value
160
+ result = super().__call__(input['image'])
161
+ if result is input['image']: # not transformed
162
+ return input
163
+ input['image'] = result
164
+
165
+ _, img_h, img_w = input['image'].size()
166
+
167
+ angle, translate, scale, shear = self.params
168
+ center = (img_w * 0.5, img_h * 0.5)
169
+ matrix = transforms.functional._get_inverse_affine_matrix(center, angle, translate, scale, shear)
170
+ matrix = torch.FloatTensor([matrix[:3], matrix[3:], [0, 0, 1]])
171
+ matrix = torch.linalg.inv(matrix)
172
+
173
+ if 'mask' in input:
174
+ input['mask'] = transforms.functional.affine(
175
+ input['mask'], *self.params, self.interpolation, self.fill
176
+ )
177
+
178
+ if 'bbox' in input:
179
+ for i, (x1, y1, x2, y2) in enumerate(input['bbox']):
180
+ pt = matrix @ torch.FloatTensor([
181
+ [x1, y1, 1],
182
+ [x2, y1, 1],
183
+ [x2, y2, 1],
184
+ [x1, y2, 1]
185
+ ]).T
186
+ x_min, y_min, _ = pt.min(dim=1).values
187
+ x_max, y_max, _ = pt.max(dim=1).values
188
+ input['bbox'][i, :] = torch.FloatTensor([x_min, y_min, x_max, y_max])
189
+
190
+ # if 'tr_param' not in input:
191
+ # input['tr_param'] = []
192
+ # input['tr_param'].append({'random_affine': matrix[:2, :].tolist()})
193
+
194
+ return input
195
+
196
+
197
+ class ColorJitter(transforms.ColorJitter):
198
+ def __call__(self, input):
199
+ if not isinstance(input, dict):
200
+ return super().__call__(input)
201
+ assert 'image' in input
202
+ input['image'] = super().__call__(input['image'])
203
+ return input
204
+
205
+
206
+ def get_transform(split, input_size=512):
207
+ mean = [0.485, 0.456, 0.406]
208
+ sdev = [0.229, 0.224, 0.225]
209
+
210
+ if split in ('train', 'trainval'):
211
+ transform = Compose([
212
+ # ColorJitter(brightness=0.5, saturation=0.5), # before normalization
213
+ ToTensor(),
214
+ Normalize(mean, sdev), # first normalize so that the mean is ~0
215
+ SquarePad(), # zero pad (approx mean pixel value)
216
+ Resize(size=(input_size, input_size)),
217
+ # RandomHorizontalFlip(p=0.5),
218
+ RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)),
219
+ NormalizeBoxCoords(),
220
+ ])
221
+ elif split in ('val', 'test', 'testA', 'testB', 'testC'):
222
+ transform = Compose([
223
+ ToTensor(),
224
+ Normalize(mean, sdev),
225
+ SquarePad(),
226
+ Resize(size=(input_size, input_size)),
227
+ NormalizeBoxCoords(),
228
+ ])
229
+ elif split in ('visu',):
230
+ transform = Compose([
231
+ ToTensor(),
232
+ SquarePad(),
233
+ Resize(size=(input_size, input_size)),
234
+ NormalizeBoxCoords(),
235
+ ])
236
+ else:
237
+ raise ValueError(f'\'{split}\' is not a valid data split')
238
+
239
+ return transform
240
+
241
+
242
+ def denormalize(img):
243
+ mean = [0.485, 0.456, 0.406]
244
+ sdev = [0.229, 0.224, 0.225]
245
+ return Normalize(
246
+ mean=[-m/s for m, s in zip(mean, sdev)], std=[1./s for s in sdev]
247
+ )(img)
248
+
249
+
250
+ def undo_box_transforms(bbox, tr_param):
251
+ # undo validation mode transformations
252
+ bbox = bbox.clone()
253
+ for tr in tr_param[::-1]:
254
+ if 'resize' in tr:
255
+ sx, sy = tr['resize']
256
+ bbox[:, (0, 2)] /= sx
257
+ bbox[:, (1, 3)] /= sy
258
+ elif 'square_pad' in tr:
259
+ px, py, _, _ = tr['square_pad']
260
+ bbox[:, (0, 2)] -= px
261
+ bbox[:, (1, 3)] -= py
262
+ elif 'normalize_box_coords' in tr:
263
+ img_h, img_w = tr['normalize_box_coords']
264
+ bbox[:, (0, 2)] *= img_w
265
+ bbox[:, (1, 3)] *= img_h
266
+ else:
267
+ continue
268
+ return bbox
269
+
270
+
271
+ def undo_box_transforms_batch(bbox, tr_param):
272
+ output = []
273
+ for i in range(bbox.size(0)):
274
+ bb = undo_box_transforms(torch.atleast_2d(bbox[i]), tr_param[i])
275
+ output.append(bb)
276
+ return torch.cat(output, dim=0)