Base app.
Browse files- .gitattributes +3 -0
- README.md +5 -5
- app.py +119 -0
- backbones.py +82 -0
- cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt +3 -0
- datasets.py +282 -0
- embeddings.py +182 -0
- encoders.py +414 -0
- models.py +412 -0
- requirements.txt +5 -0
- testing_loading.py +97 -0
- transformers_pos.py +198 -0
- transforms.py +276 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
data/saiapr_tc-12.zip filter=lfs diff=lfs merge=lfs -text
|
36 |
+
cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/val-sim_metric.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: ProbingREC
|
3 |
+
emoji: π
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.4
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import IntuitionKillingMachine
|
2 |
+
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
|
3 |
+
from torchvision.transforms import Compose
|
4 |
+
from encoders import get_tokenizer
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from copy import copy
|
8 |
+
import gradio as gr
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
|
12 |
+
def parse_model_args(model_path):
|
13 |
+
_, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
|
14 |
+
return {
|
15 |
+
'dataset': dataset,
|
16 |
+
'max_length': int(max_length),
|
17 |
+
'input_size': int(input_size),
|
18 |
+
'backbone': backbone,
|
19 |
+
'num_heads': int(num_heads),
|
20 |
+
'num_layers': int(num_layers),
|
21 |
+
'num_conv': int(num_conv),
|
22 |
+
'mu': float(mu),
|
23 |
+
'mask_pooling': bool(mask_pooling == '1')
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
class Prober:
|
28 |
+
def __init__(self,
|
29 |
+
df_path=None,
|
30 |
+
dataset_path=None,
|
31 |
+
model_checkpoint=None):
|
32 |
+
params = parse_model_args(model_checkpoint)
|
33 |
+
mean = [0.485, 0.456, 0.406]
|
34 |
+
sdev = [0.229, 0.224, 0.225]
|
35 |
+
self.tokenizer = get_tokenizer()
|
36 |
+
self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
|
37 |
+
self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
|
38 |
+
self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
|
39 |
+
self.model = IntuitionKillingMachine(
|
40 |
+
backbone=params['backbone'],
|
41 |
+
pretrained=True,
|
42 |
+
num_heads=params['num_heads'],
|
43 |
+
num_layers=params['num_layers'],
|
44 |
+
num_conv=params['num_conv'],
|
45 |
+
segmentation_head=bool(params['mu'] > 0.0),
|
46 |
+
mask_pooling=params['mask_pooling']
|
47 |
+
)
|
48 |
+
self.load_model(model_checkpoint)
|
49 |
+
self.transform = Compose([
|
50 |
+
ToTensor(),
|
51 |
+
Normalize(mean, sdev),
|
52 |
+
SquarePad(),
|
53 |
+
Resize(size=(params['input_size'], params['input_size'])),
|
54 |
+
NormalizeBoxCoords(),
|
55 |
+
])
|
56 |
+
self.max_length = 30
|
57 |
+
self.zipfile = ZipFile(dataset_path, 'r')
|
58 |
+
|
59 |
+
def load_model(self, model_checkpoint):
|
60 |
+
checkpoint = torch.load(
|
61 |
+
model_checkpoint, map_location=lambda storage, loc: storage
|
62 |
+
)
|
63 |
+
|
64 |
+
# strip 'model.' from pl checkpoint
|
65 |
+
state_dict = {
|
66 |
+
k[len('model.'):]: v
|
67 |
+
for k, v in checkpoint['state_dict'].items()
|
68 |
+
}
|
69 |
+
|
70 |
+
missing, _ = self.model.load_state_dict(state_dict, strict=False)
|
71 |
+
|
72 |
+
# ensure the only missing keys are those of the segmentation head only
|
73 |
+
assert [k for k in missing if 'segm' not in k] == []
|
74 |
+
|
75 |
+
self.model = self.model.eval()
|
76 |
+
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def probe(self, idx, re, search_by_sample_id: bool= True):
|
80 |
+
if search_by_sample_id:
|
81 |
+
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
|
82 |
+
else:
|
83 |
+
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
|
84 |
+
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
85 |
+
W0, H0 = img.size
|
86 |
+
sample = {
|
87 |
+
'image': img,
|
88 |
+
'image_size': (H0, W0), # image original size
|
89 |
+
'bbox': torch.tensor([copy(target)]),
|
90 |
+
'bbox_raw': torch.tensor([copy(target)]),
|
91 |
+
'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
|
92 |
+
'mask_bbox': None, # target bbox mask
|
93 |
+
}
|
94 |
+
sample = self.transform(sample)
|
95 |
+
tok = self.tokenizer(re,
|
96 |
+
max_length=30,
|
97 |
+
return_tensors='pt',
|
98 |
+
truncation=True)
|
99 |
+
inn = {'image': torch.stack([sample['image']]),
|
100 |
+
'mask': torch.stack([sample['mask']]),
|
101 |
+
'tok': tok}
|
102 |
+
output = undo_box_transforms_batch(self.model(inn)[0],
|
103 |
+
[sample['tr_param']]).numpy().tolist()[0]
|
104 |
+
img1 = ImageDraw.Draw(img)
|
105 |
+
#img1.rectangle(target, outline ="#0000FF00", width=3)
|
106 |
+
img1.rectangle(output, outline ="#00FF0000", width=3)
|
107 |
+
return img
|
108 |
+
|
109 |
+
|
110 |
+
prober = Prober(
|
111 |
+
df_path = 'data/val-sim_metric.json',
|
112 |
+
dataset_path = "data/saiapr_tc-12.zip",
|
113 |
+
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
|
114 |
+
)
|
115 |
+
|
116 |
+
demo = gr.Interface(fn=prober.probe, inputs=["number", "text", "checkbox"], outputs="image")
|
117 |
+
|
118 |
+
demo.queue(concurrency_count=10)
|
119 |
+
demo.launch(debug=True)
|
backbones.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
6 |
+
|
7 |
+
from torchvision.models import resnet, detection, segmentation
|
8 |
+
|
9 |
+
import timm
|
10 |
+
|
11 |
+
|
12 |
+
# https://detectron2.readthedocs.io/en/latest/modules/layers.html#detectron2.layers.FrozenBatchNorm2d.convert_frozen_batchnorm
|
13 |
+
@torch.no_grad()
|
14 |
+
def convert_frozen_batchnorm(module):
|
15 |
+
bn_module = (
|
16 |
+
nn.modules.batchnorm.BatchNorm2d,
|
17 |
+
nn.modules.batchnorm.SyncBatchNorm
|
18 |
+
)
|
19 |
+
res = module
|
20 |
+
if isinstance(module, bn_module):
|
21 |
+
res = FrozenBatchNorm2d(module.num_features)
|
22 |
+
if module.affine:
|
23 |
+
res.weight.data = module.weight.data.clone().detach()
|
24 |
+
res.bias.data = module.bias.data.clone().detach()
|
25 |
+
res.running_mean.data = module.running_mean.data
|
26 |
+
res.running_var.data = module.running_var.data
|
27 |
+
res.eps = module.eps
|
28 |
+
else:
|
29 |
+
for name, child in module.named_children():
|
30 |
+
new_child = convert_frozen_batchnorm(child)
|
31 |
+
if new_child is not child:
|
32 |
+
res.add_module(name, new_child)
|
33 |
+
return res
|
34 |
+
|
35 |
+
|
36 |
+
def get_backbone(backbone, pretrained=True):
|
37 |
+
if backbone in ('resnet18', 'resnet34', 'resnet50', 'resnet101'):
|
38 |
+
# pretrained on ImageNet for classification
|
39 |
+
model = resnet.__dict__[backbone](
|
40 |
+
pretrained=pretrained, norm_layer=FrozenBatchNorm2d
|
41 |
+
)
|
42 |
+
elif backbone == 'resnet50d':
|
43 |
+
# pretrained on COCO for detection
|
44 |
+
model = convert_frozen_batchnorm(
|
45 |
+
detection.fasterrcnn_resnet50_fpn(pretrained=pretrained).backbone.body
|
46 |
+
)
|
47 |
+
elif backbone == 'resnet50s':
|
48 |
+
# pretrained on COCO for segmentation
|
49 |
+
model = convert_frozen_batchnorm(
|
50 |
+
segmentation.deeplabv3_resnet50(pretrained=pretrained).backbone
|
51 |
+
)
|
52 |
+
elif backbone == 'resnet101s':
|
53 |
+
# pretrained on COCO for segmentation
|
54 |
+
model = convert_frozen_batchnorm(
|
55 |
+
segmentation.deeplabv3_resnet101(pretrained=pretrained).backbone
|
56 |
+
)
|
57 |
+
|
58 |
+
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
|
59 |
+
# model = convert_frozen_batchnorm(
|
60 |
+
# timm.create_model(
|
61 |
+
# backbone.replace('-', '_'),
|
62 |
+
# pretrained=True,
|
63 |
+
# features_only=True,
|
64 |
+
# #out_indices=(1, 2, 3, 4)
|
65 |
+
# )
|
66 |
+
# )
|
67 |
+
model = convert_frozen_batchnorm(
|
68 |
+
timm.create_model(
|
69 |
+
backbone.replace('-', '_'),
|
70 |
+
pretrained=pretrained,
|
71 |
+
num_classes=0,
|
72 |
+
global_pool=''
|
73 |
+
)
|
74 |
+
)
|
75 |
+
|
76 |
+
else:
|
77 |
+
raise RuntimeError(f'{backbone} is not a valid backbone')
|
78 |
+
|
79 |
+
# empty cache (dealloc modules other than the backbone)
|
80 |
+
torch.cuda.empty_cache()
|
81 |
+
|
82 |
+
return model
|
cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2aaaf1696c537a1a2b049ddfa150d36770b6e92c8524ca4e3706755c00648f26
|
3 |
+
size 1752031089
|
datasets.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import json
|
4 |
+
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import ijson
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from torchvision.transforms import ToTensor
|
16 |
+
|
17 |
+
from torchvision.ops import box_convert, clip_boxes_to_image
|
18 |
+
|
19 |
+
from re_classifier import REClassifier
|
20 |
+
|
21 |
+
from utils import progressbar
|
22 |
+
|
23 |
+
|
24 |
+
def collate_fn(batch):
|
25 |
+
image = torch.stack([s['image'] for s in batch], dim=0)
|
26 |
+
|
27 |
+
image_size = torch.FloatTensor([s['image_size'] for s in batch])
|
28 |
+
|
29 |
+
# bbox = torch.stack([s['bbox'] for s in batch], dim=0)
|
30 |
+
bbox = torch.cat([s['bbox'] for s in batch], dim=0)
|
31 |
+
|
32 |
+
# bbox_raw = torch.stack([s['bbox_raw'] for s in batch], dim=0)
|
33 |
+
bbox_raw = torch.cat([s['bbox_raw'] for s in batch], dim=0)
|
34 |
+
|
35 |
+
expr = [s['expr'] for s in batch]
|
36 |
+
|
37 |
+
tok = None
|
38 |
+
if batch[0]['tok'] is not None:
|
39 |
+
tok = {
|
40 |
+
'input_ids': torch.cat([s['tok']['input_ids'] for s in batch], dim=0),
|
41 |
+
'attention_mask': torch.cat([s['tok']['attention_mask'] for s in batch], dim=0)
|
42 |
+
}
|
43 |
+
|
44 |
+
# dynamic batching
|
45 |
+
max_length = max([s['tok']['length'] for s in batch])
|
46 |
+
tok = {
|
47 |
+
'input_ids': tok['input_ids'][:, :max_length],
|
48 |
+
'attention_mask': tok['attention_mask'][:, :max_length],
|
49 |
+
}
|
50 |
+
|
51 |
+
mask = None
|
52 |
+
if batch[0]['mask'] is not None:
|
53 |
+
mask = torch.stack([s['mask'] for s in batch], dim=0)
|
54 |
+
|
55 |
+
mask_bbox = None
|
56 |
+
if batch[0]['mask_bbox'] is not None:
|
57 |
+
mask_bbox = torch.stack([s['mask_bbox'] for s in batch], dim=0)
|
58 |
+
|
59 |
+
tr_param = [s['tr_param'] for s in batch]
|
60 |
+
|
61 |
+
return {
|
62 |
+
'image': image,
|
63 |
+
'image_size': image_size,
|
64 |
+
'bbox': bbox,
|
65 |
+
'bbox_raw': bbox_raw,
|
66 |
+
'expr': expr,
|
67 |
+
'tok': tok,
|
68 |
+
'tr_param': tr_param,
|
69 |
+
'mask': mask,
|
70 |
+
'mask_bbox': mask_bbox,
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
class RECDataset(torch.utils.data.Dataset):
|
75 |
+
def __init__(self, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False):
|
76 |
+
super().__init__()
|
77 |
+
self.samples = [] # list of samples: [(file_name, expresion, bbox)]
|
78 |
+
self.transform = transform
|
79 |
+
self.tokenizer = tokenizer
|
80 |
+
self.max_length = int(max_length)
|
81 |
+
self.with_mask_bbox = bool(with_mask_bbox)
|
82 |
+
|
83 |
+
def tokenize(self, inp, max_length):
|
84 |
+
return self.tokenizer(
|
85 |
+
inp,
|
86 |
+
return_tensors='pt',
|
87 |
+
padding='max_length',
|
88 |
+
return_token_type_ids=False,
|
89 |
+
return_attention_mask=True,
|
90 |
+
add_special_tokens=True,
|
91 |
+
truncation=True,
|
92 |
+
max_length=max_length
|
93 |
+
)
|
94 |
+
|
95 |
+
def print_stats(self):
|
96 |
+
print(f'{len(self.samples)} samples')
|
97 |
+
lens = [len(expr.split()) for _, expr, _ in self.samples]
|
98 |
+
print('expression lengths stats: '
|
99 |
+
f'min={np.min(lens):.1f}, '
|
100 |
+
f'mean={np.mean(lens):.1f}, '
|
101 |
+
f'median={np.median(lens):.1f}, '
|
102 |
+
f'max={np.max(lens):.1f}, '
|
103 |
+
f'99.9P={np.percentile(lens, 99.9):.1f}'
|
104 |
+
)
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.samples)
|
108 |
+
|
109 |
+
def __getitem__(self, idx):
|
110 |
+
file_name, expr, bbox = self.samples[idx]
|
111 |
+
|
112 |
+
if not os.path.exists(file_name):
|
113 |
+
raise IOError(f'{file_name} not found')
|
114 |
+
img = Image.open(file_name).convert('RGB')
|
115 |
+
|
116 |
+
# if isinstance(expr, (list, tuple)):
|
117 |
+
# expr = random.choice(expr)
|
118 |
+
|
119 |
+
# image size as read from disk (PIL)
|
120 |
+
W0, H0 = img.size
|
121 |
+
|
122 |
+
# # ensure box coordinates fall inside the image
|
123 |
+
# bbox = clip_boxes_to_image(bbox, (H0, W0))
|
124 |
+
# assert torch.all(bbox[:, (0, 1)] <= bbox[:, (2, 3)]) # xyxy format
|
125 |
+
|
126 |
+
sample = {
|
127 |
+
'image': img,
|
128 |
+
'image_size': (H0, W0), # image original size
|
129 |
+
'bbox': bbox.clone(), # box transformations are inplace ops
|
130 |
+
'bbox_raw': bbox.clone(), # raw boxes w/o any transformation (in pixels)
|
131 |
+
'expr': expr,
|
132 |
+
'tok': None,
|
133 |
+
'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
|
134 |
+
'mask_bbox': None, # target bbox mask
|
135 |
+
}
|
136 |
+
|
137 |
+
# apply transforms
|
138 |
+
if self.transform is None:
|
139 |
+
sample['image'] = ToTensor()(sample['image'])
|
140 |
+
else:
|
141 |
+
sample = self.transform(sample)
|
142 |
+
|
143 |
+
# tokenize after the transformations (just in case there where a left<>right substitution)
|
144 |
+
if self.tokenizer is not None:
|
145 |
+
sample['tok'] = self.tokenize(sample['expr'], self.max_length)
|
146 |
+
sample['tok']['length'] = sample['tok']['attention_mask'].sum(1).item()
|
147 |
+
|
148 |
+
# bbox segmentation mask
|
149 |
+
if self.with_mask_bbox:
|
150 |
+
# image size after transforms
|
151 |
+
_, H, W = sample['image'].size()
|
152 |
+
|
153 |
+
# transformed bbox in pixels
|
154 |
+
bbox = sample['bbox'].clone()
|
155 |
+
bbox[:, (0, 2)] *= W
|
156 |
+
bbox[:, (1, 3)] *= H
|
157 |
+
bbox = clip_boxes_to_image((bbox + 0.5).long(), (H, W))
|
158 |
+
|
159 |
+
# output mask
|
160 |
+
sample['mask_bbox'] = torch.zeros((1, H, W), dtype=torch.float32)
|
161 |
+
for x1, y1, x2, y2 in bbox.tolist():
|
162 |
+
sample['mask_bbox'][:, y1:y2+1, x1:x2+1] = 1.0
|
163 |
+
|
164 |
+
return sample
|
165 |
+
|
166 |
+
|
167 |
+
class RegionDescriptionsVisualGnome(RECDataset):
|
168 |
+
def __init__(self, data_root, transform=None, tokenizer=None,
|
169 |
+
max_length=32, with_mask_bbox=False):
|
170 |
+
super().__init__(transform=transform, tokenizer=tokenizer,
|
171 |
+
max_length=max_length, with_mask_bbox=with_mask_bbox)
|
172 |
+
|
173 |
+
|
174 |
+
# if available, read COCO IDs from the val, testA and testB splits from
|
175 |
+
# the RefCOCO dataset
|
176 |
+
try:
|
177 |
+
with open('./refcoco_valtest_ids.txt', 'r') as fh:
|
178 |
+
refcoco_ids = [int(lin.strip()) for lin in fh.readlines()]
|
179 |
+
except:
|
180 |
+
refcoco_ids = []
|
181 |
+
|
182 |
+
def path_from_url(fname):
|
183 |
+
return os.path.join(data_root, fname[fname.index('VG_100K'):])
|
184 |
+
|
185 |
+
with open(os.path.join(data_root, 'image_data.json'), 'r') as f:
|
186 |
+
image_data = {
|
187 |
+
data['image_id']: path_from_url(data['url'])
|
188 |
+
for data in json.load(f)
|
189 |
+
if data['coco_id'] is None or data['coco_id'] not in refcoco_ids
|
190 |
+
}
|
191 |
+
print(f'{len(image_data)} images')
|
192 |
+
|
193 |
+
self.samples = []
|
194 |
+
|
195 |
+
with open(os.path.join(data_root, 'region_descriptions.json'), 'r') as f:
|
196 |
+
for record in progressbar(ijson.items(f, 'item.regions.item'), desc='loading data'):
|
197 |
+
if record['image_id'] not in image_data:
|
198 |
+
continue
|
199 |
+
file_name = image_data[record['image_id']]
|
200 |
+
|
201 |
+
expr = record['phrase']
|
202 |
+
|
203 |
+
bbox = [record['x'], record['y'], record['width'], record['height']]
|
204 |
+
bbox = torch.atleast_2d(torch.FloatTensor(bbox))
|
205 |
+
bbox = box_convert(bbox, 'xywh', 'xyxy') # xyxy
|
206 |
+
|
207 |
+
self.samples.append((file_name, expr, bbox))
|
208 |
+
|
209 |
+
self.print_stats()
|
210 |
+
|
211 |
+
|
212 |
+
class ReferDataset(RECDataset):
|
213 |
+
def __init__(self, data_root, dataset, split_by, split, transform=None,
|
214 |
+
tokenizer=None, max_length=32, with_mask_bbox=False):
|
215 |
+
super().__init__(transform=transform, tokenizer=tokenizer,
|
216 |
+
max_length=max_length, with_mask_bbox=with_mask_bbox)
|
217 |
+
|
218 |
+
# https://github.com/lichengunc/refer
|
219 |
+
try:
|
220 |
+
import sys
|
221 |
+
sys.path.append('refer')
|
222 |
+
from refer import REFER
|
223 |
+
except:
|
224 |
+
raise RuntimeError('create a symlink to valid refer compilation '
|
225 |
+
'(see https://github.com/lichengunc/refer)')
|
226 |
+
|
227 |
+
refer = REFER(data_root, dataset, split_by)
|
228 |
+
ref_ids = sorted(refer.getRefIds(split=split))
|
229 |
+
|
230 |
+
self.samples = []
|
231 |
+
|
232 |
+
for rid in progressbar(ref_ids, desc='loading data'):
|
233 |
+
ref = refer.Refs[rid]
|
234 |
+
ann = refer.refToAnn[rid]
|
235 |
+
|
236 |
+
file_name = refer.Imgs[ref['image_id']]['file_name']
|
237 |
+
if dataset == 'refclef':
|
238 |
+
file_name = os.path.join(
|
239 |
+
'refer', 'data', 'images', 'saiapr_tc-12', file_name
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
coco_set = file_name.split('_')[1]
|
243 |
+
file_name = os.path.join(
|
244 |
+
'refer', 'data', 'images', 'mscoco', coco_set, file_name
|
245 |
+
)
|
246 |
+
|
247 |
+
bbox = ann['bbox']
|
248 |
+
bbox = torch.atleast_2d(torch.FloatTensor(bbox))
|
249 |
+
bbox = box_convert(bbox, 'xywh', 'xyxy') # xyxy
|
250 |
+
|
251 |
+
sentences = [s['sent'] for s in ref['sentences']]
|
252 |
+
if 'train' in split: # remove repeated expresions
|
253 |
+
sentences = list(set(sentences))
|
254 |
+
sentences = sorted(sentences)
|
255 |
+
|
256 |
+
self.samples += [(file_name, expr, bbox) for expr in sentences]
|
257 |
+
|
258 |
+
self.print_stats()
|
259 |
+
|
260 |
+
|
261 |
+
class RefCLEF(ReferDataset):
|
262 |
+
def __init__(self, *args, **kwargs):
|
263 |
+
assert args[0] in ('train', 'val', 'test')
|
264 |
+
super().__init__('refer/data', 'refclef', 'berkeley', *args, **kwargs)
|
265 |
+
|
266 |
+
|
267 |
+
class RefCOCO(ReferDataset):
|
268 |
+
def __init__(self, *args, **kwargs):
|
269 |
+
assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB')
|
270 |
+
super().__init__('refer/data', 'refcoco', 'unc', *args, **kwargs)
|
271 |
+
|
272 |
+
|
273 |
+
class RefCOCOp(ReferDataset):
|
274 |
+
def __init__(self, *args, **kwargs):
|
275 |
+
assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB')
|
276 |
+
super().__init__('refer/data', 'refcoco+', 'unc', *args, **kwargs)
|
277 |
+
|
278 |
+
|
279 |
+
class RefCOCOg(ReferDataset):
|
280 |
+
def __init__(self, *args, **kwargs):
|
281 |
+
assert args[0] in ('train', 'val', 'test')
|
282 |
+
super().__init__('refer/data', 'refcocog', 'umd', *args, **kwargs)
|
embeddings.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
# adapted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
9 |
+
class PositionEmbedding1D(nn.Module):
|
10 |
+
def __init__(self, embedding_dim, dropout=0.1, max_len=128):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
# self.dropout = nn.Dropout(p=dropout)
|
14 |
+
|
15 |
+
position = torch.arange(max_len).unsqueeze(1)
|
16 |
+
div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim))
|
17 |
+
pe = torch.zeros(max_len, embedding_dim)
|
18 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
19 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
20 |
+
pe = pe.unsqueeze(0) # .transpose(0, 1)
|
21 |
+
self.register_buffer('pe', pe)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
# # x: Tensor, shape [batch_size, seq_len, embedding_dim]
|
25 |
+
# x = x + self.pe[:, :x.size(1)]
|
26 |
+
# return self.dropout(x)
|
27 |
+
N, T, _ = x.size()
|
28 |
+
return self.pe[:, :T].repeat(N, 1, 1)
|
29 |
+
|
30 |
+
|
31 |
+
class LearnedPositionEmbedding1D(nn.Module):
|
32 |
+
def __init__(self, embedding_dim, max_len=128):
|
33 |
+
super().__init__()
|
34 |
+
self.pe = nn.Parameter(torch.Tensor(1, max_len, embedding_dim))
|
35 |
+
self.reset_parameters()
|
36 |
+
|
37 |
+
def reset_parameters(self):
|
38 |
+
nn.init.xavier_normal_(self.pe)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
N, T, _ = x.size()
|
42 |
+
return self.pe[:, :T].repeat(N, 1, 1)
|
43 |
+
|
44 |
+
|
45 |
+
# https://huggingface.co/transformers/_modules/transformers/models/detr/modeling_detr.html
|
46 |
+
class PositionEmbedding2D(nn.Module):
|
47 |
+
def __init__(self, embedding_dim, temperature=10000, normalize=False,
|
48 |
+
scale=None):
|
49 |
+
super().__init__()
|
50 |
+
assert embedding_dim % 2 == 0
|
51 |
+
self.half_embedding_dim = embedding_dim // 2
|
52 |
+
self.temperature = temperature
|
53 |
+
self.normalize = normalize
|
54 |
+
if scale is not None and normalize is False:
|
55 |
+
raise ValueError("normalize should be True if scale is passed")
|
56 |
+
if scale is None:
|
57 |
+
scale = 2 * math.pi
|
58 |
+
self.scale = scale
|
59 |
+
|
60 |
+
def forward(self, pixel_values, pixel_mask):
|
61 |
+
assert pixel_mask is not None, "No pixel mask provided"
|
62 |
+
if pixel_mask.dim() == 4 and pixel_mask.size(1) == 1:
|
63 |
+
pixel_mask = pixel_mask.squeeze(1)
|
64 |
+
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
|
65 |
+
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
|
66 |
+
if self.normalize:
|
67 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
|
68 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
|
69 |
+
|
70 |
+
dim_t = torch.arange(self.half_embedding_dim, dtype=torch.float32, device=pixel_values.device)
|
71 |
+
dim_t = self.temperature ** (2 * torch.divide(dim_t, 2, rounding_mode='floor') / self.half_embedding_dim)
|
72 |
+
|
73 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
74 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
75 |
+
pos_x = torch.stack((
|
76 |
+
pos_x[:, :, :, 0::2].sin(),
|
77 |
+
pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
78 |
+
pos_y = torch.stack((
|
79 |
+
pos_y[:, :, :, 0::2].sin(),
|
80 |
+
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
81 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
82 |
+
return pos
|
83 |
+
|
84 |
+
|
85 |
+
# https://huggingface.co/transformers/_modules/transformers/models/detr/modeling_detr.html
|
86 |
+
class LearnedPositionEmbedding2D(nn.Module):
|
87 |
+
def __init__(self, embedding_dim):
|
88 |
+
super().__init__()
|
89 |
+
assert embedding_dim % 2 == 0, 'embedding dimensionality must be even'
|
90 |
+
self.rows_embeddings = nn.Embedding(50, embedding_dim//2)
|
91 |
+
self.cols_embeddings = nn.Embedding(50, embedding_dim//2)
|
92 |
+
|
93 |
+
def forward(self, pixel_values, pixel_mask=None):
|
94 |
+
h, w = pixel_values.shape[-2:]
|
95 |
+
i = torch.arange(w, device=pixel_values.device)
|
96 |
+
j = torch.arange(h, device=pixel_values.device)
|
97 |
+
x_emb = self.cols_embeddings(i)
|
98 |
+
y_emb = self.rows_embeddings(j)
|
99 |
+
pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1)], dim=-1)
|
100 |
+
pos = pos.permute(2, 0, 1)
|
101 |
+
pos = pos.unsqueeze(0)
|
102 |
+
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
|
103 |
+
return pos
|
104 |
+
|
105 |
+
|
106 |
+
class Box8PositionEmbedding2D(nn.Module):
|
107 |
+
def __init__(self, embedding_dim, with_projection=True):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.proj = None
|
111 |
+
if with_projection:
|
112 |
+
self.proj = nn.Linear(8, embedding_dim)
|
113 |
+
nn.init.xavier_normal_(self.proj.weight)
|
114 |
+
nn.init.zeros_(self.proj.bias)
|
115 |
+
|
116 |
+
def forward(self, fmap, fmap_mask=None):
|
117 |
+
N, _, H, W = fmap.size()
|
118 |
+
|
119 |
+
y1, x1 = torch.meshgrid(
|
120 |
+
torch.arange(H, device=fmap.device, dtype=torch.float)/H,
|
121 |
+
torch.arange(W, device=fmap.device, dtype=torch.float)/W
|
122 |
+
)
|
123 |
+
y2, x2 = x1+1.0/W, y1+1.0/H
|
124 |
+
ww, hh = x2-x1, y2-y1
|
125 |
+
# x1, y1 = 2*x1-1, 2*y1-1
|
126 |
+
# x2, y2 = 2*x2-1, 2*y2-1
|
127 |
+
xc, yc = x1+0.5/W, y1+0.5/H
|
128 |
+
|
129 |
+
pos = torch.stack([x1, y1, x2, y2, xc, yc, ww, hh], dim=-1)
|
130 |
+
if self.proj is not None:
|
131 |
+
pos = self.proj(pos)
|
132 |
+
pos = pos.permute(2, 0, 1)
|
133 |
+
pos = pos.unsqueeze(0).repeat(N, 1, 1, 1)
|
134 |
+
return pos
|
135 |
+
|
136 |
+
def encode_boxes(self, boxes):
|
137 |
+
x1, y1, x2, y2 = boxes.unbind(-1)
|
138 |
+
ww, hh = x2-x1, y2-y1
|
139 |
+
xc, yc = x1+0.5*ww, y1+0.5*hh
|
140 |
+
pos = torch.stack([x1, y1, x2, y2, xc, yc, ww, hh], dim=-1)
|
141 |
+
if self.proj is not None:
|
142 |
+
pos = self.proj(pos)
|
143 |
+
return pos
|
144 |
+
|
145 |
+
|
146 |
+
class RelativePositionEmbedding2D(nn.Module):
|
147 |
+
def __init__(self, embedding_dim, spatial_bins=(16, 16), with_projection=True):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
assert isinstance(spatial_bins, (list, tuple)) and len(spatial_bins) == 2
|
151 |
+
self.spatial_bins = spatial_bins
|
152 |
+
|
153 |
+
self.proj = None
|
154 |
+
if with_projection:
|
155 |
+
self.proj = nn.Linear(2*spatial_bins[0]*spatial_bins[1], embedding_dim)
|
156 |
+
nn.init.xavier_normal_(self.proj.weight)
|
157 |
+
nn.init.zeros_(self.proj.bias)
|
158 |
+
|
159 |
+
def forward(self, fmap, fmap_mask=None):
|
160 |
+
N, _, H, W = fmap.size()
|
161 |
+
|
162 |
+
BH, BW = self.spatial_bins
|
163 |
+
yc, xc = torch.meshgrid(
|
164 |
+
0.5/BH + torch.arange(BH, device=fmap.device, dtype=torch.float)/BH,
|
165 |
+
0.5/BW + torch.arange(BW, device=fmap.device, dtype=torch.float)/BW
|
166 |
+
)
|
167 |
+
|
168 |
+
pos = torch.stack([xc, yc], dim=-1).view(-1, 1, 2)
|
169 |
+
pos = (pos - pos.transpose(0, 1)).reshape(BH, BW, -1) # relative positions
|
170 |
+
|
171 |
+
if self.proj is not None:
|
172 |
+
pos = self.proj(pos)
|
173 |
+
|
174 |
+
pos = pos.permute(2, 0, 1)
|
175 |
+
pos = pos.unsqueeze(0)
|
176 |
+
|
177 |
+
if H != BH or W != BW:
|
178 |
+
pos = nn.functional.interpolate(pos, (H, W), mode='nearest')
|
179 |
+
|
180 |
+
pos = pos.repeat(N, 1, 1, 1)
|
181 |
+
|
182 |
+
return pos
|
encoders.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import transformers
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from torchvision.models import detection
|
14 |
+
|
15 |
+
from backbones import get_backbone
|
16 |
+
|
17 |
+
from embeddings import Box8PositionEmbedding2D
|
18 |
+
|
19 |
+
EPS = 1e-5
|
20 |
+
|
21 |
+
TRANSFORMER_MODEL = 'bert-base-uncased'
|
22 |
+
# TRANSFORMER_MODEL = 'distilroberta-base'
|
23 |
+
|
24 |
+
|
25 |
+
def get_tokenizer(cache=None):
|
26 |
+
if cache is None:
|
27 |
+
return transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)
|
28 |
+
|
29 |
+
model_path = os.path.join(cache, TRANSFORMER_MODEL)
|
30 |
+
os.makedirs(model_path, exist_ok=True)
|
31 |
+
|
32 |
+
if os.path.exists(os.path.join(model_path, 'config.json')):
|
33 |
+
return transformers.BertTokenizer.from_pretrained(model_path)
|
34 |
+
|
35 |
+
tokenizer = transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)
|
36 |
+
tokenizer.save_pretrained(model_path)
|
37 |
+
|
38 |
+
return tokenizer
|
39 |
+
|
40 |
+
|
41 |
+
def weight_init(m):
|
42 |
+
if isinstance(m, nn.Conv2d):
|
43 |
+
nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
|
44 |
+
if m.bias is not None:
|
45 |
+
nn.init.zeros_(m.bias)
|
46 |
+
elif isinstance(m, nn.Linear):
|
47 |
+
nn.init.xavier_normal_(m.weight)
|
48 |
+
if m.bias is not None:
|
49 |
+
nn.init.zeros_(m.bias)
|
50 |
+
elif isinstance(m, nn.Embedding):
|
51 |
+
nn.init.xavier_normal_(m.weight)
|
52 |
+
|
53 |
+
|
54 |
+
class ImageEncoder(nn.Module):
|
55 |
+
def __init__(self, backbone='resnet50', out_channels=256, pretrained=True,
|
56 |
+
freeze_pretrained=False, with_pos=True):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
model = get_backbone(backbone, pretrained)
|
60 |
+
|
61 |
+
if pretrained and freeze_pretrained:
|
62 |
+
for p in model.parameters():
|
63 |
+
p.requires_grad = False
|
64 |
+
|
65 |
+
if 'resnet' in backbone:
|
66 |
+
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
|
67 |
+
model, return_layers=OrderedDict({'layer4': 'output'})
|
68 |
+
)
|
69 |
+
channels = 512 if backbone in ('resnet18', 'resnet34') else 2048
|
70 |
+
|
71 |
+
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
|
72 |
+
output_layer_name = list(model.named_children())[-1][0]
|
73 |
+
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
|
74 |
+
model, return_layers=OrderedDict({output_layer_name: 'output'})
|
75 |
+
)
|
76 |
+
channels = {
|
77 |
+
'cspdarknet53': 1024,
|
78 |
+
'efficientnet-b0': 1280,
|
79 |
+
'efficientnet-b3': 1536
|
80 |
+
}[backbone]
|
81 |
+
|
82 |
+
else:
|
83 |
+
raise RuntimeError('not a valid backbone')
|
84 |
+
|
85 |
+
in_channels = channels+8 if with_pos else channels
|
86 |
+
|
87 |
+
self.proj = nn.Sequential(
|
88 |
+
nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
|
89 |
+
nn.GroupNorm(1, out_channels, eps=EPS),
|
90 |
+
# nn.ReLU(inplace=True),
|
91 |
+
)
|
92 |
+
self.proj.apply(weight_init)
|
93 |
+
|
94 |
+
self.pos_emb = None
|
95 |
+
if with_pos:
|
96 |
+
self.pos_emb = Box8PositionEmbedding2D(with_projection=False)
|
97 |
+
|
98 |
+
self.out_channels = out_channels
|
99 |
+
|
100 |
+
def forward(self, img, mask=None):
|
101 |
+
x = self.backbone(img)['output']
|
102 |
+
if self.pos_emb is not None:
|
103 |
+
x = torch.cat([x, self.pos_emb(x)], dim=1)
|
104 |
+
x = self.proj(x) # NxDxHxW
|
105 |
+
|
106 |
+
x_mask = None
|
107 |
+
if mask is not None:
|
108 |
+
_, _, H, W = x.size()
|
109 |
+
x_mask = F.interpolate(mask, (H, W), mode='bilinear')
|
110 |
+
x_mask = (x_mask > 0.5).long()
|
111 |
+
|
112 |
+
return x, x_mask
|
113 |
+
|
114 |
+
|
115 |
+
class FPNImageEncoder(nn.Module):
|
116 |
+
def __init__(self,
|
117 |
+
backbone='resnet50', out_channels=256, pretrained=True,
|
118 |
+
freeze_pretrained=False, with_pos=True):
|
119 |
+
super().__init__()
|
120 |
+
|
121 |
+
model = get_backbone(backbone, pretrained)
|
122 |
+
|
123 |
+
if pretrained and freeze_pretrained:
|
124 |
+
for p in model.parameters():
|
125 |
+
p.requires_grad = False
|
126 |
+
|
127 |
+
if 'resnet' in backbone:
|
128 |
+
if backbone in ('resnet18', 'resnet34'):
|
129 |
+
in_channels_list = [64, 128, 256, 512]
|
130 |
+
else:
|
131 |
+
in_channels_list = [256, 512, 1024, 2048]
|
132 |
+
return_layers = OrderedDict({
|
133 |
+
'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'
|
134 |
+
})
|
135 |
+
|
136 |
+
# elif backbone == 'cspdarknet53':
|
137 |
+
# in_channels_list = [128, 256, 512, 1024]
|
138 |
+
# return_layers = OrderedDict({
|
139 |
+
# '1':'0', '2':'1', '3':'2', '4':'3'
|
140 |
+
# })
|
141 |
+
|
142 |
+
else:
|
143 |
+
raise RuntimeError('not a valid backbone')
|
144 |
+
|
145 |
+
self.backbone = model
|
146 |
+
|
147 |
+
self.fpn = detection.backbone_utils.BackboneWithFPN(
|
148 |
+
backbone=self.backbone,
|
149 |
+
return_layers=return_layers,
|
150 |
+
in_channels_list=in_channels_list,
|
151 |
+
out_channels=out_channels
|
152 |
+
)
|
153 |
+
|
154 |
+
self.fpn.fpn.extra_blocks = None # removes the 'pool' layer added by default
|
155 |
+
|
156 |
+
self.out_channels = out_channels
|
157 |
+
|
158 |
+
in_channels = int(out_channels + float(with_pos) * 8)
|
159 |
+
|
160 |
+
self.proj = nn.ModuleDict({
|
161 |
+
level: nn.Sequential(
|
162 |
+
nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
|
163 |
+
nn.GroupNorm(1, out_channels, eps=EPS),
|
164 |
+
# nn.ReLU(inplace=True),
|
165 |
+
) for level in return_layers.values()
|
166 |
+
})
|
167 |
+
self.proj.apply(weight_init)
|
168 |
+
|
169 |
+
self.pos_emb = None
|
170 |
+
if with_pos:
|
171 |
+
self.pos_emb = Box8PositionEmbedding2D(with_projection=False)
|
172 |
+
|
173 |
+
def forward(self, x, mask=None):
|
174 |
+
x = self.fpn(x)
|
175 |
+
|
176 |
+
# smallest feature map (eg. 16x16 for an input of 512x512 pixels)
|
177 |
+
_, _, H, W = list(x.values())[-1].size()
|
178 |
+
|
179 |
+
x_out = None
|
180 |
+
for level, fmap in x.items():
|
181 |
+
# fmap = torch.relu(fmap) # FPN blocks end in a conv2d, w/o activ.
|
182 |
+
if self.pos_emb is not None:
|
183 |
+
fmap = torch.cat([fmap, self.pos_emb(fmap)], dim=1) # +Pos
|
184 |
+
fmap = self.proj[level](fmap) # Conv+BN+ReLU
|
185 |
+
fmap = F.interpolate(fmap, (H, W), mode='nearest') # to a smaller size
|
186 |
+
if x_out is None:
|
187 |
+
x_out = fmap
|
188 |
+
else:
|
189 |
+
x_out += fmap
|
190 |
+
|
191 |
+
x_mask = None
|
192 |
+
if mask is not None:
|
193 |
+
x_mask = F.interpolate(mask, (H, W), mode='bilinear')
|
194 |
+
x_mask = (x_mask > 0.5).long()
|
195 |
+
|
196 |
+
return x_out, x_mask
|
197 |
+
|
198 |
+
|
199 |
+
class TransformerImageEncoder(nn.Module):
|
200 |
+
def __init__(self,
|
201 |
+
backbone='resnet50', out_channels=256, pretrained=True,
|
202 |
+
freeze_pretrained=False, num_heads=8, num_layers=6,
|
203 |
+
dropout_p=0.1):
|
204 |
+
super().__init__()
|
205 |
+
|
206 |
+
model = get_backbone(backbone, pretrained)
|
207 |
+
|
208 |
+
if pretrained and freeze_pretrained:
|
209 |
+
for p in model.parameters():
|
210 |
+
p.requires_grad = False
|
211 |
+
|
212 |
+
if 'resnet' in backbone:
|
213 |
+
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
|
214 |
+
model, return_layers=OrderedDict({'layer4': 'output'})
|
215 |
+
)
|
216 |
+
channels = 512 if backbone in ('resnet18', 'resnet34') else 2048
|
217 |
+
|
218 |
+
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
|
219 |
+
output_layer_name = list(model.named_children())[-1][0]
|
220 |
+
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
|
221 |
+
model, return_layers=OrderedDict({output_layer_name: 'output'})
|
222 |
+
)
|
223 |
+
channels = {
|
224 |
+
'cspdarknet53': 1024,
|
225 |
+
'efficientnet-b0': 1280,
|
226 |
+
'efficientnet-b3': 1536
|
227 |
+
}[backbone]
|
228 |
+
|
229 |
+
else:
|
230 |
+
raise RuntimeError('not a valid backbone')
|
231 |
+
|
232 |
+
self.proj = nn.Sequential(
|
233 |
+
nn.Conv2d(channels, out_channels, (1, 1), 1, bias=False),
|
234 |
+
nn.GroupNorm(1, out_channels, eps=EPS),
|
235 |
+
# nn.ReLU(inplace=True),
|
236 |
+
)
|
237 |
+
self.proj.apply(weight_init)
|
238 |
+
|
239 |
+
from transformers_pos import (
|
240 |
+
TransformerEncoder,
|
241 |
+
TransformerEncoderLayer,
|
242 |
+
)
|
243 |
+
|
244 |
+
self.encoder = TransformerEncoder(
|
245 |
+
TransformerEncoderLayer(
|
246 |
+
d_model=out_channels,
|
247 |
+
nhead=num_heads,
|
248 |
+
dropout=dropout_p,
|
249 |
+
batch_first=True
|
250 |
+
),
|
251 |
+
num_layers=num_layers
|
252 |
+
)
|
253 |
+
|
254 |
+
self.pos_emb = Box8PositionEmbedding2D(embedding_dim=out_channels)
|
255 |
+
|
256 |
+
self.out_channels = out_channels
|
257 |
+
|
258 |
+
def flatten(self, x):
|
259 |
+
N, _, H, W = x.size()
|
260 |
+
x = x.to(memory_format=torch.channels_last)
|
261 |
+
x = x.permute(0, 2, 3, 1).view(N, H*W, -1) # NxHWxD
|
262 |
+
return x
|
263 |
+
|
264 |
+
def forward(self, img, mask=None):
|
265 |
+
x = self.backbone(img)['output']
|
266 |
+
x = self.proj(x) # NxDxHxW
|
267 |
+
|
268 |
+
N, _, H, W = x.size()
|
269 |
+
|
270 |
+
pos = self.pos_emb(x) # NxDxHxW
|
271 |
+
pos = self.flatten(pos) # NxRxD
|
272 |
+
|
273 |
+
x = self.flatten(x) # NxRxD
|
274 |
+
|
275 |
+
# visibility mask
|
276 |
+
x_mask = None
|
277 |
+
if mask is not None:
|
278 |
+
x_mask = F.interpolate(mask, (H, W), mode='bilinear')
|
279 |
+
x_mask = (x_mask > 0.5).long()
|
280 |
+
|
281 |
+
if mask is None:
|
282 |
+
x = self.encoder(x, pos=pos) # NxRxD
|
283 |
+
else:
|
284 |
+
mask = self.flatten(x_mask).squeeze(-1)
|
285 |
+
x = self.encoder(x, src_key_padding_mask=(mask==0), pos=pos) # NxRxD
|
286 |
+
|
287 |
+
x = x.permute(0, 2, 1).view(N, -1, H, W) # NxDxHxW
|
288 |
+
|
289 |
+
return x, x_mask
|
290 |
+
|
291 |
+
|
292 |
+
class LanguageEncoder(nn.Module):
|
293 |
+
def __init__(self, out_features=256, dropout_p=0.2,
|
294 |
+
freeze_pretrained=False, global_pooling=True):
|
295 |
+
super().__init__()
|
296 |
+
self.language_model = transformers.AutoModel.from_pretrained(
|
297 |
+
TRANSFORMER_MODEL
|
298 |
+
)
|
299 |
+
|
300 |
+
if freeze_pretrained:
|
301 |
+
for p in self.language_model.parameters():
|
302 |
+
p.requires_grad = False
|
303 |
+
|
304 |
+
self.out_features = out_features
|
305 |
+
|
306 |
+
self.proj = nn.Sequential(
|
307 |
+
nn.Linear(768, out_features),
|
308 |
+
nn.LayerNorm(out_features, eps=1e-5),
|
309 |
+
# nn.ReLU(inplace=True),
|
310 |
+
# nn.Dropout(dropout_p),
|
311 |
+
)
|
312 |
+
self.proj.apply(weight_init)
|
313 |
+
|
314 |
+
self.global_pooling = bool(global_pooling)
|
315 |
+
|
316 |
+
def forward(self, z):
|
317 |
+
res = self.language_model(
|
318 |
+
input_ids=z['input_ids'],
|
319 |
+
position_ids=None,
|
320 |
+
attention_mask=z['attention_mask']
|
321 |
+
)
|
322 |
+
|
323 |
+
if self.global_pooling:
|
324 |
+
z, z_mask = self.proj(res.pooler_output), None
|
325 |
+
else:
|
326 |
+
z, z_mask = self.proj(res.last_hidden_state), z['attention_mask']
|
327 |
+
|
328 |
+
return z, z_mask
|
329 |
+
|
330 |
+
|
331 |
+
class RNNLanguageEncoder(nn.Module):
|
332 |
+
def __init__(self,
|
333 |
+
model_type='gru', hidden_size=1024, num_layers=2,
|
334 |
+
out_features=256, dropout_p=0.2, global_pooling=True):
|
335 |
+
super().__init__()
|
336 |
+
self.embeddings = transformers.AutoModel.from_pretrained(
|
337 |
+
TRANSFORMER_MODEL
|
338 |
+
).embeddings.word_embeddings
|
339 |
+
self.embeddings.weight.requires_grad = True
|
340 |
+
|
341 |
+
# self.dropout_emb = nn.Dropout(0.5)
|
342 |
+
self.dropout_emb = nn.Dropout(dropout_p)
|
343 |
+
|
344 |
+
assert model_type in ('gru', 'lstm')
|
345 |
+
self.rnn = (nn.GRU if model_type == 'gru' else nn.LSTM)(
|
346 |
+
input_size=self.embeddings.weight.size(1),
|
347 |
+
hidden_size=hidden_size,
|
348 |
+
num_layers=num_layers,
|
349 |
+
dropout=dropout_p,
|
350 |
+
batch_first=True,
|
351 |
+
bidirectional=True
|
352 |
+
)
|
353 |
+
|
354 |
+
self.proj = nn.Sequential(
|
355 |
+
nn.Linear(2*hidden_size, out_features),
|
356 |
+
nn.LayerNorm(out_features, eps=1e-5),
|
357 |
+
# nn.ReLU(inplace=True),
|
358 |
+
# nn.Dropout(dropout_p),
|
359 |
+
)
|
360 |
+
self.proj.apply(weight_init)
|
361 |
+
|
362 |
+
self.out_features = out_features
|
363 |
+
|
364 |
+
self.global_pooling = bool(global_pooling)
|
365 |
+
assert global_pooling # only w/ global pooling
|
366 |
+
|
367 |
+
def forward(self, z):
|
368 |
+
z_mask = z['attention_mask']
|
369 |
+
|
370 |
+
z = self.dropout_emb(self.embeddings(z['input_ids']))
|
371 |
+
z, h_n = self.rnn(z, None)
|
372 |
+
|
373 |
+
if isinstance(self.rnn, nn.LSTM):
|
374 |
+
h_n = h_n[0]
|
375 |
+
|
376 |
+
# hidden states as (num_layers, num_directions, batch, hidden_size)
|
377 |
+
h_n = h_n.view(self.rnn.num_layers, 2, z.size(0), self.rnn.hidden_size)
|
378 |
+
|
379 |
+
# last hidden states
|
380 |
+
h_n = h_n[-1].permute(1, 0, 2).reshape(z.size(0), -1)
|
381 |
+
h_n = self.proj(h_n)
|
382 |
+
return h_n, z_mask
|
383 |
+
|
384 |
+
|
385 |
+
class SimpleEncoder(nn.Module):
|
386 |
+
def __init__(self, out_features=256, dropout_p=0.1, global_pooling=True):
|
387 |
+
super().__init__()
|
388 |
+
self.embeddings = transformers.AutoModel.from_pretrained(
|
389 |
+
TRANSFORMER_MODEL
|
390 |
+
).embeddings.word_embeddings
|
391 |
+
self.embeddings.weight.requires_grad = True
|
392 |
+
|
393 |
+
# self.dropout_emb = nn.Dropout(0.5)
|
394 |
+
self.dropout_emb = nn.Dropout(dropout_p)
|
395 |
+
|
396 |
+
self.proj = nn.Sequential(
|
397 |
+
nn.Linear(768, out_features),
|
398 |
+
nn.LayerNorm(out_features, eps=1e-5),
|
399 |
+
# nn.ReLU(inplace=True),
|
400 |
+
# nn.Dropout(dropout_p),
|
401 |
+
)
|
402 |
+
self.proj.apply(weight_init)
|
403 |
+
|
404 |
+
self.out_features = out_features
|
405 |
+
|
406 |
+
self.global_pooling = bool(global_pooling)
|
407 |
+
assert not self.global_pooling # only w/o global pooling
|
408 |
+
|
409 |
+
def forward(self, z):
|
410 |
+
z_mask = z['attention_mask']
|
411 |
+
z = self.embeddings(z['input_ids'])
|
412 |
+
z = self.proj(self.dropout_emb(z))
|
413 |
+
# z[:, 0] = torch.mean(z[:, 1:], 1)
|
414 |
+
return z, z_mask
|
models.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from torchvision.ops import box_convert
|
5 |
+
import embeddings as emb
|
6 |
+
import encoders as enc
|
7 |
+
from encoders import weight_init
|
8 |
+
|
9 |
+
def conv3x3(in_channels, out_channels, num_groups=0):
|
10 |
+
return nn.Sequential(
|
11 |
+
# Conv2d w/o bias since BatchNorm2d/GroupNorm already accounts for it (affine=True)
|
12 |
+
nn.Conv2d(in_channels, out_channels, (3, 3), 1, 1, bias=False),
|
13 |
+
nn.BatchNorm2d(out_channels) if num_groups < 1 else nn.GroupNorm(num_groups, out_channels),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class IntuitionKillingMachine(nn.Module):
|
19 |
+
def __init__(self,
|
20 |
+
backbone='resnet50', pretrained=True, embedding_size=256,
|
21 |
+
num_heads=8, num_layers=6, num_conv=4, dropout_p=0.1,
|
22 |
+
segmentation_head=True, mask_pooling=True):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
if backbone.endswith('+tr'):
|
26 |
+
self.vis_enc = enc.TransformerImageEncoder(
|
27 |
+
backbone=backbone.rstrip('+tr'),
|
28 |
+
out_channels=embedding_size,
|
29 |
+
pretrained=pretrained,
|
30 |
+
)
|
31 |
+
|
32 |
+
elif backbone.endswith('+fpn'):
|
33 |
+
self.vis_enc = enc.FPNImageEncoder(
|
34 |
+
backbone=backbone.rstrip('+fpn'),
|
35 |
+
out_channels=embedding_size,
|
36 |
+
pretrained=pretrained,
|
37 |
+
with_pos=False
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
self.vis_enc = enc.ImageEncoder(
|
41 |
+
backbone=backbone,
|
42 |
+
out_channels=embedding_size,
|
43 |
+
pretrained=pretrained,
|
44 |
+
with_pos=False
|
45 |
+
)
|
46 |
+
|
47 |
+
# freeze ResNet stem
|
48 |
+
if 'resnet' in backbone:
|
49 |
+
self.vis_enc.backbone.conv1.requires_grad = False
|
50 |
+
self.vis_enc.backbone.conv1.eval()
|
51 |
+
|
52 |
+
self.vis_pos_emb = emb.LearnedPositionEmbedding2D(
|
53 |
+
embedding_dim=embedding_size
|
54 |
+
)
|
55 |
+
|
56 |
+
self.lan_enc = enc.LanguageEncoder(
|
57 |
+
out_features=embedding_size,
|
58 |
+
global_pooling=False,
|
59 |
+
dropout_p=dropout_p
|
60 |
+
)
|
61 |
+
|
62 |
+
self.lan_pos_emb = emb.LearnedPositionEmbedding1D(
|
63 |
+
embedding_dim=embedding_size
|
64 |
+
)
|
65 |
+
|
66 |
+
from transformers_pos import (
|
67 |
+
XTransformerEncoder,
|
68 |
+
TransformerEncoder,
|
69 |
+
TransformerEncoderLayer,
|
70 |
+
)
|
71 |
+
|
72 |
+
self.encoder = TransformerEncoder(
|
73 |
+
TransformerEncoderLayer(
|
74 |
+
d_model=embedding_size,
|
75 |
+
nhead=num_heads,
|
76 |
+
dropout=dropout_p,
|
77 |
+
batch_first=True
|
78 |
+
),
|
79 |
+
num_layers=num_layers
|
80 |
+
)
|
81 |
+
|
82 |
+
# ---
|
83 |
+
# CONV PRE-HEAD (NECK?)
|
84 |
+
|
85 |
+
if num_conv > 0:
|
86 |
+
self.pre_head = nn.Sequential(*[
|
87 |
+
conv3x3(embedding_size, embedding_size) for _ in range(num_conv)
|
88 |
+
])
|
89 |
+
self.pre_head.apply(weight_init)
|
90 |
+
else:
|
91 |
+
self.pre_head = nn.Identity()
|
92 |
+
|
93 |
+
# ---
|
94 |
+
# OUTPUT HEADS
|
95 |
+
|
96 |
+
# box prediction
|
97 |
+
self.head = nn.Sequential(
|
98 |
+
nn.Linear(embedding_size, 4, bias=True),
|
99 |
+
nn.Sigmoid()
|
100 |
+
)
|
101 |
+
self.head.apply(weight_init)
|
102 |
+
|
103 |
+
# box segmentation mask
|
104 |
+
self.segm_head = None
|
105 |
+
if segmentation_head:
|
106 |
+
self.segm_head = nn.Sequential(
|
107 |
+
nn.Conv2d(embedding_size, 1, (3, 3), 1, 1, bias=True),
|
108 |
+
#nn.Sigmoid()
|
109 |
+
)
|
110 |
+
self.segm_head.apply(weight_init)
|
111 |
+
|
112 |
+
# ---
|
113 |
+
|
114 |
+
self.mask_pooling = bool(mask_pooling)
|
115 |
+
|
116 |
+
if self.mask_pooling and self.segm_head is None:
|
117 |
+
raise RuntimeError('mask pooling w/o a segmentation head does not makes sense')
|
118 |
+
|
119 |
+
self.embedding_size = embedding_size
|
120 |
+
|
121 |
+
# def slow_param_ids(self, **kwargs):
|
122 |
+
# return []
|
123 |
+
|
124 |
+
def slow_param_ids(self, slow_visual_backbone=True, slow_language_backbone=True):
|
125 |
+
ids = []
|
126 |
+
|
127 |
+
if slow_visual_backbone:
|
128 |
+
ids += [id(p) for p in self.vis_enc.backbone.parameters()]
|
129 |
+
if hasattr(self.vis_enc, 'encoder'): # +tr
|
130 |
+
ids += [id(p) for p in self.vis_enc.encoder.parameters()]
|
131 |
+
|
132 |
+
if slow_language_backbone:
|
133 |
+
if isinstance(self.lan_enc, enc.LanguageEncoder):
|
134 |
+
ids += [id(p) for p in self.lan_enc.language_model.parameters()]
|
135 |
+
else:
|
136 |
+
ids += [id(p) for p in self.lan_enc.embeddings.parameters()]
|
137 |
+
|
138 |
+
return ids
|
139 |
+
|
140 |
+
def flatten(self, x):
|
141 |
+
N, D, H, W = x.size()
|
142 |
+
x = x.to(memory_format=torch.channels_last)
|
143 |
+
x = x.permute(0, 2, 3, 1).view(N, H*W, D)
|
144 |
+
return x # NxHWxD
|
145 |
+
|
146 |
+
def unflatten(self, x, size):
|
147 |
+
N, R, D = x.size()
|
148 |
+
H, W = size
|
149 |
+
assert R == H*W, 'wrong tensor size'
|
150 |
+
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
|
151 |
+
x = x.view(N, D, H, W)
|
152 |
+
return x # NxDxHxW
|
153 |
+
|
154 |
+
def forward(self, input):
|
155 |
+
img, mask, tok = input['image'], input['mask'], input['tok']
|
156 |
+
|
157 |
+
# ---
|
158 |
+
# VISUAL EMBEDDINGS
|
159 |
+
|
160 |
+
x, x_mask = self.vis_enc(img, mask) # NxDxHxW, NxHxW
|
161 |
+
x_pos = self.vis_pos_emb(x, x_mask)
|
162 |
+
|
163 |
+
N, D, H, W = x.size() # save dims before flatten
|
164 |
+
|
165 |
+
x = self.flatten(x) # NxRxD
|
166 |
+
x_mask = self.flatten(x_mask).squeeze(-1) # NxR
|
167 |
+
x_pos = self.flatten(x_pos) # NxRxD
|
168 |
+
|
169 |
+
# ---
|
170 |
+
# LANGUAGE EMBEDDINGS
|
171 |
+
|
172 |
+
z, z_mask = self.lan_enc(tok) # NxTxD, NxT
|
173 |
+
z_pos = self.lan_pos_emb(z) # NxTxD
|
174 |
+
|
175 |
+
# ---
|
176 |
+
# V+L TRANSFORMER
|
177 |
+
|
178 |
+
# [...visual...]+[[CLS]...language tokens...[SEP]]
|
179 |
+
xz = torch.cat([x, z], dim=1)
|
180 |
+
xz_mask = torch.cat([x_mask, z_mask], dim=1)
|
181 |
+
xz_pos = torch.cat([x_pos, z_pos], dim=1)
|
182 |
+
|
183 |
+
xz = self.encoder(xz, src_key_padding_mask=(xz_mask==0), pos=xz_pos) #, size=(H,W))
|
184 |
+
|
185 |
+
# restore spatiality of visual embeddings after cross-modal encoding
|
186 |
+
xz_vis = xz[:, :H*W, ...]
|
187 |
+
xz_vis = self.unflatten(xz_vis, (H, W))
|
188 |
+
|
189 |
+
x_mask = self.unflatten(x_mask.unsqueeze(-1), (H, W))
|
190 |
+
|
191 |
+
# ---
|
192 |
+
|
193 |
+
# convolutional pre-head
|
194 |
+
xz_vis = self.pre_head(xz_vis)
|
195 |
+
|
196 |
+
# ---
|
197 |
+
|
198 |
+
# segmentation head w/ (opt.) pooling
|
199 |
+
segm_mask, pooled_feat = None, None
|
200 |
+
if self.segm_head is not None:
|
201 |
+
segm_mask = torch.sigmoid(self.segm_head(xz_vis)) * x_mask
|
202 |
+
if self.mask_pooling: # box mask guided pooling
|
203 |
+
pooled_feat = (segm_mask * xz_vis).sum((2, 3)) / segm_mask.sum((2, 3))
|
204 |
+
segm_mask = F.interpolate(segm_mask, img.size()[2:], mode='bilinear', align_corners=True)
|
205 |
+
|
206 |
+
# if not mask_pooling, do the pooling using all visual feats (equiv. to a uniform mask)
|
207 |
+
if pooled_feat is None:
|
208 |
+
pooled_feat = (x_mask * xz_vis).sum((2, 3)) / x_mask.sum((2, 3))
|
209 |
+
|
210 |
+
# bbox prediction
|
211 |
+
pred = self.head(pooled_feat)
|
212 |
+
pred = box_convert(pred, 'cxcywh', 'xyxy')
|
213 |
+
|
214 |
+
return pred, segm_mask
|
215 |
+
|
216 |
+
class HeadlessMachine(nn.Module):
|
217 |
+
def __init__(self,
|
218 |
+
backbone='resnet50', pretrained=True, embedding_size=256,
|
219 |
+
num_heads=8, num_layers=6, num_conv=4, dropout_p=0.1,
|
220 |
+
segmentation_head=True, mask_pooling=True):
|
221 |
+
super().__init__()
|
222 |
+
|
223 |
+
if backbone.endswith('+tr'):
|
224 |
+
self.vis_enc = enc.TransformerImageEncoder(
|
225 |
+
backbone=backbone.rstrip('+tr'),
|
226 |
+
out_channels=embedding_size,
|
227 |
+
pretrained=pretrained,
|
228 |
+
)
|
229 |
+
|
230 |
+
elif backbone.endswith('+fpn'):
|
231 |
+
self.vis_enc = enc.FPNImageEncoder(
|
232 |
+
backbone=backbone.rstrip('+fpn'),
|
233 |
+
out_channels=embedding_size,
|
234 |
+
pretrained=pretrained,
|
235 |
+
with_pos=False
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
self.vis_enc = enc.ImageEncoder(
|
239 |
+
backbone=backbone,
|
240 |
+
out_channels=embedding_size,
|
241 |
+
pretrained=pretrained,
|
242 |
+
with_pos=False
|
243 |
+
)
|
244 |
+
|
245 |
+
# freeze ResNet stem
|
246 |
+
if 'resnet' in backbone:
|
247 |
+
self.vis_enc.backbone.conv1.requires_grad = False
|
248 |
+
self.vis_enc.backbone.conv1.eval()
|
249 |
+
|
250 |
+
self.vis_pos_emb = emb.LearnedPositionEmbedding2D(
|
251 |
+
embedding_dim=embedding_size
|
252 |
+
)
|
253 |
+
|
254 |
+
self.lan_enc = enc.LanguageEncoder(
|
255 |
+
out_features=embedding_size,
|
256 |
+
global_pooling=False,
|
257 |
+
dropout_p=dropout_p
|
258 |
+
)
|
259 |
+
|
260 |
+
self.lan_pos_emb = emb.LearnedPositionEmbedding1D(
|
261 |
+
embedding_dim=embedding_size
|
262 |
+
)
|
263 |
+
|
264 |
+
from transformers_pos import (
|
265 |
+
XTransformerEncoder,
|
266 |
+
TransformerEncoder,
|
267 |
+
TransformerEncoderLayer,
|
268 |
+
)
|
269 |
+
|
270 |
+
self.encoder = TransformerEncoder(
|
271 |
+
TransformerEncoderLayer(
|
272 |
+
d_model=embedding_size,
|
273 |
+
nhead=num_heads,
|
274 |
+
dropout=dropout_p,
|
275 |
+
batch_first=True
|
276 |
+
),
|
277 |
+
num_layers=num_layers
|
278 |
+
)
|
279 |
+
|
280 |
+
# ---
|
281 |
+
# CONV PRE-HEAD (NECK?)
|
282 |
+
|
283 |
+
if num_conv > 0:
|
284 |
+
self.pre_head = nn.Sequential(*[
|
285 |
+
conv3x3(embedding_size, embedding_size) for _ in range(num_conv)
|
286 |
+
])
|
287 |
+
self.pre_head.apply(weight_init)
|
288 |
+
else:
|
289 |
+
self.pre_head = nn.Identity()
|
290 |
+
|
291 |
+
# ---
|
292 |
+
# OUTPUT HEADS
|
293 |
+
|
294 |
+
# box prediction
|
295 |
+
self.head = nn.Sequential(
|
296 |
+
nn.Linear(embedding_size, 4, bias=True),
|
297 |
+
nn.Sigmoid()
|
298 |
+
)
|
299 |
+
self.head.apply(weight_init)
|
300 |
+
|
301 |
+
# box segmentation mask
|
302 |
+
self.segm_head = None
|
303 |
+
if segmentation_head:
|
304 |
+
self.segm_head = nn.Sequential(
|
305 |
+
nn.Conv2d(embedding_size, 1, (3, 3), 1, 1, bias=True),
|
306 |
+
#nn.Sigmoid()
|
307 |
+
)
|
308 |
+
self.segm_head.apply(weight_init)
|
309 |
+
|
310 |
+
# ---
|
311 |
+
|
312 |
+
self.mask_pooling = bool(mask_pooling)
|
313 |
+
|
314 |
+
if self.mask_pooling and self.segm_head is None:
|
315 |
+
raise RuntimeError('mask pooling w/o a segmentation head does not makes sense')
|
316 |
+
|
317 |
+
self.embedding_size = embedding_size
|
318 |
+
|
319 |
+
# def slow_param_ids(self, **kwargs):
|
320 |
+
# return []
|
321 |
+
|
322 |
+
def slow_param_ids(self, slow_visual_backbone=True, slow_language_backbone=True):
|
323 |
+
ids = []
|
324 |
+
|
325 |
+
if slow_visual_backbone:
|
326 |
+
ids += [id(p) for p in self.vis_enc.backbone.parameters()]
|
327 |
+
if hasattr(self.vis_enc, 'encoder'): # +tr
|
328 |
+
ids += [id(p) for p in self.vis_enc.encoder.parameters()]
|
329 |
+
|
330 |
+
if slow_language_backbone:
|
331 |
+
if isinstance(self.lan_enc, enc.LanguageEncoder):
|
332 |
+
ids += [id(p) for p in self.lan_enc.language_model.parameters()]
|
333 |
+
else:
|
334 |
+
ids += [id(p) for p in self.lan_enc.embeddings.parameters()]
|
335 |
+
|
336 |
+
return ids
|
337 |
+
|
338 |
+
def flatten(self, x):
|
339 |
+
N, D, H, W = x.size()
|
340 |
+
x = x.to(memory_format=torch.channels_last)
|
341 |
+
x = x.permute(0, 2, 3, 1).view(N, H*W, D)
|
342 |
+
return x # NxHWxD
|
343 |
+
|
344 |
+
def unflatten(self, x, size):
|
345 |
+
N, R, D = x.size()
|
346 |
+
H, W = size
|
347 |
+
assert R == H*W, 'wrong tensor size'
|
348 |
+
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
|
349 |
+
x = x.view(N, D, H, W)
|
350 |
+
return x # NxDxHxW
|
351 |
+
|
352 |
+
def forward(self, input):
|
353 |
+
img, mask, tok = input['image'], input['mask'], input['tok']
|
354 |
+
|
355 |
+
# ---
|
356 |
+
# VISUAL EMBEDDINGS
|
357 |
+
|
358 |
+
x, x_mask = self.vis_enc(img, mask) # NxDxHxW, NxHxW
|
359 |
+
x_pos = self.vis_pos_emb(x, x_mask)
|
360 |
+
|
361 |
+
N, D, H, W = x.size() # save dims before flatten
|
362 |
+
|
363 |
+
x = self.flatten(x) # NxRxD
|
364 |
+
x_mask = self.flatten(x_mask).squeeze(-1) # NxR
|
365 |
+
x_pos = self.flatten(x_pos) # NxRxD
|
366 |
+
|
367 |
+
# ---
|
368 |
+
# LANGUAGE EMBEDDINGS
|
369 |
+
|
370 |
+
z, z_mask = self.lan_enc(tok) # NxTxD, NxT
|
371 |
+
z_pos = self.lan_pos_emb(z) # NxTxD
|
372 |
+
|
373 |
+
# ---
|
374 |
+
# V+L TRANSFORMER
|
375 |
+
|
376 |
+
# [...visual...]+[[CLS]...language tokens...[SEP]]
|
377 |
+
xz = torch.cat([x, z], dim=1)
|
378 |
+
xz_mask = torch.cat([x_mask, z_mask], dim=1)
|
379 |
+
xz_pos = torch.cat([x_pos, z_pos], dim=1)
|
380 |
+
|
381 |
+
xz = self.encoder(xz, src_key_padding_mask=(xz_mask==0), pos=xz_pos) #, size=(H,W))
|
382 |
+
|
383 |
+
# restore spatiality of visual embeddings after cross-modal encoding
|
384 |
+
xz_vis = xz[:, :H*W, ...]
|
385 |
+
xz_vis = self.unflatten(xz_vis, (H, W))
|
386 |
+
|
387 |
+
x_mask = self.unflatten(x_mask.unsqueeze(-1), (H, W))
|
388 |
+
|
389 |
+
# ---
|
390 |
+
|
391 |
+
# convolutional pre-head
|
392 |
+
xz_vis = self.pre_head(xz_vis)
|
393 |
+
|
394 |
+
# ---
|
395 |
+
|
396 |
+
# segmentation head w/ (opt.) pooling
|
397 |
+
segm_mask, pooled_feat = None, None
|
398 |
+
if self.segm_head is not None:
|
399 |
+
segm_mask = torch.sigmoid(self.segm_head(xz_vis)) * x_mask
|
400 |
+
if self.mask_pooling: # box mask guided pooling
|
401 |
+
pooled_feat = (segm_mask * xz_vis).sum((2, 3)) / segm_mask.sum((2, 3))
|
402 |
+
segm_mask = F.interpolate(segm_mask, img.size()[2:], mode='bilinear', align_corners=True)
|
403 |
+
|
404 |
+
# if not mask_pooling, do the pooling using all visual feats (equiv. to a uniform mask)
|
405 |
+
if pooled_feat is None:
|
406 |
+
pooled_feat = (x_mask * xz_vis).sum((2, 3)) / x_mask.sum((2, 3))
|
407 |
+
|
408 |
+
# bbox prediction
|
409 |
+
pred = self.head(pooled_feat)
|
410 |
+
pred = box_convert(pred, 'cxcywh', 'xyxy')
|
411 |
+
|
412 |
+
return pred, segm_mask
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pillow==9.1.0
|
2 |
+
timm==0.6.7
|
3 |
+
torch==1.9.0
|
4 |
+
torchvision==0.10.0
|
5 |
+
transformers==4.12.3
|
testing_loading.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import IntuitionKillingMachine
|
2 |
+
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
|
3 |
+
from torchvision.transforms import Compose
|
4 |
+
from encoders import get_tokenizer
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from copy import copy
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
|
11 |
+
def parse_model_args(model_path):
|
12 |
+
_, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
|
13 |
+
return {
|
14 |
+
'dataset': dataset,
|
15 |
+
'max_length': int(max_length),
|
16 |
+
'input_size': int(input_size),
|
17 |
+
'backbone': backbone,
|
18 |
+
'num_heads': int(num_heads),
|
19 |
+
'num_layers': int(num_layers),
|
20 |
+
'num_conv': int(num_conv),
|
21 |
+
'mu': float(mu),
|
22 |
+
'mask_pooling': bool(mask_pooling == '1')
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
class Prober:
|
27 |
+
def __init__(self,
|
28 |
+
df_path=None,
|
29 |
+
dataset_path=None,
|
30 |
+
model_checkpoint=None):
|
31 |
+
params = parse_model_args(model_checkpoint)
|
32 |
+
mean = [0.485, 0.456, 0.406]
|
33 |
+
sdev = [0.229, 0.224, 0.225]
|
34 |
+
self.tokenizer = get_tokenizer()
|
35 |
+
self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
|
36 |
+
self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
|
37 |
+
self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
|
38 |
+
self.model = IntuitionKillingMachine(
|
39 |
+
backbone=params['backbone'],
|
40 |
+
pretrained=True,
|
41 |
+
num_heads=params['num_heads'],
|
42 |
+
num_layers=params['num_layers'],
|
43 |
+
num_conv=params['num_conv'],
|
44 |
+
segmentation_head=bool(params['mu'] > 0.0),
|
45 |
+
mask_pooling=params['mask_pooling']
|
46 |
+
)
|
47 |
+
self.transform = Compose([
|
48 |
+
ToTensor(),
|
49 |
+
Normalize(mean, sdev),
|
50 |
+
SquarePad(),
|
51 |
+
Resize(size=(params['input_size'], params['input_size'])),
|
52 |
+
NormalizeBoxCoords(),
|
53 |
+
])
|
54 |
+
self.max_length = 30
|
55 |
+
self.zipfile = ZipFile(dataset_path, 'r')
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def probe(self, idx, re, search_by_sample_id: bool= True):
|
59 |
+
if search_by_sample_id:
|
60 |
+
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
|
61 |
+
else:
|
62 |
+
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
|
63 |
+
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
64 |
+
W0, H0 = img.size
|
65 |
+
sample = {
|
66 |
+
'image': img,
|
67 |
+
'image_size': (H0, W0), # image original size
|
68 |
+
'bbox': torch.tensor([copy(target)]),
|
69 |
+
'bbox_raw': torch.tensor([copy(target)]),
|
70 |
+
'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
|
71 |
+
'mask_bbox': None, # target bbox mask
|
72 |
+
}
|
73 |
+
print('inn bbox: ', sample['bbox'])
|
74 |
+
sample = self.transform(sample)
|
75 |
+
tok = self.tokenizer(re,
|
76 |
+
max_length=30,
|
77 |
+
return_tensors='pt',
|
78 |
+
truncation=True)
|
79 |
+
inn = {'image': torch.stack([sample['image']]),
|
80 |
+
'mask': torch.stack([sample['mask']]),
|
81 |
+
'bbox': torch.stack([sample['bbox']]),
|
82 |
+
'tok': tok}
|
83 |
+
output = undo_box_transforms_batch(self.model(inn)[0],
|
84 |
+
[sample['tr_param']]).numpy().tolist()[0]
|
85 |
+
img1 = ImageDraw.Draw(img)
|
86 |
+
#img1.rectangle(target, outline ="#0000FF00", width=3)
|
87 |
+
img1.rectangle(output, outline ="#00FF0000", width=3)
|
88 |
+
return img
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
prober = Prober(
|
92 |
+
df_path = 'data/val-sim_metric.json',
|
93 |
+
dataset_path = "data/saiapr_tc-12.zip",
|
94 |
+
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
|
95 |
+
)
|
96 |
+
prober.probe(0, "tree")
|
97 |
+
print("Done")
|
transformers_pos.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Optional, Any
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from torch import Tensor
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def conv3x3(in_channels, out_channels, num_groups=0):
|
12 |
+
return nn.Sequential(
|
13 |
+
# Conv2d w/o bias since BatchNorm2d/GroupNorm already accounts for it (affine=True)
|
14 |
+
nn.Conv2d(in_channels, out_channels, (3, 3), 1, 1, bias=False),
|
15 |
+
nn.BatchNorm2d(out_channels) if num_groups < 1 else nn.GroupNorm(num_groups, out_channels),
|
16 |
+
nn.ReLU(inplace=True),
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class XTransformerEncoder(nn.Module):
|
21 |
+
__constants__ = ['norm']
|
22 |
+
def __init__(self, encoder_layer, num_layers, num_conv=2, norm=None):
|
23 |
+
super().__init__()
|
24 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
25 |
+
self.num_layers = num_layers
|
26 |
+
self.norm = norm
|
27 |
+
|
28 |
+
d_model = encoder_layer.linear1.in_features
|
29 |
+
self.conv = nn.ModuleList([
|
30 |
+
nn.Sequential(*[
|
31 |
+
conv3x3(d_model, d_model) for _ in range(num_conv)
|
32 |
+
]) for _ in range(num_layers)
|
33 |
+
])
|
34 |
+
|
35 |
+
def flatten(self, x):
|
36 |
+
N, D, H, W = x.size()
|
37 |
+
x = x.to(memory_format=torch.channels_last)
|
38 |
+
x = x.permute(0, 2, 3, 1).view(N, H*W, D)
|
39 |
+
return x # NxHWxD
|
40 |
+
|
41 |
+
def unflatten(self, x, size):
|
42 |
+
N, R, D = x.size()
|
43 |
+
H, W = size
|
44 |
+
assert R == H*W, 'wrong tensor size'
|
45 |
+
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
|
46 |
+
x = x.view(N, D, H, W)
|
47 |
+
return x # NxDxHxW
|
48 |
+
|
49 |
+
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, size=None) -> Tensor:
|
50 |
+
output = src
|
51 |
+
|
52 |
+
for i, mod in enumerate(self.layers):
|
53 |
+
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
54 |
+
|
55 |
+
vis = self.unflatten(output[:, :size[0]*size[1]], size)
|
56 |
+
vis = self.flatten(self.conv[i](vis))
|
57 |
+
|
58 |
+
output = torch.cat([vis, output[:, size[0]*size[1]:]], dim=1)
|
59 |
+
|
60 |
+
if self.norm is not None:
|
61 |
+
output = self.norm(output)
|
62 |
+
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class TransformerEncoder(nn.Module):
|
67 |
+
r"""TransformerEncoder is a stack of N encoder layers
|
68 |
+
|
69 |
+
Args:
|
70 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
71 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
72 |
+
norm: the layer normalization component (optional).
|
73 |
+
|
74 |
+
Examples::
|
75 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
76 |
+
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
77 |
+
>>> src = torch.rand(10, 32, 512)
|
78 |
+
>>> out = transformer_encoder(src)
|
79 |
+
"""
|
80 |
+
__constants__ = ['norm']
|
81 |
+
|
82 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
83 |
+
super(TransformerEncoder, self).__init__()
|
84 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
85 |
+
self.num_layers = num_layers
|
86 |
+
self.norm = norm
|
87 |
+
|
88 |
+
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None) -> Tensor:
|
89 |
+
r"""Pass the input through the encoder layers in turn.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
src: the sequence to the encoder (required).
|
93 |
+
mask: the mask for the src sequence (optional).
|
94 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
95 |
+
|
96 |
+
Shape:
|
97 |
+
see the docs in Transformer class.
|
98 |
+
"""
|
99 |
+
output = src
|
100 |
+
|
101 |
+
for mod in self.layers:
|
102 |
+
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
103 |
+
|
104 |
+
if self.norm is not None:
|
105 |
+
output = self.norm(output)
|
106 |
+
|
107 |
+
return output
|
108 |
+
|
109 |
+
|
110 |
+
class TransformerEncoderLayer(nn.Module):
|
111 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
112 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
113 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
114 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
115 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
116 |
+
in a different way during application.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
d_model: the number of expected features in the input (required).
|
120 |
+
nhead: the number of heads in the multiheadattention models (required).
|
121 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
122 |
+
dropout: the dropout value (default=0.1).
|
123 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
124 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
125 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
126 |
+
as (batch, seq, feature). Default: ``False``.
|
127 |
+
|
128 |
+
Examples::
|
129 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
130 |
+
>>> src = torch.rand(10, 32, 512)
|
131 |
+
>>> out = encoder_layer(src)
|
132 |
+
|
133 |
+
Alternatively, when ``batch_first`` is ``True``:
|
134 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
135 |
+
>>> src = torch.rand(32, 10, 512)
|
136 |
+
>>> out = encoder_layer(src)
|
137 |
+
"""
|
138 |
+
__constants__ = ['batch_first']
|
139 |
+
|
140 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
141 |
+
layer_norm_eps=1e-5, batch_first=False,
|
142 |
+
device=None, dtype=None) -> None:
|
143 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
144 |
+
super(TransformerEncoderLayer, self).__init__()
|
145 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
146 |
+
**factory_kwargs)
|
147 |
+
# Implementation of Feedforward model
|
148 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
149 |
+
self.dropout = nn.Dropout(dropout)
|
150 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
151 |
+
|
152 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
153 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
154 |
+
self.dropout1 = nn.Dropout(dropout)
|
155 |
+
self.dropout2 = nn.Dropout(dropout)
|
156 |
+
|
157 |
+
self.activation = _get_activation_fn(activation)
|
158 |
+
|
159 |
+
def __setstate__(self, state):
|
160 |
+
if 'activation' not in state:
|
161 |
+
state['activation'] = F.relu
|
162 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
163 |
+
|
164 |
+
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None) -> Tensor:
|
165 |
+
r"""Pass the input through the encoder layer.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
src: the sequence to the encoder layer (required).
|
169 |
+
src_mask: the mask for the src sequence (optional).
|
170 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
171 |
+
|
172 |
+
Shape:
|
173 |
+
see the docs in Transformer class.
|
174 |
+
"""
|
175 |
+
|
176 |
+
q = k = src if pos is None else src + pos
|
177 |
+
|
178 |
+
src2 = self.self_attn(q, k, src, attn_mask=src_mask,
|
179 |
+
key_padding_mask=src_key_padding_mask)[0]
|
180 |
+
src = src + self.dropout1(src2)
|
181 |
+
src = self.norm1(src)
|
182 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
183 |
+
src = src + self.dropout2(src2)
|
184 |
+
src = self.norm2(src)
|
185 |
+
return src
|
186 |
+
|
187 |
+
|
188 |
+
def _get_clones(module, N):
|
189 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
190 |
+
|
191 |
+
|
192 |
+
def _get_activation_fn(activation):
|
193 |
+
if activation == "relu":
|
194 |
+
return F.relu
|
195 |
+
elif activation == "gelu":
|
196 |
+
return F.gelu
|
197 |
+
|
198 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
transforms.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from torchvision import transforms
|
4 |
+
|
5 |
+
from torchvision.transforms import Compose
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class ToTensor(transforms.ToTensor):
|
11 |
+
def __call__(self, input):
|
12 |
+
if not isinstance(input, dict):
|
13 |
+
return super().__call__(input)
|
14 |
+
assert 'image' in input
|
15 |
+
input['image'] = super().__call__(input['image'])
|
16 |
+
return input
|
17 |
+
|
18 |
+
|
19 |
+
class Normalize(transforms.Normalize):
|
20 |
+
def __call__(self, input):
|
21 |
+
if not isinstance(input, dict):
|
22 |
+
return super().__call__(input)
|
23 |
+
assert 'image' in input
|
24 |
+
input['image'] = super().__call__(input['image'])
|
25 |
+
return input
|
26 |
+
|
27 |
+
|
28 |
+
class NormalizeBoxCoords(transforms.ToTensor):
|
29 |
+
def __call__(self, input):
|
30 |
+
if not isinstance(input, dict):
|
31 |
+
return super().__call__(input)
|
32 |
+
assert 'image' in input and 'bbox' in input
|
33 |
+
_, H, W = input['image'].size()
|
34 |
+
input['bbox'][:, (0, 2)] /= W
|
35 |
+
input['bbox'][:, (1, 3)] /= H
|
36 |
+
|
37 |
+
if 'tr_param' not in input:
|
38 |
+
input['tr_param'] = []
|
39 |
+
input['tr_param'].append({'normalize_box_coords': (H, W)})
|
40 |
+
|
41 |
+
return input
|
42 |
+
|
43 |
+
|
44 |
+
class SquarePad(torch.nn.Module):
|
45 |
+
def __call__(self, input):
|
46 |
+
if isinstance(input, Image.Image):
|
47 |
+
raise NotImplementedError('put the SquarePad transform after ToTensor')
|
48 |
+
|
49 |
+
assert 'image' in input
|
50 |
+
_, h, w = input['image'].size()
|
51 |
+
|
52 |
+
max_wh = max(w, h)
|
53 |
+
xp = int(0.5 * (max_wh - w))
|
54 |
+
yp = int(0.5 * (max_wh - h))
|
55 |
+
padding = (xp, yp, (max_wh-xp)-w, (max_wh-yp)-h)
|
56 |
+
|
57 |
+
input['image'] = transforms.functional.pad(
|
58 |
+
input['image'], padding, fill=0, padding_mode='constant'
|
59 |
+
)
|
60 |
+
# input['image'] = transforms.functional.pad(
|
61 |
+
# input['image'], padding, padding_mode='edge'
|
62 |
+
# )
|
63 |
+
|
64 |
+
if 'mask' in input:
|
65 |
+
input['mask'] = transforms.functional.pad(
|
66 |
+
input['mask'], padding, fill=0, padding_mode='constant'
|
67 |
+
)
|
68 |
+
|
69 |
+
if 'bbox' in input:
|
70 |
+
input['bbox'][:, (0, 2)] += xp
|
71 |
+
input['bbox'][:, (1, 3)] += yp
|
72 |
+
|
73 |
+
if 'tr_param' not in input:
|
74 |
+
input['tr_param'] = []
|
75 |
+
input['tr_param'].append({'square_pad': padding})
|
76 |
+
|
77 |
+
return input
|
78 |
+
|
79 |
+
|
80 |
+
class Resize(transforms.Resize):
|
81 |
+
def __call__(self, input):
|
82 |
+
if not isinstance(input, dict):
|
83 |
+
return super().__call__(input)
|
84 |
+
|
85 |
+
assert 'image' in input
|
86 |
+
|
87 |
+
if not torch.is_tensor(input['image']):
|
88 |
+
raise NotImplementedError('put the Resize transform after ToTensor')
|
89 |
+
|
90 |
+
_, img_h, img_w = input['image'].size()
|
91 |
+
|
92 |
+
if isinstance(self.size, int):
|
93 |
+
dst_h = self.size if img_h < img_w else int(self.size * img_h / img_w)
|
94 |
+
dst_w = self.size if img_w < img_h else int(self.size * img_w / img_h)
|
95 |
+
else:
|
96 |
+
dst_h, dst_w = self.size
|
97 |
+
|
98 |
+
input['image'] = super().__call__(input['image'])
|
99 |
+
|
100 |
+
if 'mask' in input:
|
101 |
+
input['mask'] = super().__call__(input['mask'])
|
102 |
+
|
103 |
+
sx, sy = dst_w / img_w, dst_h / img_h
|
104 |
+
|
105 |
+
if 'bbox' in input:
|
106 |
+
input['bbox'][:, (0, 2)] *= sx
|
107 |
+
input['bbox'][:, (1, 3)] *= sy
|
108 |
+
|
109 |
+
if 'tr_param' not in input:
|
110 |
+
input['tr_param'] = []
|
111 |
+
input['tr_param'].append({'resize': (sx, sy)})
|
112 |
+
|
113 |
+
return input
|
114 |
+
|
115 |
+
|
116 |
+
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
|
117 |
+
def __call__(self, input):
|
118 |
+
if not isinstance(input, dict):
|
119 |
+
return super().__call__(input)
|
120 |
+
|
121 |
+
assert 'image' in input
|
122 |
+
|
123 |
+
if not torch.is_tensor(input['image']):
|
124 |
+
raise NotImplementedError('use Resize after ToTensor')
|
125 |
+
|
126 |
+
result = super().__call__(input['image'])
|
127 |
+
if result is input['image']: # not flipped
|
128 |
+
return input
|
129 |
+
input['image'] = result
|
130 |
+
|
131 |
+
if 'mask' in input:
|
132 |
+
input['mask'] = torch.flip(input['mask'], dims=(-1,))
|
133 |
+
|
134 |
+
img_w = input['image'].size(2)
|
135 |
+
|
136 |
+
if 'bbox' in input:
|
137 |
+
input['bbox'][:, (0, 2)] = img_w - input['bbox'][:, (2, 0)]
|
138 |
+
|
139 |
+
if 'expr' in input:
|
140 |
+
input['expr'] = input['expr'].replace('left', '<LEFT>').replace('right', 'left').replace('<LEFT>', 'right')
|
141 |
+
|
142 |
+
return input
|
143 |
+
|
144 |
+
|
145 |
+
class RandomAffine(transforms.RandomAffine):
|
146 |
+
def get_params(self, *args, **kwargs):
|
147 |
+
self.params = super().get_params(*args, **kwargs)
|
148 |
+
return self.params
|
149 |
+
|
150 |
+
def __call__(self, input):
|
151 |
+
if not isinstance(input, dict):
|
152 |
+
return super().__call__(input)
|
153 |
+
|
154 |
+
assert 'image' in input
|
155 |
+
|
156 |
+
if not torch.is_tensor(input['image']):
|
157 |
+
raise NotImplementedError('put the Resize transform after ToTensor')
|
158 |
+
|
159 |
+
#self.fill = input['image'].mean((1,2)) # set fill value to the mean pixel value
|
160 |
+
result = super().__call__(input['image'])
|
161 |
+
if result is input['image']: # not transformed
|
162 |
+
return input
|
163 |
+
input['image'] = result
|
164 |
+
|
165 |
+
_, img_h, img_w = input['image'].size()
|
166 |
+
|
167 |
+
angle, translate, scale, shear = self.params
|
168 |
+
center = (img_w * 0.5, img_h * 0.5)
|
169 |
+
matrix = transforms.functional._get_inverse_affine_matrix(center, angle, translate, scale, shear)
|
170 |
+
matrix = torch.FloatTensor([matrix[:3], matrix[3:], [0, 0, 1]])
|
171 |
+
matrix = torch.linalg.inv(matrix)
|
172 |
+
|
173 |
+
if 'mask' in input:
|
174 |
+
input['mask'] = transforms.functional.affine(
|
175 |
+
input['mask'], *self.params, self.interpolation, self.fill
|
176 |
+
)
|
177 |
+
|
178 |
+
if 'bbox' in input:
|
179 |
+
for i, (x1, y1, x2, y2) in enumerate(input['bbox']):
|
180 |
+
pt = matrix @ torch.FloatTensor([
|
181 |
+
[x1, y1, 1],
|
182 |
+
[x2, y1, 1],
|
183 |
+
[x2, y2, 1],
|
184 |
+
[x1, y2, 1]
|
185 |
+
]).T
|
186 |
+
x_min, y_min, _ = pt.min(dim=1).values
|
187 |
+
x_max, y_max, _ = pt.max(dim=1).values
|
188 |
+
input['bbox'][i, :] = torch.FloatTensor([x_min, y_min, x_max, y_max])
|
189 |
+
|
190 |
+
# if 'tr_param' not in input:
|
191 |
+
# input['tr_param'] = []
|
192 |
+
# input['tr_param'].append({'random_affine': matrix[:2, :].tolist()})
|
193 |
+
|
194 |
+
return input
|
195 |
+
|
196 |
+
|
197 |
+
class ColorJitter(transforms.ColorJitter):
|
198 |
+
def __call__(self, input):
|
199 |
+
if not isinstance(input, dict):
|
200 |
+
return super().__call__(input)
|
201 |
+
assert 'image' in input
|
202 |
+
input['image'] = super().__call__(input['image'])
|
203 |
+
return input
|
204 |
+
|
205 |
+
|
206 |
+
def get_transform(split, input_size=512):
|
207 |
+
mean = [0.485, 0.456, 0.406]
|
208 |
+
sdev = [0.229, 0.224, 0.225]
|
209 |
+
|
210 |
+
if split in ('train', 'trainval'):
|
211 |
+
transform = Compose([
|
212 |
+
# ColorJitter(brightness=0.5, saturation=0.5), # before normalization
|
213 |
+
ToTensor(),
|
214 |
+
Normalize(mean, sdev), # first normalize so that the mean is ~0
|
215 |
+
SquarePad(), # zero pad (approx mean pixel value)
|
216 |
+
Resize(size=(input_size, input_size)),
|
217 |
+
# RandomHorizontalFlip(p=0.5),
|
218 |
+
RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
219 |
+
NormalizeBoxCoords(),
|
220 |
+
])
|
221 |
+
elif split in ('val', 'test', 'testA', 'testB', 'testC'):
|
222 |
+
transform = Compose([
|
223 |
+
ToTensor(),
|
224 |
+
Normalize(mean, sdev),
|
225 |
+
SquarePad(),
|
226 |
+
Resize(size=(input_size, input_size)),
|
227 |
+
NormalizeBoxCoords(),
|
228 |
+
])
|
229 |
+
elif split in ('visu',):
|
230 |
+
transform = Compose([
|
231 |
+
ToTensor(),
|
232 |
+
SquarePad(),
|
233 |
+
Resize(size=(input_size, input_size)),
|
234 |
+
NormalizeBoxCoords(),
|
235 |
+
])
|
236 |
+
else:
|
237 |
+
raise ValueError(f'\'{split}\' is not a valid data split')
|
238 |
+
|
239 |
+
return transform
|
240 |
+
|
241 |
+
|
242 |
+
def denormalize(img):
|
243 |
+
mean = [0.485, 0.456, 0.406]
|
244 |
+
sdev = [0.229, 0.224, 0.225]
|
245 |
+
return Normalize(
|
246 |
+
mean=[-m/s for m, s in zip(mean, sdev)], std=[1./s for s in sdev]
|
247 |
+
)(img)
|
248 |
+
|
249 |
+
|
250 |
+
def undo_box_transforms(bbox, tr_param):
|
251 |
+
# undo validation mode transformations
|
252 |
+
bbox = bbox.clone()
|
253 |
+
for tr in tr_param[::-1]:
|
254 |
+
if 'resize' in tr:
|
255 |
+
sx, sy = tr['resize']
|
256 |
+
bbox[:, (0, 2)] /= sx
|
257 |
+
bbox[:, (1, 3)] /= sy
|
258 |
+
elif 'square_pad' in tr:
|
259 |
+
px, py, _, _ = tr['square_pad']
|
260 |
+
bbox[:, (0, 2)] -= px
|
261 |
+
bbox[:, (1, 3)] -= py
|
262 |
+
elif 'normalize_box_coords' in tr:
|
263 |
+
img_h, img_w = tr['normalize_box_coords']
|
264 |
+
bbox[:, (0, 2)] *= img_w
|
265 |
+
bbox[:, (1, 3)] *= img_h
|
266 |
+
else:
|
267 |
+
continue
|
268 |
+
return bbox
|
269 |
+
|
270 |
+
|
271 |
+
def undo_box_transforms_batch(bbox, tr_param):
|
272 |
+
output = []
|
273 |
+
for i in range(bbox.size(0)):
|
274 |
+
bb = undo_box_transforms(torch.atleast_2d(bbox[i]), tr_param[i])
|
275 |
+
output.append(bb)
|
276 |
+
return torch.cat(output, dim=0)
|