dragonSwing commited on
Commit
5b31094
1 Parent(s): 78f7941

Add application files

Browse files
.gitignore ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,metals
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,metals
3
+
4
+ ### Metals ###
5
+ .metals/
6
+ .bloop/
7
+ project/**/metals.sbt
8
+
9
+ ### Python ###
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ #pdm.lock
116
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117
+ # in version control.
118
+ # https://pdm.fming.dev/#use-with-ide
119
+ .pdm.toml
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
170
+
171
+ ### Python Patch ###
172
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
173
+ poetry.toml
174
+
175
+ # ruff
176
+ .ruff_cache/
177
+
178
+ # LSP config files
179
+ pyrightconfig.json
180
+
181
+ ### VisualStudioCode ###
182
+ .vscode/*
183
+ !.vscode/settings.json
184
+ !.vscode/tasks.json
185
+ !.vscode/launch.json
186
+ !.vscode/extensions.json
187
+ !.vscode/*.code-snippets
188
+
189
+ # Local History for Visual Studio Code
190
+ .history/
191
+
192
+ # Built Visual Studio Code Extensions
193
+ *.vsix
194
+
195
+ ### VisualStudioCode Patch ###
196
+ # Ignore all local history of files
197
+ .history
198
+ .ionide
199
+
200
+ # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,metals
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Binh Le
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotate_anything.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ import tempfile
6
+
7
+ import numpy as np
8
+ import supervision as sv
9
+ from groundingdino.util.inference import Model as DinoModel
10
+ from imutils import paths
11
+ from PIL import Image
12
+ from segment_anything import sam_model_registry
13
+ from segment_anything import SamAutomaticMaskGenerator
14
+ from segment_anything import SamPredictor
15
+ from supervision.detection.utils import xywh_to_xyxy
16
+ from tqdm import tqdm
17
+
18
+ sys.path.append("tag2text")
19
+
20
+ from tag2text.models import tag2text
21
+ from config import *
22
+ from utils import detect, download_file_hf, segment, generate_tags, show_anns_sv
23
+
24
+
25
+ def process(
26
+ tag2text_model,
27
+ grounding_dino_model,
28
+ sam_predictor,
29
+ sam_automask_generator,
30
+ image_path,
31
+ task,
32
+ prompt,
33
+ box_threshold,
34
+ text_threshold,
35
+ iou_threshold,
36
+ device,
37
+ output_dir=None,
38
+ save_mask=False,
39
+ ):
40
+ detections = None
41
+ metadata = {"image": {}, "annotations": [], "assets": {}}
42
+
43
+ if save_mask:
44
+ metadata["assets"]["intermediate_mask"] = []
45
+
46
+ try:
47
+ # Load image
48
+ image = Image.open(image_path)
49
+ image_pil = image.convert("RGB")
50
+ image = np.array(image_pil)
51
+
52
+ # Extract image metadata
53
+ filename = os.path.basename(image_path)
54
+ basename = os.path.splitext(filename)[0]
55
+ h, w = image.shape[:2]
56
+ metadata["image"]["file_name"] = filename
57
+ metadata["image"]["width"] = w
58
+ metadata["image"]["height"] = h
59
+
60
+ # Generate tags
61
+ if task in ["auto", "detection"] and prompt == "":
62
+ tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
63
+ prompt = " . ".join(tags)
64
+ # print(f"Caption: {caption}")
65
+ # print(f"Tags: {tags}")
66
+
67
+ # ToDo: Extract metadata
68
+ metadata["image"]["caption"] = caption
69
+ metadata["image"]["tags"] = tags
70
+
71
+ if prompt:
72
+ metadata["prompt"] = prompt
73
+
74
+ # Detect boxes
75
+ if prompt != "":
76
+ detections, _, classes = detect(
77
+ grounding_dino_model,
78
+ image,
79
+ caption=prompt,
80
+ box_threshold=box_threshold,
81
+ text_threshold=text_threshold,
82
+ iou_threshold=iou_threshold,
83
+ post_process=True,
84
+ )
85
+
86
+ # Save detection image
87
+ if output_dir:
88
+ # Draw boxes
89
+ box_annotator = sv.BoxAnnotator()
90
+ labels = [
91
+ f"{classes[class_id] if class_id else 'Unkown'} {confidence:0.2f}"
92
+ for _, _, confidence, class_id, _ in detections
93
+ ]
94
+ box_image = box_annotator.annotate(
95
+ scene=image, detections=detections, labels=labels
96
+ )
97
+ box_image_path = os.path.join(output_dir, basename + "_detect.png")
98
+ metadata["assets"]["detection"] = box_image_path
99
+ Image.fromarray(box_image).save(box_image_path)
100
+
101
+ # Segmentation
102
+ if task in ["auto", "segment"]:
103
+ if detections:
104
+ masks, scores = segment(
105
+ sam_predictor, image=image, boxes=detections.xyxy
106
+ )
107
+ detections.mask = masks
108
+ else:
109
+ masks = sam_automask_generator.generate(image)
110
+ sorted_generated_masks = sorted(
111
+ masks, key=lambda x: x["area"], reverse=True
112
+ )
113
+
114
+ xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
115
+ mask = np.array(
116
+ [mask["segmentation"] for mask in sorted_generated_masks]
117
+ )
118
+ scores = np.array(
119
+ [mask["predicted_iou"] for mask in sorted_generated_masks]
120
+ )
121
+ detections = sv.Detections(
122
+ xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
123
+ )
124
+
125
+ # Save annotated image
126
+ if output_dir:
127
+ mask_annotator = sv.MaskAnnotator()
128
+ mask_image, res = show_anns_sv(detections)
129
+ annotated_image = mask_annotator.annotate(image, detections=detections)
130
+
131
+ mask_image_path = os.path.join(output_dir, basename + "_mask.png")
132
+ metadata["assets"]["mask"] = mask_image_path
133
+ Image.fromarray(mask_image).save(mask_image_path)
134
+
135
+ # Save annotation encoding from https://github.com/LUSSeg/ImageNet-S
136
+ mask_enc_path = os.path.join(output_dir, basename + "_mask_enc.npy")
137
+ np.save(mask_enc_path, res)
138
+ metadata["assets"]["mask_enc"] = mask_enc_path
139
+
140
+ annotated_image_path = os.path.join(
141
+ output_dir, basename + "_annotate.png"
142
+ )
143
+ metadata["assets"]["annotate"] = annotated_image_path
144
+ Image.fromarray(annotated_image).save(annotated_image_path)
145
+
146
+ # ToDo: Extract metadata
147
+ if detections:
148
+ id = 1
149
+ for (xyxy, mask, confidence, class_id, _), area, box_area, score in zip(
150
+ detections, detections.area, detections.box_area, scores
151
+ ):
152
+ annotation = {
153
+ "id": id,
154
+ "bbox": [int(x) for x in xyxy],
155
+ "box_area": float(box_area),
156
+ }
157
+ if class_id:
158
+ annotation["box_confidence"] = float(confidence)
159
+ annotation["label"] = classes[class_id] if class_id else "Unkown"
160
+ if mask is not None:
161
+ annotation["area"] = int(area)
162
+ annotation["predicted_iou"] = float(score)
163
+ metadata["annotations"].append(annotation)
164
+
165
+ if output_dir and save_mask:
166
+ mask_image_path = os.path.join(
167
+ output_dir, f"{basename}_mask_{id}.png"
168
+ )
169
+ metadata["assets"]["intermediate_mask"].append(mask_image_path)
170
+ Image.fromarray(mask * 255).save(mask_image_path)
171
+
172
+ id += 1
173
+
174
+ if output_dir:
175
+ meta_file_path = os.path.join(output_dir, basename + "_meta.json")
176
+ with open(meta_file_path, "w") as fp:
177
+ json.dump(metadata, fp)
178
+ else:
179
+ meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
180
+ meta_file_path = meta_file.name
181
+
182
+ return meta_file_path
183
+ except Exception as error:
184
+ raise ValueError(f"global exception: {error}")
185
+
186
+
187
+ def main(args: argparse.Namespace) -> None:
188
+ device = args.device
189
+ prompt = args.prompt
190
+ task = args.task
191
+
192
+ tag2text_model = None
193
+ grounding_dino_model = None
194
+ sam_predictor = None
195
+ sam_automask_generator = None
196
+
197
+ box_threshold = args.box_threshold
198
+ text_threshold = args.text_threshold
199
+ iou_threshold = args.iou_threshold
200
+ save_mask = args.save_mask
201
+
202
+ # load model
203
+ if task in ["auto", "detection"] and prompt == "":
204
+ print("Loading Tag2Text model...")
205
+ tag2text_type = args.tag2text
206
+ tag2text_checkpoint = os.path.join(
207
+ abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
208
+ )
209
+ if not os.path.exists(tag2text_checkpoint):
210
+ print(f"Downloading weights for Tag2Text {tag2text_type} model")
211
+ os.system(
212
+ f"wget {tag2text_dict[tag2text_type]['checkpoint_url']} -O {tag2text_checkpoint}"
213
+ )
214
+ tag2text_model = tag2text.tag2text_caption(
215
+ pretrained=tag2text_checkpoint,
216
+ image_size=384,
217
+ vit="swin_b",
218
+ delete_tag_index=delete_tag_index,
219
+ )
220
+ # threshold for tagging
221
+ # we reduce the threshold to obtain more tags
222
+ tag2text_model.threshold = 0.64
223
+ tag2text_model.to(device)
224
+ tag2text_model.eval()
225
+
226
+ if task in ["auto", "detection"] or prompt != "":
227
+ print("Loading Grounding Dino model...")
228
+ dino_type = args.dino
229
+ dino_checkpoint = os.path.join(
230
+ abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
231
+ )
232
+ dino_config_file = os.path.join(
233
+ abs_weight_dir, dino_dict[dino_type]["config_file"]
234
+ )
235
+ if not os.path.exists(dino_checkpoint):
236
+ print(f"Downloading weights for Grounding Dino {dino_type} model")
237
+ dino_repo_id = dino_dict[dino_type]["repo_id"]
238
+ download_file_hf(
239
+ repo_id=dino_repo_id,
240
+ filename=dino_dict[dino_type]["checkpoint_file"],
241
+ cache_dir=weight_dir,
242
+ )
243
+ download_file_hf(
244
+ repo_id=dino_repo_id,
245
+ filename=dino_dict[dino_type]["checkpoint_file"],
246
+ cache_dir=weight_dir,
247
+ )
248
+ grounding_dino_model = DinoModel(
249
+ model_config_path=dino_config_file, model_checkpoint_path=dino_checkpoint
250
+ )
251
+
252
+ if task in ["auto", "segment"]:
253
+ print("Loading SAM...")
254
+ sam_type = args.sam
255
+ sam_checkpoint = os.path.join(
256
+ abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
257
+ )
258
+ if not os.path.exists(sam_checkpoint):
259
+ print(f"Downloading weights for SAM {sam_type}")
260
+ os.system(
261
+ f"wget {sam_dict[sam_type]['checkpoint_url']} -O {sam_checkpoint}"
262
+ )
263
+ sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
264
+ sam.to(device=device)
265
+ sam_predictor = SamPredictor(sam)
266
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
267
+
268
+ if not os.path.exists(args.input):
269
+ raise ValueError("The input directory doesn't exist!")
270
+ elif not os.path.isdir(args.input):
271
+ image_paths = [args.input]
272
+ else:
273
+ image_paths = paths.list_images(args.input)
274
+
275
+ os.makedirs(args.output, exist_ok=True)
276
+
277
+ with tqdm(image_paths) as pbar:
278
+ for image_path in pbar:
279
+ pbar.set_postfix_str(f"Processing {image_path}")
280
+ process(
281
+ tag2text_model=tag2text_model,
282
+ grounding_dino_model=grounding_dino_model,
283
+ sam_predictor=sam_predictor,
284
+ sam_automask_generator=sam_automask_generator,
285
+ image_path=image_path,
286
+ task=task,
287
+ prompt=prompt,
288
+ box_threshold=box_threshold,
289
+ text_threshold=text_threshold,
290
+ iou_threshold=iou_threshold,
291
+ device=device,
292
+ output_dir=args.output,
293
+ save_mask=save_mask,
294
+ )
295
+
296
+
297
+ if __name__ == "__main__":
298
+ if not os.path.exists(abs_weight_dir):
299
+ os.makedirs(abs_weight_dir, exist_ok=True)
300
+
301
+ parser = argparse.ArgumentParser(
302
+ description=(
303
+ "Runs automatic detection and mask generation on an input image or directory of images"
304
+ )
305
+ )
306
+
307
+ parser.add_argument(
308
+ "--input",
309
+ "-i",
310
+ type=str,
311
+ required=True,
312
+ help="Path to either a single input image or folder of images.",
313
+ )
314
+
315
+ parser.add_argument(
316
+ "--output",
317
+ "-o",
318
+ type=str,
319
+ required=True,
320
+ help=(
321
+ "Path to the directory where masks will be output. Output will be either a folder "
322
+ "of PNGs per image or a single json with COCO-style masks."
323
+ ),
324
+ )
325
+
326
+ parser.add_argument(
327
+ "--sam",
328
+ type=str,
329
+ default=default_sam,
330
+ choices=sam_dict.keys(),
331
+ help="The type of SA model to load",
332
+ )
333
+
334
+ parser.add_argument(
335
+ "--tag2text",
336
+ type=str,
337
+ default=default_tag2text,
338
+ choices=tag2text_dict.keys(),
339
+ help="The path to the Tag2Text checkpoint to use for tags and caption generation.",
340
+ )
341
+
342
+ parser.add_argument(
343
+ "--dino",
344
+ type=str,
345
+ default=default_dino,
346
+ choices=dino_dict.keys(),
347
+ help="The config file of Grounding Dino model to load",
348
+ )
349
+
350
+ parser.add_argument(
351
+ "--task",
352
+ help="Task to run",
353
+ default="auto",
354
+ choices=["auto", "detect", "segment"],
355
+ type=str,
356
+ )
357
+ parser.add_argument(
358
+ "--prompt",
359
+ help="Detection prompt",
360
+ default="",
361
+ type=str,
362
+ )
363
+
364
+ parser.add_argument(
365
+ "--box-threshold", type=float, default=0.25, help="box threshold"
366
+ )
367
+ parser.add_argument(
368
+ "--text-threshold", type=float, default=0.2, help="text threshold"
369
+ )
370
+ parser.add_argument(
371
+ "--iou-threshold", type=float, default=0.5, help="iou threshold"
372
+ )
373
+
374
+ parser.add_argument(
375
+ "--save-mask",
376
+ action="store_true",
377
+ default=False,
378
+ help="If True, save all intermidiate masks.",
379
+ )
380
+ parser.add_argument(
381
+ "--device", type=str, default="cuda", help="The device to run generation on."
382
+ )
383
+ args = parser.parse_args()
384
+ main(args)
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import tempfile
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import supervision as sv
9
+ import torch
10
+ from groundingdino.util.inference import Model as DinoModel
11
+ from PIL import Image
12
+ from segment_anything import build_sam
13
+ from segment_anything import SamAutomaticMaskGenerator
14
+ from segment_anything import SamPredictor
15
+ from supervision.detection.utils import mask_to_polygons
16
+ from supervision.detection.utils import xywh_to_xyxy
17
+
18
+ # segment anything
19
+ # Grounding DINO
20
+
21
+ sys.path.append("tag2text")
22
+
23
+ from tag2text.models import tag2text
24
+ from config import *
25
+ from utils import download_file_hf, detect, segment, show_anns, generate_tags
26
+
27
+ if not os.path.exists(abs_weight_dir):
28
+ os.makedirs(abs_weight_dir, exist_ok=True)
29
+
30
+ sam_checkpoint = os.path.join(abs_weight_dir, sam_dict[default_sam]["checkpoint_file"])
31
+ if not os.path.exists(sam_checkpoint):
32
+ os.system(f"wget {sam_dict[default_sam]['checkpoint_url']} -O {sam_checkpoint}")
33
+
34
+ tag2text_checkpoint = os.path.join(
35
+ abs_weight_dir, tag2text_dict[default_tag2text]["checkpoint_file"]
36
+ )
37
+ if not os.path.exists(tag2text_checkpoint):
38
+ os.system(
39
+ f"wget {tag2text_dict[default_tag2text]['checkpoint_url']} -O {tag2text_checkpoint}"
40
+ )
41
+
42
+ dino_checkpoint = os.path.join(
43
+ abs_weight_dir, dino_dict[default_dino]["checkpoint_file"]
44
+ )
45
+ dino_config_file = os.path.join(abs_weight_dir, dino_dict[default_dino]["config_file"])
46
+ if not os.path.exists(dino_checkpoint):
47
+ dino_repo_id = dino_dict[default_dino]["repo_id"]
48
+ download_file_hf(
49
+ repo_id=dino_repo_id,
50
+ filename=dino_dict[default_dino]["config_file"],
51
+ cache_dir=weight_dir,
52
+ )
53
+ download_file_hf(
54
+ repo_id=dino_repo_id,
55
+ filename=dino_dict[default_dino]["checkpoint_file"],
56
+ cache_dir=weight_dir,
57
+ )
58
+
59
+ # load model
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ tag2text_model = tag2text.tag2text_caption(
62
+ pretrained=tag2text_checkpoint,
63
+ image_size=384,
64
+ vit="swin_b",
65
+ delete_tag_index=delete_tag_index,
66
+ )
67
+ # threshold for tagging
68
+ # we reduce the threshold to obtain more tags
69
+ tag2text_model.threshold = 0.64
70
+ tag2text_model.to(device)
71
+ tag2text_model.eval()
72
+
73
+
74
+ sam = build_sam(checkpoint=sam_checkpoint)
75
+ sam.to(device=device)
76
+ sam_predictor = SamPredictor(sam)
77
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
78
+
79
+ grounding_dino_model = DinoModel(
80
+ model_config_path=dino_config_file, model_checkpoint_path=dino_checkpoint
81
+ )
82
+
83
+
84
+ def process(image_path, task, prompt, box_threshold, text_threshold, iou_threshold):
85
+ global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
86
+ output_gallery = []
87
+ detections = None
88
+ metadata = {"image": {}, "annotations": []}
89
+
90
+ try:
91
+ # Load image
92
+ image = Image.open(image_path)
93
+ image_pil = image.convert("RGB")
94
+ image = np.array(image_pil)
95
+
96
+ # Extract image metadata
97
+ filename = os.path.basename(image_path)
98
+ h, w = image.shape[:2]
99
+ metadata["image"]["file_name"] = filename
100
+ metadata["image"]["width"] = w
101
+ metadata["image"]["height"] = h
102
+
103
+ # Generate tags
104
+ if task in ["auto", "detection"] and prompt == "":
105
+ tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
106
+ prompt = " . ".join(tags)
107
+ print(f"Caption: {caption}")
108
+ print(f"Tags: {tags}")
109
+
110
+ # ToDo: Extract metadata
111
+ metadata["image"]["caption"] = caption
112
+ metadata["image"]["tags"] = tags
113
+
114
+ if prompt:
115
+ metadata["prompt"] = prompt
116
+ print(f"Prompt: {prompt}")
117
+
118
+ # Detect boxes
119
+ if prompt != "":
120
+ detections, phrases, classes = detect(
121
+ grounding_dino_model,
122
+ image,
123
+ caption=prompt,
124
+ box_threshold=box_threshold,
125
+ text_threshold=text_threshold,
126
+ iou_threshold=iou_threshold,
127
+ post_process=True,
128
+ )
129
+
130
+ # Draw boxes
131
+ box_annotator = sv.BoxAnnotator()
132
+ labels = [
133
+ f"{classes[class_id] if class_id else 'Unkown'} {confidence:0.2f}"
134
+ for _, _, confidence, class_id, _ in detections
135
+ ]
136
+ image = box_annotator.annotate(
137
+ scene=image, detections=detections, labels=labels
138
+ )
139
+ output_gallery.append(image)
140
+
141
+ # Segmentation
142
+ if task in ["auto", "segment"]:
143
+ if detections:
144
+ masks, scores = segment(
145
+ sam_predictor, image=image, boxes=detections.xyxy
146
+ )
147
+ detections.mask = masks
148
+ else:
149
+ masks = sam_automask_generator.generate(image)
150
+ sorted_generated_masks = sorted(
151
+ masks, key=lambda x: x["area"], reverse=True
152
+ )
153
+
154
+ xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
155
+ mask = np.array(
156
+ [mask["segmentation"] for mask in sorted_generated_masks]
157
+ )
158
+ scores = np.array(
159
+ [mask["predicted_iou"] for mask in sorted_generated_masks]
160
+ )
161
+ detections = sv.Detections(
162
+ xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
163
+ )
164
+ # opacity = 0.4
165
+ # mask_image, _ = show_anns_sam(masks)
166
+ # annotated_image = np.uint8(mask_image * opacity + image * (1 - opacity))
167
+
168
+ mask_annotator = sv.MaskAnnotator()
169
+ mask_image = np.zeros_like(image, dtype=np.uint8)
170
+ mask_image = mask_annotator.annotate(
171
+ mask_image, detections=detections, opacity=1
172
+ )
173
+ annotated_image = mask_annotator.annotate(image, detections=detections)
174
+ output_gallery.append(mask_image)
175
+ output_gallery.append(annotated_image)
176
+
177
+ # ToDo: Extract metadata
178
+ if detections:
179
+ id = 1
180
+ for (xyxy, mask, confidence, class_id, _), area, box_area, score in zip(
181
+ detections, detections.area, detections.box_area, scores
182
+ ):
183
+ annotation = {
184
+ "id": id,
185
+ "bbox": [int(x) for x in xyxy],
186
+ "box_area": float(box_area),
187
+ }
188
+ if class_id:
189
+ annotation["box_confidence"] = float(confidence)
190
+ annotation["label"] = classes[class_id] if class_id else "Unkown"
191
+ if mask is not None:
192
+ # annotation["segmentation"] = mask_to_polygons(mask)
193
+ annotation["area"] = int(area)
194
+ annotation["predicted_iou"] = float(score)
195
+ metadata["annotations"].append(annotation)
196
+ id += 1
197
+
198
+ meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
199
+ meta_file_path = meta_file.name
200
+ with open(meta_file_path, "w") as fp:
201
+ json.dump(metadata, fp)
202
+
203
+ return output_gallery, meta_file_path
204
+ except Exception as error:
205
+ raise gr.Error(f"global exception: {error}")
206
+
207
+
208
+ title = "Annotate Anything"
209
+
210
+ with gr.Blocks(css="style.css", title=title) as demo:
211
+ with gr.Row(elem_classes=["container"]):
212
+ with gr.Column(scale=1):
213
+ input_image = gr.Image(type="filepath", label="Input")
214
+ task = gr.Dropdown(
215
+ ["detect", "segment", "auto"], value="auto", label="task_type"
216
+ )
217
+ text_prompt = gr.Textbox(label="Detection Prompt")
218
+ with gr.Accordion("Advanced parameters", open=False):
219
+ box_threshold = gr.Slider(
220
+ minimum=0,
221
+ maximum=1,
222
+ value=0.3,
223
+ step=0.05,
224
+ label="Box threshold",
225
+ info="Hash size to use for image hashing",
226
+ )
227
+ text_threshold = gr.Slider(
228
+ minimum=0,
229
+ maximum=1,
230
+ value=0.25,
231
+ step=0.05,
232
+ label="Text threshold",
233
+ info="Number of history images used to find out duplicate image",
234
+ )
235
+ iou_threshold = gr.Slider(
236
+ minimum=0,
237
+ maximum=1,
238
+ value=0.5,
239
+ step=0.05,
240
+ label="IOU threshold",
241
+ info="Minimum similarity threshold (in percent) to consider 2 images to be similar",
242
+ )
243
+ run_button = gr.Button(label="Run")
244
+
245
+ with gr.Column(scale=2):
246
+ gallery = gr.Gallery(
247
+ label="Generated images", show_label=False, elem_id="gallery"
248
+ ).style(preview=True, grid=2, object_fit="scale-down")
249
+ meta_file = gr.File(label="Metadata file")
250
+
251
+ with gr.Row(elem_classes=["container"]):
252
+ gr.Examples(
253
+ [
254
+ ["examples/dog.png", "auto", ""],
255
+ ["examples/eiffel.png", "auto", ""],
256
+ ["examples/eiffel.png", "segment", ""],
257
+ ["examples/girl.png", "auto", "girl . face"],
258
+ ["examples/horse.png", "detect", "horse"],
259
+ ["examples/horses.jpg", "auto", "horse"],
260
+ ["examples/traffic.jpg", "auto", ""],
261
+ ],
262
+ [input_image, task, text_prompt],
263
+ )
264
+ run_button.click(
265
+ fn=process,
266
+ inputs=[
267
+ input_image,
268
+ task,
269
+ text_prompt,
270
+ box_threshold,
271
+ text_threshold,
272
+ iou_threshold,
273
+ ],
274
+ outputs=[gallery, meta_file],
275
+ )
276
+
277
+ demo.queue(concurrency_count=2).launch()
config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Configurations
4
+ tag2text_dict = {
5
+ "swin_14m": {
6
+ "checkpoint_url": "https://huggingface.co/spaces/xinyu1205/Tag2Text/resolve/main/tag2text_swin_14m.pth",
7
+ "checkpoint_file": "tag2text_swin_14m.pth",
8
+ }
9
+ }
10
+
11
+ sam_dict = {
12
+ "default": {
13
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
14
+ "checkpoint_file": "sam_vit_h_4b8939.pth",
15
+ },
16
+ "vit_h": {
17
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
18
+ "checkpoint_file": "sam_vit_h_4b8939.pth",
19
+ },
20
+ "vit_l": {
21
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
22
+ "checkpoint_file": "sam_vit_l_0b3195.pth",
23
+ },
24
+ "vit_b": {
25
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
26
+ "checkpoint_file": "sam_vit_b_01ec64.pth",
27
+ },
28
+ }
29
+
30
+ dino_dict = {
31
+ "swinb": {
32
+ "repo_id": "ShilongLiu/GroundingDINO",
33
+ "config_file": "GroundingDINO_SwinB.cfg.py",
34
+ "checkpoint_file": "groundingdino_swinb_cogcoor.pth",
35
+ },
36
+ "swint_ogc": {
37
+ "repo_id": "ShilongLiu/GroundingDINO",
38
+ "config_file": "GroundingDINO_SwinT_OGC.cfg.py",
39
+ "checkpoint_file": "groundingdino_swint_ogc.pth",
40
+ },
41
+ }
42
+
43
+ default_sam = "default"
44
+ default_tag2text = "swin_14m"
45
+ default_dino = "swint_ogc"
46
+
47
+ root_dir = os.path.dirname(os.path.abspath(__file__))
48
+ weight_dir = "weights"
49
+ abs_weight_dir = os.path.join(root_dir, weight_dir)
50
+
51
+ tag2text_checkpoint = "tag2text_swin_14m.pth"
52
+ tag2text_url = "https://huggingface.co/spaces/xinyu1205/Tag2Text/resolve/main/tag2text_swin_14m.pth"
53
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
54
+ sam_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
55
+ output_dir = "outputs"
56
+
57
+ dino_config_file = "GroundingDINO_SwinB.cfg.py"
58
+ dino_repo_id = "ShilongLiu/GroundingDINO"
59
+ dino_checkpoint = "groundingdino_swinb_cogcoor.pth"
60
+
61
+ iou_threshold = 0.5
62
+ box_threshold = 0.3
63
+ text_threshold = 0.25
64
+
65
+ # filter out attributes and action categories which are difficult to grounding
66
+ delete_tag_index = []
67
+ for i in range(3012, 3429):
68
+ delete_tag_index.append(i)
examples/dog.png ADDED
examples/eiffel.jpg ADDED
examples/eiffel.png ADDED
examples/girl.png ADDED
examples/horse.png ADDED
examples/horses.jpg ADDED
examples/traffic.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ addict
3
+ gradio
4
+ huggingface_hub
5
+ matplotlib
6
+ numpy
7
+ onnxruntime
8
+ opencv_python
9
+ Pillow
10
+ pycocotools
11
+ pycocoevalcap
12
+ PyYAML
13
+ requests
14
+ setuptools
15
+ supervision
16
+ termcolor
17
+ timm
18
+ torch
19
+ torchvision
20
+ transformers
21
+ yapf
22
+ numba
23
+ scipy
24
+ safetensors
25
+ pynvml
26
+ fairscale
27
+ imutils
28
+ argparse
29
+ tqdm
30
+ git+https://github.com/facebookresearch/segment-anything.git
31
+ git+https://github.com/IDEA-Research/GroundingDINO
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .container {
2
+ max-width: 1368px;
3
+ margin-left: auto;
4
+ margin-right: auto;
5
+ }
6
+
7
+ #row-flex {
8
+ display: flex;
9
+ align-items: center;
10
+ justify-content: center;
11
+ }
tag2text/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 OPPO LLC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
tag2text/README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :label: Tag2Text: Guiding Vision-Language Model via Image Tagging
2
+
3
+ Official PyTorch Implementation of the <a href="https://arxiv.org/abs/2303.05657">Tag2Text</a>, an efficient and controllable vision-language model with tagging guidance. Code is available now!
4
+
5
+ Welcome to try out [Tag2Text Web demo🤗](https://huggingface.co/spaces/xinyu1205/Tag2Text)! Both Tagging and Captioning are included.
6
+
7
+ Tag2Text now is combine with [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything), which can automatically recognize, detect, and segment for an image! Tag2Text showcases powerful image recognition capabilities:
8
+ ![](./images/tag2text_grounded_sam.jpg)
9
+
10
+ ## :fire: News
11
+
12
+ - **`2023/05/20`**: Tag2Text is combined with [VideoChat](https://github.com/OpenGVLab/Ask-Anything), Tag2Text provides powerful tagging and captioning capabilities as a fundamental component!
13
+ - **`2023/04/20`**: We marry [Tag2Text with with Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) to provide powerful image recognition capabilities!
14
+ - **`2023/04/10`**: Code and checkpoint is available Now!
15
+ - **`2023/03/14`**: [Tag2Text web demo 🤗](https://huggingface.co/spaces/xinyu1205/Tag2Text) is available on Hugging Face Space!
16
+
17
+ ## :bulb: Highlight
18
+
19
+ - **Tagging.** Without manual annotations, Tag2Text achieves **superior** image tag recognition ability of [**3,429**](./data/tag_list.txt) commonly human-used categories.
20
+ - **Efficient.** Tagging guidance effectively enhances the performance of vision-language models on both **generation-based** and **alignment-based** tasks.
21
+ - **Controllable.** Tag2Text permits users to input **desired tags**, providing the flexibility in composing corresponding texts based on the input tags.
22
+
23
+ <p align="center">
24
+ <table class="tg">
25
+ <tr>
26
+ <td class="tg-c3ow"><img src="images/tag2text_framework.png" align="center" width="800" ></td>
27
+ </tr>
28
+ </table>
29
+ </p>
30
+
31
+ ## :writing_hand: TODO
32
+
33
+ - [x] Release demo.
34
+ - [x] Release checkpoints.
35
+ - [x] Release inference code.
36
+ - [ ] Release training codes.
37
+ - [ ] Release training datasets.
38
+
39
+ ## :toolbox: Checkpoints
40
+
41
+ <!-- insert a table -->
42
+
43
+ <table>
44
+ <thead>
45
+ <tr style="text-align: right;">
46
+ <th></th>
47
+ <th>name</th>
48
+ <th>backbone</th>
49
+ <th>Data</th>
50
+ <th>Illustration</th>
51
+ <th>Checkpoint</th>
52
+ </tr>
53
+ </thead>
54
+ <tbody>
55
+ <tr>
56
+ <th>1</th>
57
+ <td>Tag2Text-Swin</td>
58
+ <td>Swin-Base</td>
59
+ <td>COCO, VG, SBU, CC-3M, CC-12M</td>
60
+ <td>Demo version with comprehensive captions.</td>
61
+ <td><a href="https://huggingface.co/spaces/xinyu1205/Tag2Text/blob/main/tag2text_swin_14m.pth">Download link</a></td>
62
+ </tr>
63
+ </tbody>
64
+ </table>
65
+
66
+ ## :running: Model Inference
67
+
68
+ 1. Install the dependencies, run:
69
+
70
+ <pre/>pip install -r requirements.txt</pre>
71
+
72
+ 2. Download Tag2Text pretrained checkpoints.
73
+
74
+ 1. Get the tagging and captioning results:
75
+ <pre/>
76
+ python inference.py --image images/1641173_2291260800.jpg \
77
+ --pretrained pretrained/tag2text_swin_14m.pth
78
+ </pre>
79
+ Or get the tagging and sepcifed captioning results (optional):
80
+ <pre/>python inference.py --image images/1641173_2291260800.jpg \
81
+ --pretrained pretrained/tag2text_swin_14m.pth \
82
+ --specified-tags "cloud,sky"</pre>
83
+
84
+ ## :black_nib: Citation
85
+
86
+ If you find our work to be useful for your research, please consider citing.
87
+
88
+ ```
89
+ @article{huang2023tag2text,
90
+ title={Tag2Text: Guiding Vision-Language Model via Image Tagging},
91
+ author={Huang, Xinyu and Zhang, Youcai and Ma, Jinyu and Tian, Weiwei and Feng, Rui and Zhang, Yuejie and Li, Yaqian and Guo, Yandong and Zhang, Lei},
92
+ journal={arXiv preprint arXiv:2303.05657},
93
+ year={2023}
94
+ }
95
+ ```
96
+
97
+ ## :hearts: Acknowledgements
98
+
99
+ This work is done with the help of the amazing code base of [BLIP](https://github.com/salesforce/BLIP), thanks very much!
100
+
101
+ We also want to thank @Cheng Rui @Shilong Liu @Ren Tianhe for their help in [marrying Tag2Text with Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything).
tag2text/configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
tag2text/configs/q2l_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 4,
15
+ "num_hidden_layers": 2,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true,
21
+ "add_tag_cross_attention": false
22
+ }
tag2text/configs/swin/config_swinB_384.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
+ "vision_width": 1024,
4
+ "image_res": 384,
5
+ "window_size": 12,
6
+ "embed_dim": 128,
7
+ "depths": [ 2, 2, 18, 2 ],
8
+ "num_heads": [ 4, 8, 16, 32 ]
9
+ }
tag2text/data/tag_list.txt ADDED
@@ -0,0 +1,3429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tennis
2
+ bear cub
3
+ observatory
4
+ bicycle
5
+ hillside
6
+ judge
7
+ watercolor illustration
8
+ granite
9
+ lobster
10
+ livery
11
+ stone
12
+ ceramic
13
+ ranch
14
+ cloth
15
+ smile
16
+ building
17
+ tattoo
18
+ cricketer
19
+ cheek
20
+ pear
21
+ source
22
+ winter
23
+ surface
24
+ spray
25
+ ceremony
26
+ magic
27
+ curve
28
+ container
29
+ fair
30
+ medicine
31
+ baby
32
+ tennis racquet
33
+ ornament
34
+ bamboo
35
+ duckling
36
+ song
37
+ safari
38
+ team presentation
39
+ daffodil
40
+ cross
41
+ toothpaste
42
+ shield
43
+ fashion model
44
+ capsule
45
+ map
46
+ creek
47
+ glass house
48
+ glass plate
49
+ siding
50
+ corner
51
+ water buffalo
52
+ bison
53
+ figure skater
54
+ diploma
55
+ tire
56
+ race
57
+ cable car
58
+ brain
59
+ gas stove
60
+ soap bubble
61
+ palette
62
+ snowboard
63
+ school child
64
+ trench coat
65
+ monk
66
+ fiber
67
+ kitchen window
68
+ sunglass
69
+ coffee
70
+ security
71
+ strawberry
72
+ penguin
73
+ tree root
74
+ loaf
75
+ engagement ring
76
+ lamb
77
+ vector cartoon illustration
78
+ sandwich
79
+ mountain village
80
+ shape
81
+ charm
82
+ fiction
83
+ knot
84
+ greenhouse
85
+ sushi
86
+ text
87
+ disaster
88
+ trophy
89
+ gang
90
+ strap
91
+ soccer game
92
+ cardinal
93
+ tee
94
+ turtle
95
+ water surface
96
+ grassland
97
+ dolphin
98
+ store
99
+ dirt
100
+ iceberg
101
+ pergola
102
+ farmer market
103
+ publicity portrait
104
+ tote bag
105
+ teenage girl
106
+ view mirror
107
+ session
108
+ commuter
109
+ dressing room
110
+ tricycle
111
+ christmas ball
112
+ headlight
113
+ police
114
+ armchair
115
+ chart
116
+ yacht
117
+ saw
118
+ printer
119
+ rock band
120
+ gingerbread house
121
+ tag
122
+ table lamp
123
+ hockey game
124
+ slope
125
+ font
126
+ wicker basket
127
+ jewelry
128
+ quarter
129
+ software
130
+ weapon
131
+ pin
132
+ worship
133
+ painter
134
+ goal
135
+ morning light
136
+ bike
137
+ baseball bat
138
+ elevator
139
+ cuisine
140
+ sausage
141
+ stunt
142
+ wrestler
143
+ statue
144
+ landing
145
+ pillar
146
+ willow tree
147
+ sea wave
148
+ chicken
149
+ peanut
150
+ muscle
151
+ bob
152
+ tv genre
153
+ bathroom window
154
+ radish
155
+ textile
156
+ pelican
157
+ marketplace
158
+ crest
159
+ elevation map
160
+ gift
161
+ parish
162
+ traffic light
163
+ campfire
164
+ fog
165
+ award winner
166
+ beach ball
167
+ mat
168
+ white house
169
+ plaster
170
+ moped
171
+ football team
172
+ solution
173
+ bicyclist
174
+ bit
175
+ playground
176
+ darkness
177
+ cake
178
+ maple leave
179
+ mold
180
+ cracker
181
+ blueberry
182
+ rubble
183
+ container ship
184
+ pedestrian bridge
185
+ snail
186
+ parrot
187
+ form
188
+ circuit
189
+ highlight
190
+ pickup truck
191
+ koala
192
+ rain
193
+ system
194
+ weather
195
+ raincoat
196
+ soccer team
197
+ windshield
198
+ thunderstorm
199
+ mike
200
+ bird house
201
+ bridge
202
+ grandfather
203
+ restroom
204
+ animation
205
+ wilderness
206
+ clown
207
+ banana
208
+ brown
209
+ braid
210
+ dining room
211
+ kindergarten
212
+ launch event
213
+ purple
214
+ school
215
+ stairwell
216
+ brooch
217
+ movie poster image
218
+ mountain river
219
+ shelf
220
+ wicket
221
+ headboard
222
+ buddha
223
+ flower field
224
+ dugout
225
+ cd
226
+ bald eagle
227
+ lagoon
228
+ seaweed
229
+ agriculture
230
+ emergency service
231
+ maple tree
232
+ parachute
233
+ continent
234
+ amusement park
235
+ remote
236
+ bun
237
+ tackle
238
+ hospital
239
+ garage door
240
+ birthday party
241
+ friendship
242
+ go
243
+ mausoleum
244
+ jeep
245
+ raccoon
246
+ step
247
+ ice hockey team
248
+ cigarette
249
+ lace dress
250
+ forest floor
251
+ mall
252
+ captain
253
+ milk
254
+ golf course
255
+ meal
256
+ picnic table
257
+ sail
258
+ volleyball
259
+ canal
260
+ terrace
261
+ computer desk
262
+ caravan
263
+ hotel
264
+ cheerleader
265
+ nurse
266
+ museum
267
+ marsh
268
+ fox
269
+ plateau
270
+ night
271
+ twin
272
+ letter logo
273
+ autumn tree
274
+ powder
275
+ convention
276
+ creature
277
+ lighthouse
278
+ shop window
279
+ jacket
280
+ stork
281
+ taxi
282
+ trade
283
+ blackboard
284
+ olive
285
+ road sign
286
+ resort
287
+ snowflake
288
+ cemetery
289
+ travel
290
+ evening dress
291
+ picnic
292
+ drink
293
+ winter morning
294
+ football player
295
+ snack
296
+ boxing glove
297
+ dinner party
298
+ airline
299
+ swing
300
+ port
301
+ wheelbarrow
302
+ bathroom sink
303
+ sweater
304
+ ambulance
305
+ gear
306
+ oil
307
+ wii controller
308
+ array
309
+ home office
310
+ car show
311
+ mixture
312
+ profession
313
+ tree frog
314
+ square
315
+ facility
316
+ coral reef
317
+ sea wall
318
+ pizza
319
+ exhibit
320
+ demolition
321
+ trout
322
+ ring
323
+ coffee shop
324
+ bracelet
325
+ bean
326
+ lip
327
+ fencing
328
+ landscape
329
+ sitting
330
+ package
331
+ metal
332
+ bust
333
+ king
334
+ hair
335
+ window seat
336
+ wildlife
337
+ trunk
338
+ greenery
339
+ stencil
340
+ fire hydrant
341
+ bridesmaid
342
+ plaza
343
+ alps
344
+ tower bridge
345
+ crop top
346
+ crossing
347
+ cinema
348
+ pedestrian crossing
349
+ family
350
+ shopping cart
351
+ stomach
352
+ church building
353
+ screen door
354
+ skater
355
+ soccer field
356
+ kettle
357
+ mussel
358
+ raindrop
359
+ candy cane
360
+ water lily
361
+ flower girl
362
+ desert
363
+ enclosure
364
+ christmas light
365
+ kitchen
366
+ caterpillar
367
+ plaid
368
+ bath
369
+ bush
370
+ mud
371
+ ballet
372
+ knee
373
+ adult
374
+ raft
375
+ sea view
376
+ cactus
377
+ office chair
378
+ overall
379
+ rim
380
+ scaffolding
381
+ pig
382
+ cover
383
+ poster page
384
+ sprinkle
385
+ chandelier
386
+ algae
387
+ traffic
388
+ surfboard
389
+ book
390
+ filming
391
+ flash
392
+ mansion
393
+ camouflage
394
+ trouser
395
+ ticket
396
+ weed
397
+ cab
398
+ trench
399
+ elephant
400
+ huddle
401
+ sphere
402
+ christmas decoration
403
+ city
404
+ launch
405
+ doll
406
+ christmas ornament
407
+ fabric
408
+ bikini
409
+ biplane
410
+ breakfast
411
+ neighbourhood
412
+ race track
413
+ foliage
414
+ avocado
415
+ school bus
416
+ footwear
417
+ highway
418
+ ocean view
419
+ art vector illustration
420
+ wall clock
421
+ curtain
422
+ teenager
423
+ kitchen area
424
+ robot
425
+ tusk
426
+ lounge chair
427
+ beam
428
+ paddle
429
+ camel
430
+ lid
431
+ world map
432
+ city view
433
+ newlywed
434
+ cargo ship
435
+ yellow
436
+ exhibition
437
+ bend
438
+ novel
439
+ wool
440
+ ontario
441
+ bread
442
+ campus
443
+ coastline
444
+ cutting board
445
+ booth
446
+ table top
447
+ carpet
448
+ beach chair
449
+ workout
450
+ street food
451
+ fun
452
+ costumer film designer
453
+ gadget
454
+ artist
455
+ fishing village
456
+ builder
457
+ violinist
458
+ iphone
459
+ spider web
460
+ traffic sign
461
+ ruin
462
+ rescue
463
+ clipboard
464
+ seal
465
+ film director
466
+ paw
467
+ nursery
468
+ intersection
469
+ tomato sauce
470
+ taste
471
+ paddy field
472
+ christmas tree
473
+ wave
474
+ stool
475
+ watering can
476
+ rug
477
+ daytime
478
+ subway station
479
+ craft
480
+ pine forest
481
+ black
482
+ planet
483
+ motif
484
+ christmas market
485
+ glass window
486
+ college
487
+ wheat
488
+ damage
489
+ rectangle
490
+ picture frame
491
+ chess
492
+ guest room
493
+ street corner
494
+ religion
495
+ seed
496
+ puzzle
497
+ freeway
498
+ beauty
499
+ ocean
500
+ watch
501
+ mother
502
+ garage
503
+ quote
504
+ dj
505
+ supporter
506
+ hip hop artist
507
+ muffin
508
+ eiffel tower
509
+ cash
510
+ firefighter
511
+ cauliflower
512
+ bunker
513
+ sled
514
+ manicure
515
+ shark
516
+ stall
517
+ jungle
518
+ family home
519
+ tour bus
520
+ chimney
521
+ touchdown
522
+ roundabout
523
+ coyote
524
+ street scene
525
+ tank
526
+ wedding dress
527
+ mantle
528
+ bedroom window
529
+ coconut
530
+ chapel
531
+ goat
532
+ living space
533
+ rock wall
534
+ polka dot
535
+ railway
536
+ mandala
537
+ mango
538
+ lesson
539
+ mountain landscape
540
+ team photo
541
+ bookshelf
542
+ meter
543
+ bulldog
544
+ evening sun
545
+ stick
546
+ card
547
+ pink
548
+ fish pond
549
+ paint
550
+ pill
551
+ cart
552
+ pea
553
+ van
554
+ album
555
+ football college game
556
+ mountain pass
557
+ doughnut
558
+ ski slope
559
+ match
560
+ official
561
+ shadow
562
+ organ
563
+ celebration
564
+ coin
565
+ log cabin
566
+ firework display
567
+ present
568
+ twig
569
+ chef
570
+ confetti
571
+ footpath
572
+ tour
573
+ ponytail
574
+ artwork
575
+ race car
576
+ club
577
+ season
578
+ hose
579
+ pencil
580
+ aircraft
581
+ rock formation
582
+ wardrobe
583
+ participant
584
+ politician
585
+ engineer
586
+ peace
587
+ filter
588
+ sailing boat
589
+ water bottle
590
+ service dog
591
+ poodle
592
+ loki
593
+ statesman
594
+ sleeping bag
595
+ outskirt
596
+ clock
597
+ factory
598
+ oak tree
599
+ physician
600
+ color
601
+ room
602
+ stairway
603
+ company
604
+ lady
605
+ graph
606
+ faucet
607
+ tablecloth
608
+ subway train
609
+ chocolate chip cookie
610
+ headquarters
611
+ screw
612
+ goggle
613
+ halloween
614
+ city street
615
+ swirl
616
+ cord
617
+ forward
618
+ bone
619
+ bedding
620
+ archway
621
+ wig
622
+ lobby
623
+ mask
624
+ attic
625
+ kitchen table
626
+ skylight
627
+ fire
628
+ exit
629
+ oil painting
630
+ passenger
631
+ meditation
632
+ salmon
633
+ fedora
634
+ rubber stamp
635
+ orange juice
636
+ arch
637
+ scientist
638
+ stroll
639
+ manhattan
640
+ float
641
+ baseball uniform
642
+ circle
643
+ church
644
+ decker bus
645
+ competitor
646
+ zoo
647
+ basketball team
648
+ tourist
649
+ daughter
650
+ silverware
651
+ ceiling fan
652
+ birth
653
+ vase
654
+ jack
655
+ mushroom
656
+ spiral
657
+ cage
658
+ limb
659
+ salad
660
+ ad
661
+ control
662
+ earth
663
+ party
664
+ bolt
665
+ tractor
666
+ barley
667
+ wedding photo
668
+ hawk
669
+ warehouse
670
+ vegetable garden
671
+ chocolate cake
672
+ cabbage
673
+ floor window
674
+ baby shower
675
+ magnifying glass
676
+ table
677
+ stethoscope
678
+ reading
679
+ mission
680
+ croissant
681
+ gift box
682
+ rocket
683
+ forest road
684
+ cooking
685
+ suite
686
+ hill country
687
+ motorcycle
688
+ baseball player
689
+ angle
690
+ drug
691
+ sport association
692
+ championship
693
+ family portrait
694
+ florist
695
+ softball
696
+ egret
697
+ office
698
+ plywood
699
+ jockey
700
+ mosque
701
+ brunch
702
+ beanie
703
+ office building
704
+ pattern
705
+ calendar
706
+ indoor
707
+ pepper
708
+ ledge
709
+ trail
710
+ fuel
711
+ laptop computer
712
+ tennis shoe
713
+ deck chair
714
+ guitarist
715
+ barn
716
+ surgery
717
+ cartoon illustration
718
+ nebula
719
+ railroad
720
+ mountain goat
721
+ goose
722
+ car door
723
+ cheer
724
+ liquid
725
+ hardwood floor
726
+ pathway
727
+ acorn
728
+ gull
729
+ airliner
730
+ couch
731
+ lake house
732
+ spaghetti
733
+ promenade
734
+ collection
735
+ garden
736
+ bank
737
+ robin
738
+ tennis ball
739
+ peony
740
+ gymnast
741
+ lavender
742
+ deck
743
+ test
744
+ riverside
745
+ rapper
746
+ domino
747
+ bride
748
+ mouse
749
+ basil
750
+ wedding couple
751
+ ocean wave
752
+ arm
753
+ kitchen floor
754
+ grove
755
+ family member
756
+ backyard
757
+ raspberry
758
+ forest fire
759
+ officer
760
+ hibiscus
761
+ canyon
762
+ composer
763
+ signature
764
+ olive oil
765
+ hibiscus flower
766
+ rose
767
+ vector icon
768
+ sunrise
769
+ horseback
770
+ motor scooter
771
+ office worker
772
+ tradition
773
+ ingredient
774
+ washing machine
775
+ lighting
776
+ bagel
777
+ sailboat
778
+ policeman
779
+ mare
780
+ graphic
781
+ halloween pumpkin
782
+ stock
783
+ pilot
784
+ education
785
+ team
786
+ body
787
+ horse
788
+ kimono
789
+ bazaar
790
+ bag
791
+ recording studio
792
+ parsley
793
+ entrance
794
+ denim
795
+ vet
796
+ horse farm
797
+ charcoal
798
+ architecture
799
+ glass vase
800
+ puppy
801
+ estuary
802
+ television show host
803
+ city bus
804
+ shoulder
805
+ beast
806
+ balance
807
+ golfer
808
+ roadside
809
+ denim jacket
810
+ stone wall
811
+ counter top
812
+ app icon
813
+ toast
814
+ head coach
815
+ ham
816
+ warrior
817
+ gem
818
+ refrigerator
819
+ snowman
820
+ construction worker
821
+ coal
822
+ website
823
+ morning fog
824
+ mustard
825
+ human
826
+ owl
827
+ puppy dog
828
+ piggy bank
829
+ vegetation
830
+ pirate
831
+ action film
832
+ marshmallow
833
+ thanksgiving
834
+ business
835
+ disease
836
+ signage
837
+ greeting
838
+ skate park
839
+ tile
840
+ mouth
841
+ spinach
842
+ vacation
843
+ leader
844
+ shrine
845
+ walker
846
+ science fiction film
847
+ bill
848
+ rabbit
849
+ motor boat
850
+ bar
851
+ radio
852
+ barge
853
+ tail
854
+ chainsaw
855
+ gallery
856
+ rainbow
857
+ pasta
858
+ padlock
859
+ web
860
+ pastry
861
+ ink
862
+ reef
863
+ school uniform
864
+ shawl
865
+ treasure
866
+ peach
867
+ dinner table
868
+ injury
869
+ harbor
870
+ witch
871
+ car dealership
872
+ litter
873
+ gesture
874
+ documentary
875
+ marriage
876
+ sea shell
877
+ priest
878
+ dome
879
+ kit
880
+ icon
881
+ seaside
882
+ bucket
883
+ entertainment
884
+ stable
885
+ hat
886
+ puddle
887
+ sock
888
+ shopper
889
+ technology
890
+ harbour
891
+ orbit
892
+ antler
893
+ tube
894
+ flag waving
895
+ cook
896
+ tight
897
+ commander
898
+ farmland
899
+ switch
900
+ hiker
901
+ wedding ceremony
902
+ award ceremony
903
+ champion
904
+ chopstick
905
+ farmhouse
906
+ performer
907
+ spike
908
+ accident
909
+ cruise ship
910
+ passenger train
911
+ attraction
912
+ entertainer
913
+ rear view
914
+ sidewalk
915
+ parade
916
+ racing
917
+ plane
918
+ ritual
919
+ peacock
920
+ pocket
921
+ plum
922
+ drop
923
+ carrot
924
+ floor
925
+ sunset
926
+ troop
927
+ architect
928
+ coffee table
929
+ dust
930
+ outline
931
+ leather
932
+ charity event
933
+ heat
934
+ whale
935
+ laundry
936
+ coconut tree
937
+ crosswalk
938
+ pony
939
+ ant
940
+ pipe
941
+ string
942
+ coat
943
+ angel
944
+ beef
945
+ church tower
946
+ dish
947
+ pitch
948
+ cupboard
949
+ thermometer
950
+ dirt field
951
+ fireworks
952
+ minute
953
+ cane
954
+ pajama
955
+ flower garden
956
+ autumn
957
+ trash can
958
+ dachshund
959
+ banana tree
960
+ tray
961
+ moose
962
+ roadway
963
+ carnival
964
+ antenna
965
+ pole
966
+ castle wall
967
+ ram
968
+ cattle
969
+ hay
970
+ cookie
971
+ swimmer
972
+ baseball team
973
+ strait
974
+ hedge
975
+ jet
976
+ fire pit
977
+ octopus
978
+ calf
979
+ cube
980
+ opera
981
+ cardboard box
982
+ tiara
983
+ kitchen sink
984
+ prairie
985
+ bowl
986
+ galaxy
987
+ straw hat
988
+ linen
989
+ ski resort
990
+ stitch
991
+ street lamp
992
+ motorist
993
+ icicle
994
+ stain
995
+ flora
996
+ drain
997
+ kitchen cabinet
998
+ decor
999
+ bouquet
1000
+ pound
1001
+ interior design
1002
+ nail polish
1003
+ figurine
1004
+ tomb
1005
+ disc
1006
+ twist
1007
+ blouse
1008
+ ribbon
1009
+ figure
1010
+ burger
1011
+ cork
1012
+ soccer goalkeeper
1013
+ train bridge
1014
+ drinking water
1015
+ dew
1016
+ baker
1017
+ storm cloud
1018
+ tarmac
1019
+ tv drama
1020
+ sponge
1021
+ magnet
1022
+ sailor
1023
+ entry
1024
+ swan
1025
+ exercise
1026
+ sloth
1027
+ jewel
1028
+ scuba diver
1029
+ bite
1030
+ cat tree
1031
+ tent
1032
+ can
1033
+ tennis match
1034
+ ecosystem
1035
+ picket fence
1036
+ palm
1037
+ train car
1038
+ frying pan
1039
+ rally
1040
+ tablet pc
1041
+ reindeer
1042
+ image
1043
+ wolf
1044
+ chin
1045
+ conservatory
1046
+ flood water
1047
+ cityscape
1048
+ beach sand
1049
+ car park
1050
+ pavement
1051
+ farm field
1052
+ swimming
1053
+ winter storm
1054
+ stem
1055
+ pillow
1056
+ inning
1057
+ gorilla
1058
+ desk
1059
+ avenue
1060
+ fern
1061
+ money
1062
+ pearl
1063
+ train station
1064
+ skillet
1065
+ nap
1066
+ barber
1067
+ library
1068
+ freezer
1069
+ label
1070
+ rainforest
1071
+ parking sign
1072
+ mirror
1073
+ wing
1074
+ noodle
1075
+ press room
1076
+ sculpture
1077
+ tablet
1078
+ viewer
1079
+ prayer
1080
+ mini
1081
+ mechanic
1082
+ laugh
1083
+ rice field
1084
+ hand
1085
+ mustache
1086
+ mountain road
1087
+ catwalk
1088
+ conference
1089
+ cape
1090
+ installation
1091
+ musician
1092
+ stream
1093
+ machine
1094
+ speech
1095
+ crocodile
1096
+ soccer match
1097
+ town square
1098
+ passport
1099
+ post box
1100
+ point
1101
+ stone building
1102
+ motorway
1103
+ mix
1104
+ dentist
1105
+ businessperson
1106
+ happiness
1107
+ boat
1108
+ vineyard
1109
+ treadmill
1110
+ glass wall
1111
+ water droplet
1112
+ coffee mug
1113
+ graduate
1114
+ sunflower
1115
+ parliament
1116
+ shepherd
1117
+ movie
1118
+ wine
1119
+ orchard
1120
+ tulip
1121
+ motherboard
1122
+ cup
1123
+ broom
1124
+ spot
1125
+ drawing
1126
+ polo shirt
1127
+ graduation
1128
+ film producer
1129
+ moonlight
1130
+ glow
1131
+ film format
1132
+ t shirt
1133
+ rock face
1134
+ sword
1135
+ clinic
1136
+ festival day
1137
+ meadow
1138
+ staple
1139
+ pupil
1140
+ training ground
1141
+ rider
1142
+ flower
1143
+ foal
1144
+ wharf
1145
+ foot bridge
1146
+ shooting
1147
+ top
1148
+ mast
1149
+ police car
1150
+ robe
1151
+ wedding bouquet
1152
+ stop sign
1153
+ birthday cake
1154
+ glitter
1155
+ butter
1156
+ scooter
1157
+ tundra
1158
+ superhero
1159
+ pocket watch
1160
+ inscription
1161
+ youngster
1162
+ fruit tree
1163
+ movie poster
1164
+ engine
1165
+ foundation
1166
+ motorcyclist
1167
+ take
1168
+ woman
1169
+ antelope
1170
+ country artist
1171
+ road trip
1172
+ typewriter
1173
+ tuxedo
1174
+ brand
1175
+ pine
1176
+ bathroom
1177
+ paradise
1178
+ texture
1179
+ balloon
1180
+ dining table
1181
+ home
1182
+ computer screen
1183
+ actor
1184
+ clip
1185
+ tv tower
1186
+ panorama
1187
+ summit
1188
+ cat
1189
+ plot
1190
+ eagle
1191
+ dancer
1192
+ pup
1193
+ studio shot
1194
+ tear
1195
+ bird bath
1196
+ classroom
1197
+ bookstore
1198
+ city wall
1199
+ tv programme
1200
+ blade
1201
+ easel
1202
+ buttercream
1203
+ sweet
1204
+ designer
1205
+ diamond
1206
+ handshake
1207
+ herb
1208
+ corn field
1209
+ seafront
1210
+ concrete
1211
+ street artist
1212
+ gas
1213
+ stamp
1214
+ window display
1215
+ paper
1216
+ note
1217
+ pint
1218
+ quarry
1219
+ research
1220
+ fixture
1221
+ manager
1222
+ soil
1223
+ leopard
1224
+ board game
1225
+ ladder
1226
+ stop light
1227
+ island
1228
+ ramp
1229
+ football match
1230
+ icing
1231
+ drill
1232
+ currency
1233
+ summer evening
1234
+ topping
1235
+ pyramid
1236
+ pomegranate
1237
+ cell
1238
+ ivy
1239
+ squad
1240
+ scenery
1241
+ computer
1242
+ locomotive
1243
+ surf
1244
+ mascot
1245
+ dune
1246
+ path
1247
+ duck
1248
+ twilight
1249
+ wire
1250
+ bow tie
1251
+ strike
1252
+ cormorant
1253
+ car wash
1254
+ crane
1255
+ market
1256
+ philosopher
1257
+ alarm clock
1258
+ camera
1259
+ birch
1260
+ greeting card
1261
+ plain
1262
+ clay
1263
+ donut
1264
+ lock
1265
+ moth
1266
+ laboratory
1267
+ fan
1268
+ violin
1269
+ jazz fusion artist
1270
+ mountain biker
1271
+ terrain
1272
+ magazine
1273
+ pickup
1274
+ comedy film
1275
+ smartphone
1276
+ film
1277
+ bed
1278
+ microwave oven
1279
+ tournament
1280
+ lawn
1281
+ car window
1282
+ alligator
1283
+ screen
1284
+ jetty
1285
+ shopping bag
1286
+ landscape view
1287
+ cabinetry
1288
+ friendly match
1289
+ thing
1290
+ petal
1291
+ shopping center
1292
+ transport
1293
+ ballet dancer
1294
+ shoreline
1295
+ princess
1296
+ car seat
1297
+ parking meter
1298
+ green
1299
+ vodka
1300
+ band
1301
+ rock
1302
+ costume
1303
+ warning sign
1304
+ strip
1305
+ plaque
1306
+ wheelchair
1307
+ headband
1308
+ ginger
1309
+ dice
1310
+ media
1311
+ hairdresser
1312
+ press
1313
+ living room
1314
+ stove
1315
+ player
1316
+ cherry
1317
+ workshop
1318
+ carving
1319
+ embroidery
1320
+ doodle
1321
+ adventure
1322
+ rugby player
1323
+ monument
1324
+ brush
1325
+ marker
1326
+ loft
1327
+ postcard
1328
+ collage
1329
+ ball
1330
+ professor
1331
+ dresser
1332
+ gig
1333
+ festival
1334
+ blackbird
1335
+ makeup artist
1336
+ video camera
1337
+ sticker
1338
+ peak
1339
+ wildflower
1340
+ santa hat
1341
+ rodeo
1342
+ wedding photographer
1343
+ guy
1344
+ staff
1345
+ waterfall
1346
+ operation
1347
+ defender
1348
+ falcon
1349
+ haze
1350
+ individual
1351
+ gentleman
1352
+ greyhound
1353
+ rocking chair
1354
+ rice
1355
+ garbage
1356
+ platter
1357
+ chocolate
1358
+ splash
1359
+ business suit
1360
+ cheetah
1361
+ valley
1362
+ maze
1363
+ trampoline
1364
+ garland
1365
+ slalom
1366
+ unicorn
1367
+ tree stump
1368
+ painting
1369
+ romance
1370
+ fight
1371
+ alcohol
1372
+ ghost
1373
+ fondant
1374
+ spa
1375
+ shutter
1376
+ death
1377
+ demonstration
1378
+ cotton
1379
+ pier
1380
+ flea market
1381
+ history
1382
+ savannah
1383
+ fist
1384
+ aisle
1385
+ crew
1386
+ jug
1387
+ pose
1388
+ anchor
1389
+ teapot
1390
+ boat house
1391
+ business team
1392
+ tripod
1393
+ bee
1394
+ pebble
1395
+ mattress
1396
+ canvas
1397
+ hallway
1398
+ campaign
1399
+ pod
1400
+ lake district
1401
+ article
1402
+ white
1403
+ sofa
1404
+ honey
1405
+ marathon
1406
+ pancake
1407
+ tourist attraction
1408
+ wedding gown
1409
+ battle
1410
+ shelving
1411
+ sea
1412
+ sheet music
1413
+ pie
1414
+ yarn
1415
+ construction site
1416
+ flyer
1417
+ tie
1418
+ star
1419
+ lettuce
1420
+ martial artist
1421
+ dart
1422
+ straw
1423
+ reflection
1424
+ conference room
1425
+ temperature
1426
+ rugby
1427
+ mosquito
1428
+ physicist
1429
+ rock climber
1430
+ crash
1431
+ backdrop
1432
+ toilet seat
1433
+ sand castle
1434
+ water park
1435
+ toy car
1436
+ waste
1437
+ luxury
1438
+ hangar
1439
+ rv
1440
+ tree trunk
1441
+ board
1442
+ gold
1443
+ project picture
1444
+ cap
1445
+ cottage
1446
+ relief
1447
+ attire
1448
+ microscope
1449
+ battery
1450
+ roll
1451
+ line
1452
+ parking garage
1453
+ crystal
1454
+ broadcasting
1455
+ brick wall
1456
+ lab
1457
+ flooring
1458
+ meeting
1459
+ 3d cg rendering
1460
+ desktop computer
1461
+ cowboy
1462
+ sailing ship
1463
+ junction
1464
+ hairstyle
1465
+ homework
1466
+ profile
1467
+ model
1468
+ flower pot
1469
+ street light
1470
+ salt lake
1471
+ maple
1472
+ space
1473
+ blizzard
1474
+ throw
1475
+ zebras
1476
+ brochure
1477
+ constellation
1478
+ beak
1479
+ kilt
1480
+ pond
1481
+ blue sky
1482
+ sneaker
1483
+ sand dune
1484
+ morning sun
1485
+ almond
1486
+ grill
1487
+ curl
1488
+ basketball girl game
1489
+ chameleon
1490
+ toilet bowl
1491
+ prince
1492
+ keyboard
1493
+ queen
1494
+ computer monitor
1495
+ writing
1496
+ crown
1497
+ basilica
1498
+ kiss
1499
+ house
1500
+ parking
1501
+ football competition
1502
+ shell
1503
+ sport equipment
1504
+ comedy
1505
+ baboon
1506
+ vendor
1507
+ rise building
1508
+ wrap
1509
+ food truck
1510
+ cat bed
1511
+ rickshaw
1512
+ flare
1513
+ teal
1514
+ nectar
1515
+ eclipse
1516
+ vehicle
1517
+ steam locomotive
1518
+ gorge
1519
+ cow
1520
+ christmas card
1521
+ demonstrator
1522
+ memorial
1523
+ towel
1524
+ jewellery
1525
+ train
1526
+ frisbee
1527
+ baseball game
1528
+ fur
1529
+ afternoon sun
1530
+ community
1531
+ sparkler
1532
+ bandage
1533
+ firework
1534
+ dollar
1535
+ pasture
1536
+ video
1537
+ bus
1538
+ tree house
1539
+ seashore
1540
+ field
1541
+ hamburger
1542
+ souvenir
1543
+ hedgehog
1544
+ worm
1545
+ pine cone
1546
+ osprey
1547
+ dinosaur
1548
+ vegetable
1549
+ junk
1550
+ poster
1551
+ army
1552
+ winger
1553
+ bundle
1554
+ stage
1555
+ growth
1556
+ wedding party
1557
+ service
1558
+ blanket
1559
+ ruler
1560
+ eye
1561
+ credit card
1562
+ castle
1563
+ diner
1564
+ hut
1565
+ elk
1566
+ hard rock artist
1567
+ nun
1568
+ dog breed
1569
+ nest
1570
+ drama film
1571
+ number icon
1572
+ water tank
1573
+ giraffe
1574
+ altar
1575
+ pavilion
1576
+ tv personality
1577
+ suv
1578
+ street vendor
1579
+ street sign
1580
+ ditch
1581
+ debris
1582
+ foam
1583
+ takeoff
1584
+ spice
1585
+ mountain lake
1586
+ tea
1587
+ orchestra
1588
+ spacecraft
1589
+ counter
1590
+ abbey
1591
+ mountain
1592
+ hydrangea
1593
+ racer
1594
+ orange tree
1595
+ tide
1596
+ cowboy hat
1597
+ rapid
1598
+ town
1599
+ wild
1600
+ herd
1601
+ vein
1602
+ driveway
1603
+ jar
1604
+ bark
1605
+ illustration
1606
+ horror film
1607
+ corn
1608
+ stroller
1609
+ industry
1610
+ mountain stream
1611
+ gym
1612
+ neckline
1613
+ pan
1614
+ client
1615
+ spectator
1616
+ eggplant
1617
+ camper
1618
+ fawn
1619
+ hoodie
1620
+ meat
1621
+ lemonade
1622
+ food market
1623
+ slum
1624
+ comic book character
1625
+ flower market
1626
+ love
1627
+ palace
1628
+ gun
1629
+ heel
1630
+ shopping street
1631
+ shooting basketball guard
1632
+ family photo
1633
+ rooftop
1634
+ laundry basket
1635
+ airport runway
1636
+ horn
1637
+ face mask
1638
+ flight
1639
+ appetizer
1640
+ violet
1641
+ country lane
1642
+ cement
1643
+ instrument
1644
+ tv actor
1645
+ spark
1646
+ celebrity
1647
+ award
1648
+ country house
1649
+ standing
1650
+ auction
1651
+ date
1652
+ engagement
1653
+ puck
1654
+ advertisement
1655
+ chair
1656
+ zebra
1657
+ driftwood
1658
+ bumblebee
1659
+ maple leaf
1660
+ bonnet
1661
+ orange
1662
+ water tower
1663
+ door
1664
+ singer
1665
+ floor plan
1666
+ discussion
1667
+ theatre
1668
+ pilgrim
1669
+ mug
1670
+ branch
1671
+ window sill
1672
+ baseball pitcher
1673
+ bakery
1674
+ lollipop
1675
+ basketball player
1676
+ toilet paper
1677
+ chalkboard
1678
+ cabin
1679
+ sign
1680
+ night sky
1681
+ cannon
1682
+ fishing net
1683
+ submarine
1684
+ suit
1685
+ fur coat
1686
+ wine bottle
1687
+ folder
1688
+ street art
1689
+ suspension bridge
1690
+ evening sky
1691
+ billboard
1692
+ postage stamp
1693
+ newspaper
1694
+ transportation
1695
+ surgeon
1696
+ light
1697
+ park
1698
+ horizon
1699
+ road
1700
+ sand bar
1701
+ trumpet
1702
+ lounge
1703
+ cloud forest
1704
+ birthday celebration
1705
+ balcony
1706
+ anime
1707
+ beehive
1708
+ umbrella
1709
+ goldfish
1710
+ baseball cap
1711
+ waterhole
1712
+ ceiling
1713
+ carousel
1714
+ backpack
1715
+ plant pot
1716
+ atmosphere
1717
+ sunflower field
1718
+ spire
1719
+ vision
1720
+ woodpecker
1721
+ chip
1722
+ pool table
1723
+ lotus flower
1724
+ cone
1725
+ humpback whale
1726
+ reservoir
1727
+ hunt
1728
+ piano
1729
+ plate
1730
+ dining area
1731
+ luggage
1732
+ skier
1733
+ dance floor
1734
+ crow
1735
+ stair
1736
+ overpass
1737
+ opera house
1738
+ bear
1739
+ jazz artist
1740
+ water
1741
+ vessel
1742
+ cast
1743
+ yard
1744
+ cathedral
1745
+ basketball hoop
1746
+ graveyard
1747
+ sound
1748
+ berry
1749
+ onlooker
1750
+ fauna
1751
+ birch tree
1752
+ retail
1753
+ hill
1754
+ skeleton
1755
+ journalist
1756
+ frost
1757
+ basket
1758
+ nail
1759
+ dusk
1760
+ trash
1761
+ dawn
1762
+ clover
1763
+ hen
1764
+ volcano
1765
+ basketball coach
1766
+ home decor
1767
+ charge
1768
+ haircut
1769
+ sense
1770
+ university
1771
+ lizard
1772
+ daisy
1773
+ tablet computer
1774
+ grass field
1775
+ prison
1776
+ metal artist
1777
+ bathroom mirror
1778
+ window frame
1779
+ chest
1780
+ flavor
1781
+ pop country artist
1782
+ market square
1783
+ monkey
1784
+ blog
1785
+ deer
1786
+ speech bubble
1787
+ dog
1788
+ independence day
1789
+ girl
1790
+ boy
1791
+ tartan
1792
+ furniture
1793
+ appliance
1794
+ office window
1795
+ fish boat
1796
+ sand box
1797
+ tv sitcom
1798
+ drama
1799
+ sleigh
1800
+ depression
1801
+ paper towel
1802
+ baseball
1803
+ protestor
1804
+ grape
1805
+ wedding cake
1806
+ invitation
1807
+ accessory
1808
+ pick
1809
+ grandparent
1810
+ racket
1811
+ tea plantation
1812
+ outdoors
1813
+ egg
1814
+ glass bowl
1815
+ sun
1816
+ organization
1817
+ lion
1818
+ panel
1819
+ station
1820
+ wallpaper
1821
+ helicopter
1822
+ salt
1823
+ vanity
1824
+ patio
1825
+ lunch
1826
+ street performer
1827
+ mountain range
1828
+ soup
1829
+ bacon
1830
+ power station
1831
+ cantilever bridge
1832
+ hummingbird
1833
+ shirt
1834
+ rope
1835
+ hip
1836
+ chalk
1837
+ pendant
1838
+ choir
1839
+ tv
1840
+ lichen
1841
+ railway bridge
1842
+ art gallery
1843
+ bartender
1844
+ wagon
1845
+ baby elephant
1846
+ accordion
1847
+ horseshoe
1848
+ building site
1849
+ clutch
1850
+ harvest
1851
+ savanna
1852
+ geranium
1853
+ business woman
1854
+ paddock
1855
+ patch
1856
+ beech tree
1857
+ war
1858
+ suburbs
1859
+ hospital bed
1860
+ motorcycle racer
1861
+ moss
1862
+ gravel
1863
+ government agency
1864
+ dollar bill
1865
+ father
1866
+ fjord
1867
+ concert
1868
+ nut
1869
+ wedding photography
1870
+ finish line
1871
+ home plate
1872
+ food
1873
+ nose
1874
+ thumb
1875
+ village
1876
+ dining room table
1877
+ bumper
1878
+ monster
1879
+ blackberry
1880
+ lime
1881
+ conflict
1882
+ gala
1883
+ wallet
1884
+ wrist
1885
+ hug
1886
+ mermaid
1887
+ lava
1888
+ lawyer
1889
+ folk rock artist
1890
+ arena
1891
+ onion
1892
+ toothbrush
1893
+ fashion
1894
+ perfume
1895
+ flip
1896
+ triangle
1897
+ woodland
1898
+ mail
1899
+ grasshopper
1900
+ studio
1901
+ wood floor
1902
+ den
1903
+ racquet
1904
+ cello
1905
+ lemur
1906
+ astronaut
1907
+ glass table
1908
+ blood
1909
+ dvd
1910
+ planter
1911
+ silver
1912
+ leash
1913
+ master bedroom
1914
+ forest
1915
+ batter
1916
+ shoe
1917
+ engraving
1918
+ opening
1919
+ product
1920
+ toe
1921
+ cocktail
1922
+ mallard duck
1923
+ bike ride
1924
+ oasis
1925
+ wedding ring
1926
+ cinematographer
1927
+ holly
1928
+ autograph
1929
+ fence
1930
+ ice cube
1931
+ cove
1932
+ pineapple
1933
+ aurora
1934
+ glass bead
1935
+ produce
1936
+ apartment building
1937
+ cob
1938
+ miniature
1939
+ cockpit
1940
+ flashlight
1941
+ frog
1942
+ sheep
1943
+ groom
1944
+ steel
1945
+ watermelon
1946
+ clip art
1947
+ paper plate
1948
+ ostrich
1949
+ contour
1950
+ mural
1951
+ cub
1952
+ paisley bandanna
1953
+ winery
1954
+ turn
1955
+ handle
1956
+ satellite
1957
+ post
1958
+ pork
1959
+ child
1960
+ asphalt
1961
+ grocery store
1962
+ vulture
1963
+ trolley
1964
+ nightclub
1965
+ brick
1966
+ trailer
1967
+ compass
1968
+ cereal
1969
+ cafe
1970
+ cartoon character
1971
+ sugar
1972
+ fiction book
1973
+ glass floor
1974
+ umpire
1975
+ guitar
1976
+ hamster
1977
+ protester
1978
+ airplane
1979
+ garment
1980
+ blazer
1981
+ railway line
1982
+ wedding
1983
+ shoe box
1984
+ parking lot
1985
+ construction
1986
+ graduation ceremony
1987
+ tram
1988
+ telescope
1989
+ copper
1990
+ pain
1991
+ autumn forest
1992
+ guest house
1993
+ partner
1994
+ crayon
1995
+ dip
1996
+ boot
1997
+ corridor
1998
+ computer keyboard
1999
+ hockey player
2000
+ chicken coop
2001
+ bus station
2002
+ gathering
2003
+ ankle
2004
+ bunk bed
2005
+ wood table
2006
+ football coach
2007
+ monarch
2008
+ pharmacy
2009
+ legging
2010
+ mannequin
2011
+ female
2012
+ train track
2013
+ stack
2014
+ canopy
2015
+ design element
2016
+ grandmother
2017
+ symbol
2018
+ beach hut
2019
+ zucchini
2020
+ bomb
2021
+ businessman
2022
+ skyscraper
2023
+ tongue
2024
+ case
2025
+ sparkle
2026
+ highland
2027
+ ballroom
2028
+ prom
2029
+ estate
2030
+ customer
2031
+ archipelago
2032
+ cheese
2033
+ debate
2034
+ carriage
2035
+ bulldozer
2036
+ pumpkin
2037
+ sitting room
2038
+ gas station
2039
+ wedding reception
2040
+ camp
2041
+ dog bed
2042
+ tower
2043
+ property
2044
+ river bed
2045
+ pop latin artist
2046
+ fridge
2047
+ wine glass
2048
+ coast
2049
+ beer
2050
+ tow truck
2051
+ fire truck
2052
+ mountain bike
2053
+ thigh
2054
+ heron
2055
+ boat ride
2056
+ gondola
2057
+ turquoise
2058
+ lake
2059
+ llama
2060
+ kitty
2061
+ tin
2062
+ waiting room
2063
+ coffee cup
2064
+ socialite
2065
+ guard
2066
+ tap
2067
+ waterway
2068
+ forehead
2069
+ list
2070
+ erosion
2071
+ box
2072
+ sea lion
2073
+ pollen
2074
+ dam
2075
+ wasp
2076
+ salon
2077
+ tennis tournament
2078
+ flower box
2079
+ aquarium
2080
+ rain cloud
2081
+ clothing store
2082
+ lead singer
2083
+ cupcake
2084
+ tortoise
2085
+ lettering
2086
+ sport facility
2087
+ dance
2088
+ dog house
2089
+ nature
2090
+ football
2091
+ rooster
2092
+ footballer
2093
+ railway track
2094
+ crowd
2095
+ fishing rod
2096
+ silhouette
2097
+ wind turbine
2098
+ sari
2099
+ bus window
2100
+ cloud
2101
+ charity
2102
+ medal
2103
+ yoga
2104
+ event
2105
+ veil
2106
+ fashion menswear milan week
2107
+ news
2108
+ knife
2109
+ print
2110
+ screen tv
2111
+ walnut
2112
+ fungus
2113
+ ice cream
2114
+ computer mouse
2115
+ play
2116
+ tribe
2117
+ picture
2118
+ video game
2119
+ business card
2120
+ music festival
2121
+ rack
2122
+ envelope
2123
+ shower
2124
+ dirt road
2125
+ mine
2126
+ oyster
2127
+ monarch butterfly
2128
+ dude
2129
+ fruit salad
2130
+ podium
2131
+ fork
2132
+ lace
2133
+ test match
2134
+ boulder
2135
+ cricket player
2136
+ staircase
2137
+ peninsula
2138
+ shopping
2139
+ popcorn
2140
+ oak
2141
+ market stall
2142
+ pine tree
2143
+ mountaineer
2144
+ student
2145
+ closet
2146
+ hood
2147
+ handstand
2148
+ centerpiece
2149
+ insect
2150
+ patient
2151
+ makeover
2152
+ tennis player
2153
+ sheet
2154
+ park bench
2155
+ apple
2156
+ organism
2157
+ hook
2158
+ turkey
2159
+ tangerine
2160
+ sibling
2161
+ shopping mall
2162
+ bird
2163
+ scarf
2164
+ smoothie
2165
+ net
2166
+ grass
2167
+ napkin
2168
+ ray
2169
+ eyebrow
2170
+ laptop keyboard
2171
+ motorbike
2172
+ woman hand
2173
+ oven
2174
+ book cover
2175
+ easter egg
2176
+ microwave
2177
+ sand
2178
+ snapshot
2179
+ soccer ball
2180
+ makeup
2181
+ knight
2182
+ bowling ball
2183
+ shower curtain
2184
+ flame
2185
+ lightning
2186
+ running
2187
+ power plant
2188
+ crib
2189
+ cartoon
2190
+ moat
2191
+ fashion girl
2192
+ wedding invitation
2193
+ bottle
2194
+ cliff
2195
+ monastery
2196
+ file photo
2197
+ apartment
2198
+ casino
2199
+ cream
2200
+ sweatshirt
2201
+ storm
2202
+ cruise
2203
+ teddy bear
2204
+ shovel
2205
+ wind farm
2206
+ writer
2207
+ dock
2208
+ professional
2209
+ hotel room
2210
+ job
2211
+ monitor
2212
+ donkey
2213
+ pass
2214
+ interview
2215
+ duchess
2216
+ mark
2217
+ plank
2218
+ beard
2219
+ zombie
2220
+ trio
2221
+ channel
2222
+ cricket team
2223
+ windmill
2224
+ vest
2225
+ diagram
2226
+ cable
2227
+ winter scene
2228
+ golden gate bridge
2229
+ buffalo
2230
+ studio portrait
2231
+ pagoda
2232
+ whiskey
2233
+ freight train
2234
+ kite
2235
+ future
2236
+ steam train
2237
+ phone box
2238
+ headset
2239
+ wood
2240
+ snowboarder
2241
+ paper bag
2242
+ slide
2243
+ grapefruit
2244
+ seating
2245
+ morning
2246
+ bronze sculpture
2247
+ theatre actor
2248
+ stump
2249
+ jean
2250
+ landmark
2251
+ jam
2252
+ waist
2253
+ watercolor
2254
+ hammock
2255
+ light fixture
2256
+ ice
2257
+ basin
2258
+ beverage
2259
+ shelter
2260
+ premiere
2261
+ mound
2262
+ ear
2263
+ bronze
2264
+ sunlight
2265
+ street
2266
+ energy
2267
+ barn door
2268
+ hike
2269
+ fleet
2270
+ claw
2271
+ beach
2272
+ pepperoni
2273
+ bin
2274
+ trainer
2275
+ buffet
2276
+ archive
2277
+ toddler
2278
+ referee
2279
+ bay window
2280
+ dove
2281
+ production company
2282
+ evening light
2283
+ gate
2284
+ farm
2285
+ reed
2286
+ fruit stand
2287
+ explorer
2288
+ snow storm
2289
+ throw pillow
2290
+ button
2291
+ display case
2292
+ bookcase
2293
+ lead
2294
+ lipstick
2295
+ basketball court
2296
+ cargo
2297
+ ensemble
2298
+ pope
2299
+ clock tower
2300
+ teen
2301
+ speaker
2302
+ rat
2303
+ laptop
2304
+ ski
2305
+ mess
2306
+ stadium
2307
+ ferry boat
2308
+ bunny
2309
+ waterfront
2310
+ downtown
2311
+ sink
2312
+ press conference
2313
+ dinner
2314
+ condiment
2315
+ thread
2316
+ audience
2317
+ grid
2318
+ car
2319
+ plastic
2320
+ people
2321
+ barbecue
2322
+ pigeon
2323
+ urinal
2324
+ seagull
2325
+ volunteer
2326
+ hockey
2327
+ fir tree
2328
+ pollution
2329
+ trial
2330
+ collar
2331
+ area
2332
+ meeting room
2333
+ circus
2334
+ yogurt
2335
+ orangutan
2336
+ viaduct
2337
+ comedian
2338
+ drone
2339
+ scissor
2340
+ pop rock artist
2341
+ biscuit
2342
+ panda
2343
+ water feature
2344
+ air balloon
2345
+ remote control
2346
+ watercolor painting
2347
+ show
2348
+ walk
2349
+ post office
2350
+ bike path
2351
+ rap gangsta artist
2352
+ microphone
2353
+ crack
2354
+ sunset sky
2355
+ glass
2356
+ tv show
2357
+ cartoon style
2358
+ stripe
2359
+ foyer
2360
+ signal
2361
+ calligraphy
2362
+ bulb
2363
+ gardener
2364
+ coffee bean
2365
+ spider
2366
+ tapestry
2367
+ city skyline
2368
+ necklace
2369
+ kitten
2370
+ traveler
2371
+ veteran
2372
+ frosting
2373
+ fry
2374
+ tennis court
2375
+ tank top
2376
+ butterfly house
2377
+ mist
2378
+ drummer
2379
+ water level
2380
+ scale
2381
+ baseball glove
2382
+ music video performer
2383
+ champagne
2384
+ camping
2385
+ clothing
2386
+ water drop
2387
+ telephone box
2388
+ pen
2389
+ morning mist
2390
+ fire engine
2391
+ porch
2392
+ opening ceremony
2393
+ style
2394
+ palm tree
2395
+ fashion show
2396
+ universe
2397
+ scratch
2398
+ axe
2399
+ ottoman
2400
+ explosion
2401
+ rib
2402
+ boutique
2403
+ game
2404
+ cucumber
2405
+ fruit
2406
+ stone bridge
2407
+ nature reserve
2408
+ track
2409
+ train window
2410
+ punch
2411
+ telephone pole
2412
+ velvet
2413
+ sauce
2414
+ moon
2415
+ contrast
2416
+ flamingo
2417
+ bat
2418
+ vending machine
2419
+ ship
2420
+ equestrian
2421
+ shade
2422
+ comforter
2423
+ pallet
2424
+ sparrow
2425
+ wii
2426
+ glaze
2427
+ grocery
2428
+ steeple
2429
+ soccer player
2430
+ contract
2431
+ advertising
2432
+ runner
2433
+ chimpanzee
2434
+ world
2435
+ seat
2436
+ project
2437
+ chihuahua
2438
+ bubble
2439
+ willow
2440
+ pedestal
2441
+ soul hip hop artist
2442
+ curb
2443
+ drawer
2444
+ leaf
2445
+ banner
2446
+ launch party
2447
+ coach
2448
+ government
2449
+ snowball
2450
+ toy
2451
+ portrait
2452
+ doctor
2453
+ whiteboard
2454
+ electronic
2455
+ tiger
2456
+ graffiti
2457
+ column
2458
+ nightstand
2459
+ whistle
2460
+ maxi dress
2461
+ bench
2462
+ wetsuit
2463
+ bird feeder
2464
+ football game
2465
+ basketball
2466
+ class
2467
+ bathroom door
2468
+ store window
2469
+ text message
2470
+ wreath
2471
+ street view
2472
+ binocular
2473
+ pet
2474
+ facade
2475
+ drought
2476
+ lemon
2477
+ new year
2478
+ night view
2479
+ airplane window
2480
+ specie
2481
+ rule
2482
+ jaw
2483
+ wheat field
2484
+ diet
2485
+ pop artist
2486
+ habitat
2487
+ screenshot
2488
+ scoreboard
2489
+ shore
2490
+ mane
2491
+ quilt
2492
+ ski lift
2493
+ orchid
2494
+ turban
2495
+ christmas
2496
+ airport
2497
+ marina
2498
+ glass door
2499
+ glass bottle
2500
+ restaurant
2501
+ conductor
2502
+ logo
2503
+ sleep
2504
+ tape
2505
+ tomato
2506
+ river bank
2507
+ lilac
2508
+ tooth
2509
+ training
2510
+ pottery
2511
+ shop
2512
+ steam engine
2513
+ mason jar
2514
+ base
2515
+ procession
2516
+ border
2517
+ shoot
2518
+ footprint
2519
+ hotdog
2520
+ bull
2521
+ stocking
2522
+ recreation
2523
+ automobile model
2524
+ design
2525
+ country pop artist
2526
+ river
2527
+ retriever
2528
+ department store
2529
+ auditorium
2530
+ sport car
2531
+ supermarket
2532
+ belt
2533
+ cricket
2534
+ window box
2535
+ dress shirt
2536
+ letter
2537
+ residence
2538
+ megaphone
2539
+ pant
2540
+ wildfire
2541
+ bird nest
2542
+ crab
2543
+ swimsuit
2544
+ candle
2545
+ funeral
2546
+ mill
2547
+ national park
2548
+ plant
2549
+ cop
2550
+ power line
2551
+ perch
2552
+ blue
2553
+ finger
2554
+ ferris wheel
2555
+ globe
2556
+ skateboard
2557
+ helmet
2558
+ movie theater
2559
+ uniform
2560
+ hammer
2561
+ material
2562
+ kid
2563
+ well
2564
+ butterfly
2565
+ sideline
2566
+ fashion fall show
2567
+ planet earth
2568
+ lift
2569
+ male
2570
+ sauna
2571
+ gray
2572
+ flour
2573
+ sand sculpture
2574
+ program
2575
+ cabinet
2576
+ infant
2577
+ wheel
2578
+ aircraft model
2579
+ dough
2580
+ garlic
2581
+ skate
2582
+ arrow
2583
+ wrapping paper
2584
+ ripple
2585
+ lamp
2586
+ iron
2587
+ banknote
2588
+ beaver
2589
+ ferry
2590
+ courtyard
2591
+ bassist
2592
+ countryside
2593
+ steak
2594
+ comfort
2595
+ boxer
2596
+ laundry room
2597
+ campsite
2598
+ brick building
2599
+ golf
2600
+ subway
2601
+ headphone
2602
+ fort
2603
+ handbag
2604
+ drum
2605
+ flood
2606
+ saddle
2607
+ bass
2608
+ labyrinth
2609
+ needle
2610
+ sun ray
2611
+ app
2612
+ menu
2613
+ president
2614
+ cardigan
2615
+ dandelion
2616
+ wetland
2617
+ ice hockey player
2618
+ number
2619
+ city hall
2620
+ fishing
2621
+ portrait session
2622
+ pug
2623
+ key
2624
+ art print
2625
+ minister
2626
+ hurdle
2627
+ emergency
2628
+ painting artist
2629
+ flag pole
2630
+ evening
2631
+ purse
2632
+ recipe
2633
+ golf ball
2634
+ coloring book
2635
+ mountain peak
2636
+ senior
2637
+ holiday
2638
+ bud
2639
+ cousin
2640
+ pantry
2641
+ lap
2642
+ skin
2643
+ flag
2644
+ tissue paper
2645
+ ridge
2646
+ wire fence
2647
+ surfer
2648
+ climber
2649
+ photograph
2650
+ sewing machine
2651
+ cooler
2652
+ actress
2653
+ apple tree
2654
+ cancer
2655
+ starfish
2656
+ automobile make
2657
+ dumbbell
2658
+ brace
2659
+ tunnel
2660
+ window
2661
+ paint artist
2662
+ composition
2663
+ school student
2664
+ condo
2665
+ convertible
2666
+ cushion
2667
+ selfie
2668
+ territory
2669
+ guide
2670
+ tree
2671
+ court
2672
+ shrimp
2673
+ stone house
2674
+ dress
2675
+ eyelash
2676
+ juice
2677
+ broccoli
2678
+ chain
2679
+ tourism
2680
+ mountain top
2681
+ concept car
2682
+ film premiere
2683
+ light bulb
2684
+ cafeteria
2685
+ badge
2686
+ flower bed
2687
+ theater
2688
+ root
2689
+ racecar driver
2690
+ basketball boy game
2691
+ glove
2692
+ skyline
2693
+ wall
2694
+ glacier
2695
+ airport terminal
2696
+ bug
2697
+ trim
2698
+ railway station
2699
+ briefcase
2700
+ flat
2701
+ fountain
2702
+ person
2703
+ lane
2704
+ asparagus
2705
+ art
2706
+ lantern
2707
+ dishwasher
2708
+ director
2709
+ snake
2710
+ lecture
2711
+ game controller
2712
+ tree branch
2713
+ pub
2714
+ bathing suit
2715
+ queue
2716
+ belly
2717
+ poppy
2718
+ bow
2719
+ pitcher
2720
+ ice cream cone
2721
+ cave
2722
+ candy
2723
+ road bridge
2724
+ host
2725
+ traffic jam
2726
+ earring
2727
+ file
2728
+ foot
2729
+ watermark overlay stamp
2730
+ mailbox
2731
+ supercar
2732
+ railing
2733
+ bedroom
2734
+ seafood
2735
+ waffle
2736
+ bronze statue
2737
+ plan
2738
+ flow
2739
+ marble
2740
+ basketball game
2741
+ automobile
2742
+ scene
2743
+ cypress tree
2744
+ soldier
2745
+ skateboarder
2746
+ glass building
2747
+ cherry tree
2748
+ pump
2749
+ grain
2750
+ wildebeest
2751
+ loop
2752
+ frame
2753
+ bathtub
2754
+ saxophone
2755
+ diver
2756
+ stalk
2757
+ lily
2758
+ bead
2759
+ alley
2760
+ flock
2761
+ family room
2762
+ manufacturing
2763
+ pointer
2764
+ worker
2765
+ navy
2766
+ potato
2767
+ teacher
2768
+ photography
2769
+ dolly
2770
+ boardwalk
2771
+ water fountain
2772
+ athlete
2773
+ side dish
2774
+ bay
2775
+ ice hockey
2776
+ phone
2777
+ hero
2778
+ face
2779
+ gold medal
2780
+ blind
2781
+ swamp
2782
+ researcher
2783
+ swim
2784
+ meatball
2785
+ iguana
2786
+ leather jacket
2787
+ jellyfish
2788
+ site
2789
+ smoke
2790
+ traffic signal
2791
+ melon
2792
+ beetle
2793
+ calculator
2794
+ skirt
2795
+ plantation
2796
+ sculptor
2797
+ barrier
2798
+ catcher
2799
+ security guard
2800
+ sketch
2801
+ awning
2802
+ steering wheel
2803
+ mountain view
2804
+ bus stop
2805
+ pool
2806
+ leg
2807
+ spotlight
2808
+ apron
2809
+ mineral
2810
+ inlet
2811
+ sleeve
2812
+ torch
2813
+ emotion
2814
+ march
2815
+ police officer
2816
+ performance
2817
+ lamp post
2818
+ fishing boat
2819
+ summer
2820
+ presentation
2821
+ saucer
2822
+ suitcase
2823
+ supermodel
2824
+ goalkeeper
2825
+ shrub
2826
+ rock artist
2827
+ document
2828
+ beach house
2829
+ man
2830
+ blue artist
2831
+ cigar
2832
+ railroad track
2833
+ gown
2834
+ mosaic
2835
+ bungalow
2836
+ alphabet
2837
+ baseball field
2838
+ shed
2839
+ pedestrian
2840
+ rail
2841
+ soap
2842
+ kitchen counter
2843
+ dessert
2844
+ dunk
2845
+ blossom
2846
+ conversation
2847
+ fruit market
2848
+ glass jar
2849
+ military
2850
+ beer bottle
2851
+ photographer
2852
+ tennis racket
2853
+ competition
2854
+ escalator
2855
+ bell tower
2856
+ stilt
2857
+ ballerina
2858
+ television
2859
+ feather
2860
+ fence post
2861
+ rear
2862
+ dahlia
2863
+ red carpet
2864
+ tub
2865
+ hole
2866
+ fortress
2867
+ pack
2868
+ telephone
2869
+ cardboard
2870
+ city park
2871
+ platform
2872
+ college student
2873
+ arch bridge
2874
+ wind
2875
+ blender
2876
+ bloom
2877
+ ice rink
2878
+ birthday
2879
+ raven
2880
+ fairy
2881
+ embankment
2882
+ hall
2883
+ flower shop
2884
+ suburb
2885
+ barrel
2886
+ biker
2887
+ steam
2888
+ dragonfly
2889
+ formation
2890
+ electricity
2891
+ business people
2892
+ symmetry
2893
+ walkway
2894
+ fisherman
2895
+ gas mask
2896
+ loch
2897
+ youth
2898
+ hanger
2899
+ dot
2900
+ fish
2901
+ street market
2902
+ animation film
2903
+ crime fiction film
2904
+ boar
2905
+ emblem
2906
+ halloween costume
2907
+ kangaroo
2908
+ couple
2909
+ spoon
2910
+ squirrel
2911
+ neon sign
2912
+ sky
2913
+ office desk
2914
+ beauty salon
2915
+ breakwater
2916
+ fashion look
2917
+ toaster
2918
+ author
2919
+ news conference
2920
+ outdoor
2921
+ canoe
2922
+ dragon
2923
+ tool
2924
+ shopping centre
2925
+ ladybug
2926
+ swimming pool
2927
+ landscaping
2928
+ ski pole
2929
+ red
2930
+ truck
2931
+ fly
2932
+ temple
2933
+ level
2934
+ sunday
2935
+ railroad bridge
2936
+ car mirror
2937
+ lawn mower
2938
+ flute
2939
+ aircraft carrier
2940
+ fashion menswear london week
2941
+ sunshine
2942
+ tile floor
2943
+ skull
2944
+ fossil
2945
+ flower arrangement
2946
+ diaper
2947
+ sea turtle
2948
+ cherry blossom
2949
+ fireman
2950
+ shack
2951
+ lens
2952
+ waiter
2953
+ animal
2954
+ basement
2955
+ snow
2956
+ autumn park
2957
+ glass box
2958
+ kick
2959
+ head
2960
+ anniversary
2961
+ vine
2962
+ back
2963
+ paper lantern
2964
+ fish tank
2965
+ cellphone
2966
+ silk
2967
+ coral
2968
+ notebook
2969
+ photo
2970
+ gazebo
2971
+ ketchup
2972
+ driver
2973
+ farmer
2974
+ bonfire
2975
+ chestnut
2976
+ photoshoot
2977
+ football field
2978
+ olive tree
2979
+ pheasant
2980
+ sandal
2981
+ toilet
2982
+ fireplace
2983
+ music
2984
+ deity
2985
+ fish market
2986
+ fig
2987
+ bell
2988
+ neck
2989
+ grave
2990
+ villa
2991
+ cyclist
2992
+ crate
2993
+ grey
2994
+ asphalt road
2995
+ soccer
2996
+ hostel
2997
+ municipality
2998
+ courthouse
2999
+ roof
3000
+ end table
3001
+ pot
3002
+ sedan
3003
+ structure
3004
+ folk artist
3005
+ sport
3006
+ sport team
3007
+ protest
3008
+ syringe
3009
+ fashion designer
3010
+ jersey
3011
+ heart shape
3012
+ kayak
3013
+ stare
3014
+ sit with
3015
+ direct
3016
+ read
3017
+ photograph
3018
+ spin
3019
+ teach
3020
+ laugh
3021
+ carve
3022
+ grow on
3023
+ warm
3024
+ watch
3025
+ stretch
3026
+ smell
3027
+ decorate
3028
+ shine
3029
+ light
3030
+ dance
3031
+ send
3032
+ park
3033
+ chase
3034
+ collect
3035
+ lead
3036
+ kiss
3037
+ lead to
3038
+ lick
3039
+ smile
3040
+ cheer
3041
+ sit
3042
+ point
3043
+ block
3044
+ rock
3045
+ drop
3046
+ cut
3047
+ ski
3048
+ wrap
3049
+ lose
3050
+ serve
3051
+ provide
3052
+ sleep
3053
+ dress
3054
+ embrace
3055
+ burn
3056
+ pack
3057
+ stir
3058
+ create
3059
+ touch
3060
+ wash
3061
+ stick
3062
+ reveal
3063
+ shop
3064
+ train
3065
+ paint
3066
+ groom
3067
+ hunt
3068
+ bloom
3069
+ play
3070
+ pay
3071
+ brush
3072
+ shoot
3073
+ hold
3074
+ picture
3075
+ carry
3076
+ sip
3077
+ contain
3078
+ turn
3079
+ pour
3080
+ pitch
3081
+ give
3082
+ add
3083
+ blow
3084
+ look in
3085
+ show
3086
+ walk
3087
+ illuminate
3088
+ kneel
3089
+ cover
3090
+ drag
3091
+ post
3092
+ present
3093
+ fit
3094
+ operate
3095
+ fish
3096
+ race
3097
+ write
3098
+ deliver
3099
+ peel
3100
+ push
3101
+ run
3102
+ sit around
3103
+ buy
3104
+ jump
3105
+ walk on
3106
+ attend
3107
+ clean
3108
+ sell
3109
+ ride on
3110
+ mount
3111
+ host
3112
+ dry
3113
+ plant
3114
+ sing
3115
+ row
3116
+ shake
3117
+ perch
3118
+ ride
3119
+ fight
3120
+ skateboard
3121
+ live
3122
+ call
3123
+ surround
3124
+ practice
3125
+ play on
3126
+ work on
3127
+ step
3128
+ relax
3129
+ hit
3130
+ fall in
3131
+ flow
3132
+ greet
3133
+ launch
3134
+ wear
3135
+ hang on
3136
+ drive
3137
+ sit in
3138
+ break
3139
+ learn
3140
+ fly
3141
+ connect
3142
+ display
3143
+ locate
3144
+ compete
3145
+ go for
3146
+ sail
3147
+ lift
3148
+ toast
3149
+ help
3150
+ run on
3151
+ reflect
3152
+ pose
3153
+ scratch
3154
+ frame
3155
+ dribble
3156
+ herd
3157
+ enter
3158
+ exit
3159
+ place
3160
+ inspect
3161
+ build
3162
+ pick
3163
+ fill
3164
+ grind
3165
+ skate
3166
+ offer
3167
+ float
3168
+ sit by
3169
+ stand
3170
+ release
3171
+ rest
3172
+ singe
3173
+ climb
3174
+ tie
3175
+ mark
3176
+ lay
3177
+ stand around
3178
+ capture
3179
+ set
3180
+ land
3181
+ swinge
3182
+ run in
3183
+ kick
3184
+ lean
3185
+ head
3186
+ sign
3187
+ approach
3188
+ swim
3189
+ close
3190
+ crash
3191
+ control
3192
+ fall
3193
+ remove
3194
+ repair
3195
+ open
3196
+ appear
3197
+ travel
3198
+ load
3199
+ miss
3200
+ check
3201
+ surf
3202
+ moor
3203
+ smoke
3204
+ drink
3205
+ board
3206
+ seat
3207
+ feed
3208
+ rise
3209
+ sit on
3210
+ swing
3211
+ grow
3212
+ strike
3213
+ date
3214
+ slide
3215
+ share
3216
+ graze
3217
+ jump in
3218
+ lie
3219
+ extrude
3220
+ roll
3221
+ move
3222
+ gather
3223
+ eat
3224
+ pull
3225
+ run through
3226
+ squeeze
3227
+ lay on
3228
+ draw
3229
+ play with
3230
+ wave
3231
+ assemble
3232
+ perform
3233
+ march
3234
+ score
3235
+ attach
3236
+ adjust
3237
+ hang
3238
+ hug
3239
+ sleep on
3240
+ throw
3241
+ live in
3242
+ talk
3243
+ pet
3244
+ work
3245
+ run with
3246
+ see
3247
+ flip
3248
+ catch
3249
+ cook
3250
+ receive
3251
+ celebrate
3252
+ look
3253
+ classic
3254
+ bridal
3255
+ indoor
3256
+ industrial
3257
+ teenage
3258
+ mini
3259
+ grassy
3260
+ aged
3261
+ long
3262
+ warm
3263
+ light
3264
+ handsome
3265
+ happy
3266
+ three
3267
+ pregnant
3268
+ circular
3269
+ urban
3270
+ silver
3271
+ ceramic
3272
+ 3d
3273
+ green
3274
+ blonde
3275
+ golden
3276
+ dark
3277
+ tropical
3278
+ ripe
3279
+ deep
3280
+ fat
3281
+ musical
3282
+ giant
3283
+ medical
3284
+ medieval
3285
+ bare
3286
+ stunning
3287
+ bold
3288
+ geographical
3289
+ huge
3290
+ plastic
3291
+ foggy
3292
+ stormy
3293
+ gothic
3294
+ biological
3295
+ empty
3296
+ clear
3297
+ antique
3298
+ pink
3299
+ steep
3300
+ brown
3301
+ striped
3302
+ aerial
3303
+ rainy
3304
+ cool
3305
+ flying
3306
+ commercial
3307
+ purple
3308
+ trendy
3309
+ blank
3310
+ haired
3311
+ dead
3312
+ wooden
3313
+ flat
3314
+ high
3315
+ beige
3316
+ panoramic
3317
+ angry
3318
+ dozen
3319
+ rural
3320
+ solar
3321
+ big
3322
+ small
3323
+ stained
3324
+ thick
3325
+ many
3326
+ fresh
3327
+ clean
3328
+ strong
3329
+ abstract
3330
+ crowded
3331
+ retro
3332
+ dry
3333
+ gorgeous
3334
+ martial
3335
+ modern
3336
+ blue
3337
+ cloudy
3338
+ low
3339
+ four
3340
+ outdoor
3341
+ single
3342
+ much
3343
+ beautiful
3344
+ snowy
3345
+ pretty
3346
+ new
3347
+ short
3348
+ sunny
3349
+ closed
3350
+ rocky
3351
+ red
3352
+ two
3353
+ double
3354
+ male
3355
+ gray
3356
+ five
3357
+ colorful
3358
+ automotive
3359
+ various
3360
+ one
3361
+ old
3362
+ rusty
3363
+ tall
3364
+ wild
3365
+ narrow
3366
+ natural
3367
+ several
3368
+ frozen
3369
+ textured
3370
+ lush
3371
+ young
3372
+ hot
3373
+ mixed
3374
+ white
3375
+ float
3376
+ quiet
3377
+ round
3378
+ bright
3379
+ religious
3380
+ female
3381
+ historical
3382
+ shiny
3383
+ traditional
3384
+ tourist
3385
+ yellow
3386
+ bald
3387
+ coastal
3388
+ lovely
3389
+ little
3390
+ broken
3391
+ romantic
3392
+ wide
3393
+ royal
3394
+ rich
3395
+ open
3396
+ cute
3397
+ ancient
3398
+ cold
3399
+ political
3400
+ elderly
3401
+ gold
3402
+ full
3403
+ rustic
3404
+ metallic
3405
+ floral
3406
+ sad
3407
+ wet
3408
+ fancy
3409
+ senior
3410
+ tiny
3411
+ stylish
3412
+ large
3413
+ frosty
3414
+ orange
3415
+ transparent
3416
+ electronic
3417
+ shallow
3418
+ scared
3419
+ armed
3420
+ dirty
3421
+ historic
3422
+ black
3423
+ few
3424
+ windy
3425
+ some
3426
+ square
3427
+ ornamental
3428
+ sandy
3429
+ thin
tag2text/inference.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Tag2Text
3
+ * Written by Xinyu Huang
4
+ """
5
+ import argparse
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ from models.tag2text import tag2text_caption
12
+ from PIL import Image
13
+
14
+ parser = argparse.ArgumentParser(
15
+ description="Tag2Text inferece for tagging and captioning"
16
+ )
17
+ parser.add_argument(
18
+ "--image",
19
+ metavar="DIR",
20
+ help="path to dataset",
21
+ default="images/1641173_2291260800.jpg",
22
+ )
23
+ parser.add_argument(
24
+ "--pretrained",
25
+ metavar="DIR",
26
+ help="path to pretrained model",
27
+ default="pretrained/tag2text_swin_14m.pth",
28
+ )
29
+ parser.add_argument(
30
+ "--image-size",
31
+ default=384,
32
+ type=int,
33
+ metavar="N",
34
+ help="input image size (default: 448)",
35
+ )
36
+ parser.add_argument(
37
+ "--thre", default=0.68, type=float, metavar="N", help="threshold value"
38
+ )
39
+ parser.add_argument(
40
+ "--specified-tags", default="None", help="User input specified tags"
41
+ )
42
+
43
+
44
+ def inference(image, model, input_tag="None"):
45
+ with torch.no_grad():
46
+ caption, tag_predict = model.generate(
47
+ image, tag_input=None, max_length=50, return_tag_predict=True
48
+ )
49
+
50
+ if input_tag == "" or input_tag == "none" or input_tag == "None":
51
+ return tag_predict[0], None, caption[0]
52
+
53
+ # If user input specified tags:
54
+ else:
55
+ input_tag_list = []
56
+ input_tag_list.append(input_tag.replace(",", " | "))
57
+
58
+ with torch.no_grad():
59
+ caption, input_tag = model.generate(
60
+ image, tag_input=input_tag_list, max_length=50, return_tag_predict=True
61
+ )
62
+
63
+ return tag_predict[0], input_tag[0], caption[0]
64
+
65
+
66
+ if __name__ == "__main__":
67
+ args = parser.parse_args()
68
+
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ normalize = transforms.Normalize(
71
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
72
+ )
73
+ transform = transforms.Compose(
74
+ [
75
+ transforms.Resize((args.image_size, args.image_size)),
76
+ transforms.ToTensor(),
77
+ normalize,
78
+ ]
79
+ )
80
+
81
+ # delete some tags that may disturb captioning
82
+ # 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
83
+ delete_tag_index = [127, 2961, 3351, 3265, 3338, 3355, 3359]
84
+
85
+ #######load model
86
+ model = tag2text_caption(
87
+ pretrained=args.pretrained,
88
+ image_size=args.image_size,
89
+ vit="swin_b",
90
+ delete_tag_index=delete_tag_index,
91
+ )
92
+ model.threshold = args.thre # threshold for tagging
93
+ model.eval()
94
+
95
+ model = model.to(device)
96
+ raw_image = Image.open(args.image).resize((args.image_size, args.image_size))
97
+ image = transform(raw_image).unsqueeze(0).to(device)
98
+
99
+ res = inference(image, model, args.specified_tags)
100
+ print("Model Identified Tags: ", res[0])
101
+ print("User Specified Tags: ", res[1])
102
+ print("Image Caption: ", res[2])
tag2text/models/bert.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+ import math
11
+ import os
12
+ import warnings
13
+ from dataclasses import dataclass
14
+ from typing import Optional
15
+ from typing import Tuple
16
+
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint
19
+ from torch import device
20
+ from torch import dtype
21
+ from torch import nn
22
+ from torch import Tensor
23
+ from torch.nn import CrossEntropyLoss
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
29
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
30
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
31
+ from transformers.modeling_outputs import MaskedLMOutput
32
+ from transformers.modeling_outputs import MultipleChoiceModelOutput
33
+ from transformers.modeling_outputs import NextSentencePredictorOutput
34
+ from transformers.modeling_outputs import QuestionAnsweringModelOutput
35
+ from transformers.modeling_outputs import SequenceClassifierOutput
36
+ from transformers.modeling_outputs import TokenClassifierOutput
37
+ from transformers.modeling_utils import apply_chunking_to_forward
38
+ from transformers.modeling_utils import find_pruneable_heads_and_indices
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.modeling_utils import prune_linear_layer
41
+ from transformers.models.bert.configuration_bert import BertConfig
42
+ from transformers.utils import logging
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ class BertEmbeddings_nopos(nn.Module):
48
+ """Construct the embeddings from word and position embeddings."""
49
+
50
+ def __init__(self, config):
51
+ super().__init__()
52
+ self.word_embeddings = nn.Embedding(
53
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
54
+ )
55
+ # self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
56
+
57
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
58
+ # any TensorFlow checkpoint file
59
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
60
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
61
+
62
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
63
+ # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
64
+ # self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
65
+
66
+ self.config = config
67
+
68
+ def forward(
69
+ self,
70
+ input_ids=None,
71
+ position_ids=None,
72
+ inputs_embeds=None,
73
+ past_key_values_length=0,
74
+ ):
75
+ if input_ids is not None:
76
+ input_shape = input_ids.size()
77
+ else:
78
+ input_shape = inputs_embeds.size()[:-1]
79
+
80
+ seq_length = input_shape[1]
81
+
82
+ # if position_ids is None:
83
+ # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
84
+
85
+ if inputs_embeds is None:
86
+ inputs_embeds = self.word_embeddings(input_ids)
87
+
88
+ embeddings = inputs_embeds
89
+
90
+ # if self.position_embedding_type == "absolute":
91
+ # position_embeddings = self.position_embeddings(position_ids)
92
+ # # print('add position_embeddings!!!!')
93
+ # embeddings += position_embeddings
94
+ embeddings = self.LayerNorm(embeddings)
95
+ embeddings = self.dropout(embeddings)
96
+ return embeddings
97
+
98
+
99
+ class BertEmbeddings(nn.Module):
100
+ """Construct the embeddings from word and position embeddings."""
101
+
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.word_embeddings = nn.Embedding(
105
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
106
+ )
107
+ self.position_embeddings = nn.Embedding(
108
+ config.max_position_embeddings, config.hidden_size
109
+ )
110
+
111
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
112
+ # any TensorFlow checkpoint file
113
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
114
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
115
+
116
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
117
+ self.register_buffer(
118
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
119
+ )
120
+ self.position_embedding_type = getattr(
121
+ config, "position_embedding_type", "absolute"
122
+ )
123
+
124
+ self.config = config
125
+
126
+ def forward(
127
+ self,
128
+ input_ids=None,
129
+ position_ids=None,
130
+ inputs_embeds=None,
131
+ past_key_values_length=0,
132
+ ):
133
+ if input_ids is not None:
134
+ input_shape = input_ids.size()
135
+ else:
136
+ input_shape = inputs_embeds.size()[:-1]
137
+
138
+ seq_length = input_shape[1]
139
+
140
+ if position_ids is None:
141
+ position_ids = self.position_ids[
142
+ :, past_key_values_length : seq_length + past_key_values_length
143
+ ]
144
+
145
+ if inputs_embeds is None:
146
+ inputs_embeds = self.word_embeddings(input_ids)
147
+
148
+ embeddings = inputs_embeds
149
+
150
+ if self.position_embedding_type == "absolute":
151
+ position_embeddings = self.position_embeddings(position_ids)
152
+ # print('add position_embeddings!!!!')
153
+ embeddings += position_embeddings
154
+ embeddings = self.LayerNorm(embeddings)
155
+ embeddings = self.dropout(embeddings)
156
+ return embeddings
157
+
158
+
159
+ class BertSelfAttention(nn.Module):
160
+ def __init__(self, config, is_cross_attention):
161
+ super().__init__()
162
+ self.config = config
163
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
164
+ config, "embedding_size"
165
+ ):
166
+ raise ValueError(
167
+ "The hidden size (%d) is not a multiple of the number of attention "
168
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
169
+ )
170
+
171
+ self.num_attention_heads = config.num_attention_heads
172
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
173
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
174
+
175
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
176
+ if is_cross_attention:
177
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
178
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
179
+ else:
180
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
181
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
182
+
183
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
184
+ self.position_embedding_type = getattr(
185
+ config, "position_embedding_type", "absolute"
186
+ )
187
+ if (
188
+ self.position_embedding_type == "relative_key"
189
+ or self.position_embedding_type == "relative_key_query"
190
+ ):
191
+ self.max_position_embeddings = config.max_position_embeddings
192
+ self.distance_embedding = nn.Embedding(
193
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
194
+ )
195
+ self.save_attention = False
196
+
197
+ def save_attn_gradients(self, attn_gradients):
198
+ self.attn_gradients = attn_gradients
199
+
200
+ def get_attn_gradients(self):
201
+ return self.attn_gradients
202
+
203
+ def save_attention_map(self, attention_map):
204
+ self.attention_map = attention_map
205
+
206
+ def get_attention_map(self):
207
+ return self.attention_map
208
+
209
+ def transpose_for_scores(self, x):
210
+ new_x_shape = x.size()[:-1] + (
211
+ self.num_attention_heads,
212
+ self.attention_head_size,
213
+ )
214
+ x = x.view(*new_x_shape)
215
+ return x.permute(0, 2, 1, 3)
216
+
217
+ def forward(
218
+ self,
219
+ hidden_states,
220
+ attention_mask=None,
221
+ head_mask=None,
222
+ encoder_hidden_states=None,
223
+ encoder_attention_mask=None,
224
+ past_key_value=None,
225
+ output_attentions=False,
226
+ ):
227
+ mixed_query_layer = self.query(hidden_states)
228
+
229
+ # If this is instantiated as a cross-attention module, the keys
230
+ # and values come from an encoder; the attention mask needs to be
231
+ # such that the encoder's padding tokens are not attended to.
232
+ is_cross_attention = encoder_hidden_states is not None
233
+
234
+ if is_cross_attention:
235
+ # print(self.key.weight.shape)
236
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
237
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
238
+ attention_mask = encoder_attention_mask
239
+ elif past_key_value is not None:
240
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
241
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
242
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
243
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
244
+ else:
245
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
246
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
247
+
248
+ query_layer = self.transpose_for_scores(mixed_query_layer)
249
+
250
+ past_key_value = (key_layer, value_layer)
251
+
252
+ # compatible with higher versions of transformers
253
+ if key_layer.shape[0] > query_layer.shape[0]:
254
+ key_layer = key_layer[: query_layer.shape[0], :, :, :]
255
+ attention_mask = attention_mask[: query_layer.shape[0], :, :]
256
+ value_layer = value_layer[: query_layer.shape[0], :, :, :]
257
+
258
+ # Take the dot product between "query" and "key" to get the raw attention scores.
259
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
260
+
261
+ if (
262
+ self.position_embedding_type == "relative_key"
263
+ or self.position_embedding_type == "relative_key_query"
264
+ ):
265
+ seq_length = hidden_states.size()[1]
266
+ position_ids_l = torch.arange(
267
+ seq_length, dtype=torch.long, device=hidden_states.device
268
+ ).view(-1, 1)
269
+ position_ids_r = torch.arange(
270
+ seq_length, dtype=torch.long, device=hidden_states.device
271
+ ).view(1, -1)
272
+ distance = position_ids_l - position_ids_r
273
+ positional_embedding = self.distance_embedding(
274
+ distance + self.max_position_embeddings - 1
275
+ )
276
+ positional_embedding = positional_embedding.to(
277
+ dtype=query_layer.dtype
278
+ ) # fp16 compatibility
279
+
280
+ if self.position_embedding_type == "relative_key":
281
+ relative_position_scores = torch.einsum(
282
+ "bhld,lrd->bhlr", query_layer, positional_embedding
283
+ )
284
+ attention_scores = attention_scores + relative_position_scores
285
+ elif self.position_embedding_type == "relative_key_query":
286
+ relative_position_scores_query = torch.einsum(
287
+ "bhld,lrd->bhlr", query_layer, positional_embedding
288
+ )
289
+ relative_position_scores_key = torch.einsum(
290
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
291
+ )
292
+ attention_scores = (
293
+ attention_scores
294
+ + relative_position_scores_query
295
+ + relative_position_scores_key
296
+ )
297
+
298
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
299
+ if attention_mask is not None:
300
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
301
+ attention_scores = attention_scores + attention_mask
302
+
303
+ # Normalize the attention scores to probabilities.
304
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
305
+
306
+ if is_cross_attention and self.save_attention:
307
+ self.save_attention_map(attention_probs)
308
+ attention_probs.register_hook(self.save_attn_gradients)
309
+
310
+ # This is actually dropping out entire tokens to attend to, which might
311
+ # seem a bit unusual, but is taken from the original Transformer paper.
312
+ attention_probs_dropped = self.dropout(attention_probs)
313
+
314
+ # Mask heads if we want to
315
+ if head_mask is not None:
316
+ attention_probs_dropped = attention_probs_dropped * head_mask
317
+
318
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
319
+
320
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
321
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
322
+ context_layer = context_layer.view(*new_context_layer_shape)
323
+
324
+ outputs = (
325
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
326
+ )
327
+
328
+ outputs = outputs + (past_key_value,)
329
+ return outputs
330
+
331
+
332
+ class BertSelfOutput(nn.Module):
333
+ def __init__(self, config):
334
+ super().__init__()
335
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
336
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
337
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
338
+
339
+ def forward(self, hidden_states, input_tensor):
340
+ hidden_states = self.dense(hidden_states)
341
+ hidden_states = self.dropout(hidden_states)
342
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
343
+ return hidden_states
344
+
345
+
346
+ class BertAttention(nn.Module):
347
+ def __init__(self, config, is_cross_attention=False):
348
+ super().__init__()
349
+ self.self = BertSelfAttention(config, is_cross_attention)
350
+ self.output = BertSelfOutput(config)
351
+ self.pruned_heads = set()
352
+
353
+ def prune_heads(self, heads):
354
+ if len(heads) == 0:
355
+ return
356
+ heads, index = find_pruneable_heads_and_indices(
357
+ heads,
358
+ self.self.num_attention_heads,
359
+ self.self.attention_head_size,
360
+ self.pruned_heads,
361
+ )
362
+
363
+ # Prune linear layers
364
+ self.self.query = prune_linear_layer(self.self.query, index)
365
+ self.self.key = prune_linear_layer(self.self.key, index)
366
+ self.self.value = prune_linear_layer(self.self.value, index)
367
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
368
+
369
+ # Update hyper params and store pruned heads
370
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
371
+ self.self.all_head_size = (
372
+ self.self.attention_head_size * self.self.num_attention_heads
373
+ )
374
+ self.pruned_heads = self.pruned_heads.union(heads)
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states,
379
+ attention_mask=None,
380
+ head_mask=None,
381
+ encoder_hidden_states=None,
382
+ encoder_attention_mask=None,
383
+ past_key_value=None,
384
+ output_attentions=False,
385
+ ):
386
+ self_outputs = self.self(
387
+ hidden_states,
388
+ attention_mask,
389
+ head_mask,
390
+ encoder_hidden_states,
391
+ encoder_attention_mask,
392
+ past_key_value,
393
+ output_attentions,
394
+ )
395
+ attention_output = self.output(self_outputs[0], hidden_states)
396
+ outputs = (attention_output,) + self_outputs[
397
+ 1:
398
+ ] # add attentions if we output them
399
+ return outputs
400
+
401
+
402
+ class BertIntermediate(nn.Module):
403
+ def __init__(self, config):
404
+ super().__init__()
405
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
406
+ if isinstance(config.hidden_act, str):
407
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
408
+ else:
409
+ self.intermediate_act_fn = config.hidden_act
410
+
411
+ def forward(self, hidden_states):
412
+ hidden_states = self.dense(hidden_states)
413
+ hidden_states = self.intermediate_act_fn(hidden_states)
414
+ return hidden_states
415
+
416
+
417
+ class BertOutput(nn.Module):
418
+ def __init__(self, config):
419
+ super().__init__()
420
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
421
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
422
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
423
+
424
+ def forward(self, hidden_states, input_tensor):
425
+ hidden_states = self.dense(hidden_states)
426
+ hidden_states = self.dropout(hidden_states)
427
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
428
+ return hidden_states
429
+
430
+
431
+ class BertLayer(nn.Module):
432
+ def __init__(self, config, layer_num):
433
+ super().__init__()
434
+ self.config = config
435
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
436
+ self.seq_len_dim = 1
437
+ self.attention = BertAttention(config)
438
+ self.layer_num = layer_num
439
+ if self.config.add_cross_attention:
440
+ self.crossattention = BertAttention(
441
+ config, is_cross_attention=self.config.add_cross_attention
442
+ )
443
+ self.intermediate = BertIntermediate(config)
444
+ self.output = BertOutput(config)
445
+
446
+ def forward(
447
+ self,
448
+ hidden_states,
449
+ attention_mask=None,
450
+ head_mask=None,
451
+ encoder_hidden_states=None,
452
+ encoder_attention_mask=None,
453
+ past_key_value=None,
454
+ output_attentions=False,
455
+ mode=None,
456
+ ):
457
+ if mode == "tagging":
458
+ assert (
459
+ encoder_hidden_states is not None
460
+ ), "encoder_hidden_states must be given for cross-attention layers"
461
+
462
+ cross_attention_outputs = self.crossattention(
463
+ hidden_states,
464
+ attention_mask,
465
+ head_mask,
466
+ encoder_hidden_states,
467
+ encoder_attention_mask,
468
+ output_attentions=output_attentions,
469
+ )
470
+ attention_output = cross_attention_outputs[0]
471
+ outputs = cross_attention_outputs[
472
+ 1:-1
473
+ ] # add cross attentions if we output attention weights
474
+
475
+ present_key_value = cross_attention_outputs[-1]
476
+
477
+ else:
478
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
479
+ self_attn_past_key_value = (
480
+ past_key_value[:2] if past_key_value is not None else None
481
+ )
482
+ self_attention_outputs = self.attention(
483
+ hidden_states,
484
+ attention_mask,
485
+ head_mask,
486
+ output_attentions=output_attentions,
487
+ past_key_value=self_attn_past_key_value,
488
+ )
489
+ attention_output = self_attention_outputs[0]
490
+
491
+ outputs = self_attention_outputs[1:-1]
492
+ present_key_value = self_attention_outputs[-1]
493
+
494
+ if mode == "multimodal":
495
+ assert (
496
+ encoder_hidden_states is not None
497
+ ), "encoder_hidden_states must be given for cross-attention layers"
498
+
499
+ cross_attention_outputs = self.crossattention(
500
+ attention_output,
501
+ attention_mask,
502
+ head_mask,
503
+ encoder_hidden_states,
504
+ encoder_attention_mask,
505
+ output_attentions=output_attentions,
506
+ )
507
+ attention_output = cross_attention_outputs[0]
508
+ outputs = (
509
+ outputs + cross_attention_outputs[1:-1]
510
+ ) # add cross attentions if we output attention weights
511
+ layer_output = apply_chunking_to_forward(
512
+ self.feed_forward_chunk,
513
+ self.chunk_size_feed_forward,
514
+ self.seq_len_dim,
515
+ attention_output,
516
+ )
517
+ outputs = (layer_output,) + outputs
518
+
519
+ outputs = outputs + (present_key_value,)
520
+
521
+ return outputs
522
+
523
+ def feed_forward_chunk(self, attention_output):
524
+ intermediate_output = self.intermediate(attention_output)
525
+ layer_output = self.output(intermediate_output, attention_output)
526
+ return layer_output
527
+
528
+
529
+ class BertEncoder(nn.Module):
530
+ def __init__(self, config):
531
+ super().__init__()
532
+ self.config = config
533
+ self.layer = nn.ModuleList(
534
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
535
+ )
536
+ self.gradient_checkpointing = False
537
+
538
+ def forward(
539
+ self,
540
+ hidden_states,
541
+ attention_mask=None,
542
+ head_mask=None,
543
+ encoder_hidden_states=None,
544
+ encoder_attention_mask=None,
545
+ past_key_values=None,
546
+ use_cache=None,
547
+ output_attentions=False,
548
+ output_hidden_states=False,
549
+ return_dict=True,
550
+ mode="multimodal",
551
+ ):
552
+ all_hidden_states = () if output_hidden_states else None
553
+ all_self_attentions = () if output_attentions else None
554
+ all_cross_attentions = (
555
+ () if output_attentions and self.config.add_cross_attention else None
556
+ )
557
+
558
+ next_decoder_cache = () if use_cache else None
559
+
560
+ for i in range(self.config.num_hidden_layers):
561
+ layer_module = self.layer[i]
562
+ if output_hidden_states:
563
+ all_hidden_states = all_hidden_states + (hidden_states,)
564
+
565
+ layer_head_mask = head_mask[i] if head_mask is not None else None
566
+ past_key_value = past_key_values[i] if past_key_values is not None else None
567
+
568
+ if self.gradient_checkpointing and self.training:
569
+ if use_cache:
570
+ logger.warn(
571
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
572
+ )
573
+ use_cache = False
574
+
575
+ def create_custom_forward(module):
576
+ def custom_forward(*inputs):
577
+ return module(*inputs, past_key_value, output_attentions)
578
+
579
+ return custom_forward
580
+
581
+ layer_outputs = torch.utils.checkpoint.checkpoint(
582
+ create_custom_forward(layer_module),
583
+ hidden_states,
584
+ attention_mask,
585
+ layer_head_mask,
586
+ encoder_hidden_states,
587
+ encoder_attention_mask,
588
+ mode=mode,
589
+ )
590
+ else:
591
+ layer_outputs = layer_module(
592
+ hidden_states,
593
+ attention_mask,
594
+ layer_head_mask,
595
+ encoder_hidden_states,
596
+ encoder_attention_mask,
597
+ past_key_value,
598
+ output_attentions,
599
+ mode=mode,
600
+ )
601
+
602
+ hidden_states = layer_outputs[0]
603
+ if use_cache:
604
+ next_decoder_cache += (layer_outputs[-1],)
605
+ if output_attentions:
606
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
607
+
608
+ if output_hidden_states:
609
+ all_hidden_states = all_hidden_states + (hidden_states,)
610
+
611
+ if not return_dict:
612
+ return tuple(
613
+ v
614
+ for v in [
615
+ hidden_states,
616
+ next_decoder_cache,
617
+ all_hidden_states,
618
+ all_self_attentions,
619
+ all_cross_attentions,
620
+ ]
621
+ if v is not None
622
+ )
623
+ return BaseModelOutputWithPastAndCrossAttentions(
624
+ last_hidden_state=hidden_states,
625
+ past_key_values=next_decoder_cache,
626
+ hidden_states=all_hidden_states,
627
+ attentions=all_self_attentions,
628
+ cross_attentions=all_cross_attentions,
629
+ )
630
+
631
+
632
+ class BertPooler(nn.Module):
633
+ def __init__(self, config):
634
+ super().__init__()
635
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
636
+ self.activation = nn.Tanh()
637
+
638
+ def forward(self, hidden_states):
639
+ # We "pool" the model by simply taking the hidden state corresponding
640
+ # to the first token.
641
+ first_token_tensor = hidden_states[:, 0]
642
+ pooled_output = self.dense(first_token_tensor)
643
+ pooled_output = self.activation(pooled_output)
644
+ return pooled_output
645
+
646
+
647
+ class BertPredictionHeadTransform(nn.Module):
648
+ def __init__(self, config):
649
+ super().__init__()
650
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
651
+ if isinstance(config.hidden_act, str):
652
+ self.transform_act_fn = ACT2FN[config.hidden_act]
653
+ else:
654
+ self.transform_act_fn = config.hidden_act
655
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
656
+
657
+ def forward(self, hidden_states):
658
+ hidden_states = self.dense(hidden_states)
659
+ hidden_states = self.transform_act_fn(hidden_states)
660
+ hidden_states = self.LayerNorm(hidden_states)
661
+ return hidden_states
662
+
663
+
664
+ class BertLMPredictionHead(nn.Module):
665
+ def __init__(self, config):
666
+ super().__init__()
667
+ self.transform = BertPredictionHeadTransform(config)
668
+
669
+ # The output weights are the same as the input embeddings, but there is
670
+ # an output-only bias for each token.
671
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
672
+
673
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
674
+
675
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
676
+ self.decoder.bias = self.bias
677
+
678
+ def forward(self, hidden_states):
679
+ hidden_states = self.transform(hidden_states)
680
+ hidden_states = self.decoder(hidden_states)
681
+ return hidden_states
682
+
683
+
684
+ class BertOnlyMLMHead(nn.Module):
685
+ def __init__(self, config):
686
+ super().__init__()
687
+ self.predictions = BertLMPredictionHead(config)
688
+
689
+ def forward(self, sequence_output):
690
+ prediction_scores = self.predictions(sequence_output)
691
+ return prediction_scores
692
+
693
+
694
+ class BertPreTrainedModel(PreTrainedModel):
695
+ """An abstract class to handle weights initialization and a simple interface for downloading and loading
696
+ pretrained models."""
697
+
698
+ config_class = BertConfig
699
+ base_model_prefix = "bert"
700
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
701
+
702
+ def _init_weights(self, module):
703
+ """Initialize the weights."""
704
+ if isinstance(module, (nn.Linear, nn.Embedding)):
705
+ # Slightly different from the TF version which uses truncated_normal for initialization
706
+ # cf https://github.com/pytorch/pytorch/pull/5617
707
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
708
+ elif isinstance(module, nn.LayerNorm):
709
+ module.bias.data.zero_()
710
+ module.weight.data.fill_(1.0)
711
+ if isinstance(module, nn.Linear) and module.bias is not None:
712
+ module.bias.data.zero_()
713
+
714
+
715
+ class BertModel(BertPreTrainedModel):
716
+ """The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
717
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention
718
+ is all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob
719
+ Uszkoreit, Llion Jones, Aidan N.
720
+
721
+ Gomez, Lukasz Kaiser and Illia Polosukhin. argument and :obj:`add_cross_attention` set to :obj:`True`; an
722
+ :obj:`encoder_hidden_states` is then expected as an input to the forward pass.
723
+ """
724
+
725
+ def __init__(self, config, add_pooling_layer=True):
726
+ super().__init__(config)
727
+ self.config = config
728
+
729
+ self.embeddings = BertEmbeddings(config)
730
+
731
+ self.encoder = BertEncoder(config)
732
+
733
+ self.pooler = BertPooler(config) if add_pooling_layer else None
734
+
735
+ self.init_weights()
736
+
737
+ def get_input_embeddings(self):
738
+ return self.embeddings.word_embeddings
739
+
740
+ def set_input_embeddings(self, value):
741
+ self.embeddings.word_embeddings = value
742
+
743
+ def _prune_heads(self, heads_to_prune):
744
+ """Prunes heads of the model.
745
+
746
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
747
+ class PreTrainedModel
748
+ """
749
+ for layer, heads in heads_to_prune.items():
750
+ self.encoder.layer[layer].attention.prune_heads(heads)
751
+
752
+ def get_extended_attention_mask(
753
+ self,
754
+ attention_mask: Tensor,
755
+ input_shape: Tuple[int],
756
+ device: device,
757
+ is_decoder: bool,
758
+ ) -> Tensor:
759
+ """Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
760
+
761
+ Arguments:
762
+ attention_mask (:obj:`torch.Tensor`):
763
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
764
+ input_shape (:obj:`Tuple[int]`):
765
+ The shape of the input to the model.
766
+ device: (:obj:`torch.device`):
767
+ The device of the input to the model.
768
+
769
+ Returns:
770
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
771
+ """
772
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
773
+ # ourselves in which case we just need to make it broadcastable to all heads.
774
+ if attention_mask.dim() == 3:
775
+ extended_attention_mask = attention_mask[:, None, :, :]
776
+ elif attention_mask.dim() == 2:
777
+ # Provided a padding mask of dimensions [batch_size, seq_length]
778
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
779
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
780
+ if is_decoder:
781
+ batch_size, seq_length = input_shape
782
+
783
+ seq_ids = torch.arange(seq_length, device=device)
784
+ causal_mask = (
785
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
786
+ <= seq_ids[None, :, None]
787
+ )
788
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
789
+ # causal and attention masks must have same type with pytorch version < 1.3
790
+ causal_mask = causal_mask.to(attention_mask.dtype)
791
+
792
+ if causal_mask.shape[1] < attention_mask.shape[1]:
793
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
794
+ causal_mask = torch.cat(
795
+ [
796
+ torch.ones(
797
+ (batch_size, seq_length, prefix_seq_len),
798
+ device=device,
799
+ dtype=causal_mask.dtype,
800
+ ),
801
+ causal_mask,
802
+ ],
803
+ axis=-1,
804
+ )
805
+
806
+ extended_attention_mask = (
807
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
808
+ )
809
+ else:
810
+ extended_attention_mask = attention_mask[:, None, None, :]
811
+ else:
812
+ raise ValueError(
813
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
814
+ input_shape, attention_mask.shape
815
+ )
816
+ )
817
+
818
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
819
+ # masked positions, this operation will create a tensor which is 0.0 for
820
+ # positions we want to attend and -10000.0 for masked positions.
821
+ # Since we are adding it to the raw scores before the softmax, this is
822
+ # effectively the same as removing these entirely.
823
+ extended_attention_mask = extended_attention_mask.to(
824
+ dtype=self.dtype
825
+ ) # fp16 compatibility
826
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
827
+ return extended_attention_mask
828
+
829
+ def forward(
830
+ self,
831
+ input_ids=None,
832
+ attention_mask=None,
833
+ position_ids=None,
834
+ head_mask=None,
835
+ inputs_embeds=None,
836
+ encoder_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ past_key_values=None,
840
+ use_cache=None,
841
+ output_attentions=None,
842
+ output_hidden_states=None,
843
+ return_dict=None,
844
+ is_decoder=False,
845
+ mode="multimodal",
846
+ ):
847
+ r"""
848
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
849
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
850
+ the model is configured as a decoder.
851
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
852
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
853
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
854
+ - 1 for tokens that are **not masked**,
855
+ - 0 for tokens that are **masked**.
856
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
857
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
858
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
859
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
860
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
861
+ use_cache (:obj:`bool`, `optional`):
862
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
863
+ decoding (see :obj:`past_key_values`).
864
+ """
865
+ output_attentions = (
866
+ output_attentions
867
+ if output_attentions is not None
868
+ else self.config.output_attentions
869
+ )
870
+ output_hidden_states = (
871
+ output_hidden_states
872
+ if output_hidden_states is not None
873
+ else self.config.output_hidden_states
874
+ )
875
+ return_dict = (
876
+ return_dict if return_dict is not None else self.config.use_return_dict
877
+ )
878
+
879
+ if is_decoder:
880
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
881
+ else:
882
+ use_cache = False
883
+
884
+ if input_ids is not None and inputs_embeds is not None:
885
+ raise ValueError(
886
+ "You cannot specify both input_ids and inputs_embeds at the same time"
887
+ )
888
+ elif input_ids is not None:
889
+ input_shape = input_ids.size()
890
+ batch_size, seq_length = input_shape
891
+ device = input_ids.device
892
+ elif inputs_embeds is not None:
893
+ input_shape = inputs_embeds.size()[:-1]
894
+ batch_size, seq_length = input_shape
895
+ device = inputs_embeds.device
896
+ elif encoder_embeds is not None:
897
+ input_shape = encoder_embeds.size()[:-1]
898
+ batch_size, seq_length = input_shape
899
+ device = encoder_embeds.device
900
+ else:
901
+ raise ValueError(
902
+ "You have to specify either input_ids or inputs_embeds or encoder_embeds"
903
+ )
904
+
905
+ # past_key_values_length
906
+ past_key_values_length = (
907
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
908
+ )
909
+
910
+ if attention_mask is None:
911
+ attention_mask = torch.ones(
912
+ ((batch_size, seq_length + past_key_values_length)), device=device
913
+ )
914
+
915
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
916
+ # ourselves in which case we just need to make it broadcastable to all heads.
917
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
918
+ attention_mask, input_shape, device, is_decoder
919
+ )
920
+
921
+ # If a 2D or 3D attention mask is provided for the cross-attention
922
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
923
+ if encoder_hidden_states is not None:
924
+ if type(encoder_hidden_states) == list:
925
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
926
+ 0
927
+ ].size()
928
+ else:
929
+ (
930
+ encoder_batch_size,
931
+ encoder_sequence_length,
932
+ _,
933
+ ) = encoder_hidden_states.size()
934
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
935
+
936
+ if type(encoder_attention_mask) == list:
937
+ encoder_extended_attention_mask = [
938
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
939
+ ]
940
+ elif encoder_attention_mask is None:
941
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
942
+ encoder_extended_attention_mask = self.invert_attention_mask(
943
+ encoder_attention_mask
944
+ )
945
+ else:
946
+ encoder_extended_attention_mask = self.invert_attention_mask(
947
+ encoder_attention_mask
948
+ )
949
+ else:
950
+ encoder_extended_attention_mask = None
951
+
952
+ # Prepare head mask if needed
953
+ # 1.0 in head_mask indicate we keep the head
954
+ # attention_probs has shape bsz x n_heads x N x N
955
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
956
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
957
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
958
+
959
+ if encoder_embeds is None:
960
+ embedding_output = self.embeddings(
961
+ input_ids=input_ids,
962
+ position_ids=position_ids,
963
+ inputs_embeds=inputs_embeds,
964
+ past_key_values_length=past_key_values_length,
965
+ )
966
+ else:
967
+ embedding_output = encoder_embeds
968
+
969
+ encoder_outputs = self.encoder(
970
+ embedding_output,
971
+ attention_mask=extended_attention_mask,
972
+ head_mask=head_mask,
973
+ encoder_hidden_states=encoder_hidden_states,
974
+ encoder_attention_mask=encoder_extended_attention_mask,
975
+ past_key_values=past_key_values,
976
+ use_cache=use_cache,
977
+ output_attentions=output_attentions,
978
+ output_hidden_states=output_hidden_states,
979
+ return_dict=return_dict,
980
+ mode=mode,
981
+ )
982
+ sequence_output = encoder_outputs[0]
983
+ pooled_output = (
984
+ self.pooler(sequence_output) if self.pooler is not None else None
985
+ )
986
+
987
+ if not return_dict:
988
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
989
+
990
+ return BaseModelOutputWithPoolingAndCrossAttentions(
991
+ last_hidden_state=sequence_output,
992
+ pooler_output=pooled_output,
993
+ past_key_values=encoder_outputs.past_key_values,
994
+ hidden_states=encoder_outputs.hidden_states,
995
+ attentions=encoder_outputs.attentions,
996
+ cross_attentions=encoder_outputs.cross_attentions,
997
+ )
998
+
999
+
1000
+ class BertLMHeadModel(BertPreTrainedModel):
1001
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1002
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1003
+
1004
+ def __init__(self, config):
1005
+ super().__init__(config)
1006
+
1007
+ self.bert = BertModel(config, add_pooling_layer=False)
1008
+ self.cls = BertOnlyMLMHead(config)
1009
+
1010
+ self.init_weights()
1011
+
1012
+ def get_output_embeddings(self):
1013
+ return self.cls.predictions.decoder
1014
+
1015
+ def set_output_embeddings(self, new_embeddings):
1016
+ self.cls.predictions.decoder = new_embeddings
1017
+
1018
+ def forward(
1019
+ self,
1020
+ input_ids=None,
1021
+ attention_mask=None,
1022
+ position_ids=None,
1023
+ head_mask=None,
1024
+ inputs_embeds=None,
1025
+ encoder_hidden_states=None,
1026
+ encoder_attention_mask=None,
1027
+ labels=None,
1028
+ past_key_values=None,
1029
+ use_cache=None,
1030
+ output_attentions=None,
1031
+ output_hidden_states=None,
1032
+ return_dict=None,
1033
+ return_logits=False,
1034
+ is_decoder=True,
1035
+ reduction="mean",
1036
+ mode="multimodal",
1037
+ ):
1038
+ r"""
1039
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1040
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1041
+ the model is configured as a decoder.
1042
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1043
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1044
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1045
+ - 1 for tokens that are **not masked**,
1046
+ - 0 for tokens that are **masked**.
1047
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1048
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1049
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1050
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1051
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1052
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1053
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1054
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1055
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1056
+ use_cache (:obj:`bool`, `optional`):
1057
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1058
+ decoding (see :obj:`past_key_values`).
1059
+ Returns:
1060
+ Example::
1061
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1062
+ >>> import torch
1063
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1064
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1065
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1066
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1067
+ >>> outputs = model(**inputs)
1068
+ >>> prediction_logits = outputs.logits
1069
+ """
1070
+ return_dict = (
1071
+ return_dict if return_dict is not None else self.config.use_return_dict
1072
+ )
1073
+ if labels is not None:
1074
+ use_cache = False
1075
+
1076
+ outputs = self.bert(
1077
+ input_ids,
1078
+ attention_mask=attention_mask,
1079
+ position_ids=position_ids,
1080
+ head_mask=head_mask,
1081
+ inputs_embeds=inputs_embeds,
1082
+ encoder_hidden_states=encoder_hidden_states,
1083
+ encoder_attention_mask=encoder_attention_mask,
1084
+ past_key_values=past_key_values,
1085
+ use_cache=use_cache,
1086
+ output_attentions=output_attentions,
1087
+ output_hidden_states=output_hidden_states,
1088
+ return_dict=return_dict,
1089
+ is_decoder=is_decoder,
1090
+ mode=mode,
1091
+ )
1092
+
1093
+ sequence_output = outputs[0]
1094
+ prediction_scores = self.cls(sequence_output)
1095
+ # sequence_output.shape torch.Size([85, 30, 768])
1096
+ # prediction_scores.shape torch.Size([85, 30, 30524])
1097
+ # labels.shape torch.Size([85, 30])
1098
+
1099
+ if return_logits:
1100
+ return prediction_scores[:, :-1, :].contiguous()
1101
+
1102
+ lm_loss = None
1103
+ if labels is not None:
1104
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1105
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1106
+ labels = labels[:, 1:].contiguous()
1107
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1108
+ lm_loss = loss_fct(
1109
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1110
+ labels.view(-1),
1111
+ )
1112
+ if reduction == "none":
1113
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1114
+
1115
+ if not return_dict:
1116
+ output = (prediction_scores,) + outputs[2:]
1117
+ return ((lm_loss,) + output) if lm_loss is not None else output
1118
+
1119
+ return CausalLMOutputWithCrossAttentions(
1120
+ loss=lm_loss,
1121
+ logits=prediction_scores,
1122
+ past_key_values=outputs.past_key_values,
1123
+ hidden_states=outputs.hidden_states,
1124
+ attentions=outputs.attentions,
1125
+ cross_attentions=outputs.cross_attentions,
1126
+ )
1127
+
1128
+ def prepare_inputs_for_generation(
1129
+ self, input_ids, past=None, attention_mask=None, **model_kwargs
1130
+ ):
1131
+ input_shape = input_ids.shape
1132
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1133
+ if attention_mask is None:
1134
+ attention_mask = input_ids.new_ones(input_shape)
1135
+
1136
+ # cut decoder_input_ids if past is used
1137
+ if past is not None:
1138
+ input_ids = input_ids[:, -1:]
1139
+
1140
+ return {
1141
+ "input_ids": input_ids,
1142
+ "attention_mask": attention_mask,
1143
+ "past_key_values": past,
1144
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1145
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1146
+ "is_decoder": True,
1147
+ }
1148
+
1149
+ def _reorder_cache(self, past, beam_idx):
1150
+ reordered_past = ()
1151
+ for layer_past in past:
1152
+ reordered_past += (
1153
+ tuple(
1154
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1155
+ ),
1156
+ )
1157
+ return reordered_past
tag2text/models/swin_transformer.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as checkpoint
11
+ from scipy import interpolate
12
+ from timm.models.layers import DropPath
13
+ from timm.models.layers import to_2tuple
14
+ from timm.models.layers import trunc_normal_
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features,
21
+ hidden_features=None,
22
+ out_features=None,
23
+ act_layer=nn.GELU,
24
+ drop=0.0,
25
+ ):
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x):
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
41
+
42
+
43
+ def window_partition(x, window_size):
44
+ """
45
+ Args:
46
+ x: (B, H, W, C)
47
+ window_size (int): window size
48
+
49
+ Returns:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ """
52
+ B, H, W, C = x.shape
53
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
54
+ windows = (
55
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
56
+ )
57
+ return windows
58
+
59
+
60
+ def window_reverse(windows, window_size, H, W):
61
+ """
62
+ Args:
63
+ windows: (num_windows*B, window_size, window_size, C)
64
+ window_size (int): Window size
65
+ H (int): Height of image
66
+ W (int): Width of image
67
+
68
+ Returns:
69
+ x: (B, H, W, C)
70
+ """
71
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
72
+ x = windows.view(
73
+ B, H // window_size, W // window_size, window_size, window_size, -1
74
+ )
75
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
76
+ return x
77
+
78
+
79
+ class WindowAttention(nn.Module):
80
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
81
+ It supports both of shifted and non-shifted window.
82
+
83
+ Args:
84
+ dim (int): Number of input channels.
85
+ window_size (tuple[int]): The height and width of the window.
86
+ num_heads (int): Number of attention heads.
87
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
88
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
89
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
90
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ dim,
96
+ window_size,
97
+ num_heads,
98
+ qkv_bias=True,
99
+ qk_scale=None,
100
+ attn_drop=0.0,
101
+ proj_drop=0.0,
102
+ ):
103
+ super().__init__()
104
+ self.dim = dim
105
+ self.window_size = window_size # Wh, Ww
106
+ self.num_heads = num_heads
107
+ head_dim = dim // num_heads
108
+ self.scale = qk_scale or head_dim**-0.5
109
+
110
+ # define a parameter table of relative position bias
111
+ self.relative_position_bias_table = nn.Parameter(
112
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
113
+ ) # 2*Wh-1 * 2*Ww-1, nH
114
+
115
+ # get pair-wise relative position index for each token inside the window
116
+ coords_h = torch.arange(self.window_size[0])
117
+ coords_w = torch.arange(self.window_size[1])
118
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
119
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
120
+ relative_coords = (
121
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
122
+ ) # 2, Wh*Ww, Wh*Ww
123
+ relative_coords = relative_coords.permute(
124
+ 1, 2, 0
125
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
126
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
127
+ relative_coords[:, :, 1] += self.window_size[1] - 1
128
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
129
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
130
+ self.register_buffer("relative_position_index", relative_position_index)
131
+
132
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
133
+ self.attn_drop = nn.Dropout(attn_drop)
134
+ self.proj = nn.Linear(dim, dim)
135
+ self.proj_drop = nn.Dropout(proj_drop)
136
+
137
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
138
+ self.softmax = nn.Softmax(dim=-1)
139
+
140
+ def forward(self, x, mask=None):
141
+ """
142
+ Args:
143
+ x: input features with shape of (num_windows*B, N, C)
144
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
145
+ """
146
+ B_, N, C = x.shape
147
+ qkv = (
148
+ self.qkv(x)
149
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
150
+ .permute(2, 0, 3, 1, 4)
151
+ )
152
+ q, k, v = (
153
+ qkv[0],
154
+ qkv[1],
155
+ qkv[2],
156
+ ) # make torchscript happy (cannot use tensor as tuple)
157
+
158
+ q = q * self.scale
159
+ attn = q @ k.transpose(-2, -1)
160
+
161
+ relative_position_bias = self.relative_position_bias_table[
162
+ self.relative_position_index.view(-1)
163
+ ].view(
164
+ self.window_size[0] * self.window_size[1],
165
+ self.window_size[0] * self.window_size[1],
166
+ -1,
167
+ ) # Wh*Ww,Wh*Ww,nH
168
+ relative_position_bias = relative_position_bias.permute(
169
+ 2, 0, 1
170
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
171
+ attn = attn + relative_position_bias.unsqueeze(0)
172
+
173
+ if mask is not None:
174
+ nW = mask.shape[0]
175
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
176
+ 1
177
+ ).unsqueeze(0)
178
+ attn = attn.view(-1, self.num_heads, N, N)
179
+ attn = self.softmax(attn)
180
+ else:
181
+ attn = self.softmax(attn)
182
+
183
+ attn = self.attn_drop(attn)
184
+
185
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
186
+ x = self.proj(x)
187
+ x = self.proj_drop(x)
188
+ return x
189
+
190
+ def extra_repr(self) -> str:
191
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
192
+
193
+ def flops(self, N):
194
+ # calculate flops for 1 window with token length of N
195
+ flops = 0
196
+ # qkv = self.qkv(x)
197
+ flops += N * self.dim * 3 * self.dim
198
+ # attn = (q @ k.transpose(-2, -1))
199
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
200
+ # x = (attn @ v)
201
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
202
+ # x = self.proj(x)
203
+ flops += N * self.dim * self.dim
204
+ return flops
205
+
206
+
207
+ class SwinTransformerBlock(nn.Module):
208
+ r"""Swin Transformer Block.
209
+
210
+ Args:
211
+ dim (int): Number of input channels.
212
+ input_resolution (tuple[int]): Input resulotion.
213
+ num_heads (int): Number of attention heads.
214
+ window_size (int): Window size.
215
+ shift_size (int): Shift size for SW-MSA.
216
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
217
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
218
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
219
+ drop (float, optional): Dropout rate. Default: 0.0
220
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
221
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
222
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
223
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ dim,
229
+ input_resolution,
230
+ num_heads,
231
+ window_size=7,
232
+ shift_size=0,
233
+ mlp_ratio=4.0,
234
+ qkv_bias=True,
235
+ qk_scale=None,
236
+ drop=0.0,
237
+ attn_drop=0.0,
238
+ drop_path=0.0,
239
+ act_layer=nn.GELU,
240
+ norm_layer=nn.LayerNorm,
241
+ ):
242
+ super().__init__()
243
+ self.dim = dim
244
+ self.input_resolution = input_resolution
245
+ self.num_heads = num_heads
246
+ self.window_size = window_size
247
+ self.shift_size = shift_size
248
+ self.mlp_ratio = mlp_ratio
249
+ if min(self.input_resolution) <= self.window_size:
250
+ # if window size is larger than input resolution, we don't partition windows
251
+ self.shift_size = 0
252
+ self.window_size = min(self.input_resolution)
253
+ assert (
254
+ 0 <= self.shift_size < self.window_size
255
+ ), "shift_size must in 0-window_size"
256
+
257
+ self.norm1 = norm_layer(dim)
258
+ self.attn = WindowAttention(
259
+ dim,
260
+ window_size=to_2tuple(self.window_size),
261
+ num_heads=num_heads,
262
+ qkv_bias=qkv_bias,
263
+ qk_scale=qk_scale,
264
+ attn_drop=attn_drop,
265
+ proj_drop=drop,
266
+ )
267
+
268
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
269
+ self.norm2 = norm_layer(dim)
270
+ mlp_hidden_dim = int(dim * mlp_ratio)
271
+ self.mlp = Mlp(
272
+ in_features=dim,
273
+ hidden_features=mlp_hidden_dim,
274
+ act_layer=act_layer,
275
+ drop=drop,
276
+ )
277
+
278
+ if self.shift_size > 0:
279
+ # calculate attention mask for SW-MSA
280
+ H, W = self.input_resolution
281
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
282
+ h_slices = (
283
+ slice(0, -self.window_size),
284
+ slice(-self.window_size, -self.shift_size),
285
+ slice(-self.shift_size, None),
286
+ )
287
+ w_slices = (
288
+ slice(0, -self.window_size),
289
+ slice(-self.window_size, -self.shift_size),
290
+ slice(-self.shift_size, None),
291
+ )
292
+ cnt = 0
293
+ for h in h_slices:
294
+ for w in w_slices:
295
+ img_mask[:, h, w, :] = cnt
296
+ cnt += 1
297
+
298
+ mask_windows = window_partition(
299
+ img_mask, self.window_size
300
+ ) # nW, window_size, window_size, 1
301
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
302
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
303
+ attn_mask = attn_mask.masked_fill(
304
+ attn_mask != 0, float(-100.0)
305
+ ).masked_fill(attn_mask == 0, float(0.0))
306
+ else:
307
+ attn_mask = None
308
+
309
+ self.register_buffer("attn_mask", attn_mask)
310
+
311
+ def forward(self, x):
312
+ H, W = self.input_resolution
313
+ B, L, C = x.shape
314
+ assert L == H * W, "input feature has wrong size"
315
+
316
+ shortcut = x
317
+ x = self.norm1(x)
318
+ x = x.view(B, H, W, C)
319
+
320
+ # cyclic shift
321
+ if self.shift_size > 0:
322
+ shifted_x = torch.roll(
323
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
324
+ )
325
+ else:
326
+ shifted_x = x
327
+
328
+ # partition windows
329
+ x_windows = window_partition(
330
+ shifted_x, self.window_size
331
+ ) # nW*B, window_size, window_size, C
332
+ x_windows = x_windows.view(
333
+ -1, self.window_size * self.window_size, C
334
+ ) # nW*B, window_size*window_size, C
335
+
336
+ # W-MSA/SW-MSA
337
+ attn_windows = self.attn(
338
+ x_windows, mask=self.attn_mask
339
+ ) # nW*B, window_size*window_size, C
340
+
341
+ # merge windows
342
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
343
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
344
+
345
+ # reverse cyclic shift
346
+ if self.shift_size > 0:
347
+ x = torch.roll(
348
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
349
+ )
350
+ else:
351
+ x = shifted_x
352
+ x = x.view(B, H * W, C)
353
+
354
+ # FFN
355
+ x = shortcut + self.drop_path(x)
356
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
357
+
358
+ return x
359
+
360
+ def extra_repr(self) -> str:
361
+ return (
362
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
363
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
364
+ )
365
+
366
+ def flops(self):
367
+ flops = 0
368
+ H, W = self.input_resolution
369
+ # norm1
370
+ flops += self.dim * H * W
371
+ # W-MSA/SW-MSA
372
+ nW = H * W / self.window_size / self.window_size
373
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
374
+ # mlp
375
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
376
+ # norm2
377
+ flops += self.dim * H * W
378
+ return flops
379
+
380
+
381
+ class PatchMerging(nn.Module):
382
+ r"""Patch Merging Layer.
383
+
384
+ Args:
385
+ input_resolution (tuple[int]): Resolution of input feature.
386
+ dim (int): Number of input channels.
387
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
388
+ """
389
+
390
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
391
+ super().__init__()
392
+ self.input_resolution = input_resolution
393
+ self.dim = dim
394
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
395
+ self.norm = norm_layer(4 * dim)
396
+
397
+ def forward(self, x):
398
+ """
399
+ x: B, H*W, C
400
+ """
401
+ H, W = self.input_resolution
402
+ B, L, C = x.shape
403
+ assert L == H * W, "input feature has wrong size"
404
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
405
+
406
+ x = x.view(B, H, W, C)
407
+
408
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
409
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
410
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
411
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
412
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
413
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
414
+
415
+ x = self.norm(x)
416
+ x = self.reduction(x)
417
+
418
+ return x
419
+
420
+ def extra_repr(self) -> str:
421
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
422
+
423
+ def flops(self):
424
+ H, W = self.input_resolution
425
+ flops = H * W * self.dim
426
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
427
+ return flops
428
+
429
+
430
+ class BasicLayer(nn.Module):
431
+ """A basic Swin Transformer layer for one stage.
432
+
433
+ Args:
434
+ dim (int): Number of input channels.
435
+ input_resolution (tuple[int]): Input resolution.
436
+ depth (int): Number of blocks.
437
+ num_heads (int): Number of attention heads.
438
+ window_size (int): Local window size.
439
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
440
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
441
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
442
+ drop (float, optional): Dropout rate. Default: 0.0
443
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
444
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
445
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
446
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
447
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
448
+ """
449
+
450
+ def __init__(
451
+ self,
452
+ dim,
453
+ input_resolution,
454
+ depth,
455
+ num_heads,
456
+ window_size,
457
+ mlp_ratio=4.0,
458
+ qkv_bias=True,
459
+ qk_scale=None,
460
+ drop=0.0,
461
+ attn_drop=0.0,
462
+ drop_path=0.0,
463
+ norm_layer=nn.LayerNorm,
464
+ downsample=None,
465
+ use_checkpoint=False,
466
+ ):
467
+ super().__init__()
468
+ self.dim = dim
469
+ self.input_resolution = input_resolution
470
+ self.depth = depth
471
+ self.use_checkpoint = use_checkpoint
472
+
473
+ # build blocks
474
+ self.blocks = nn.ModuleList(
475
+ [
476
+ SwinTransformerBlock(
477
+ dim=dim,
478
+ input_resolution=input_resolution,
479
+ num_heads=num_heads,
480
+ window_size=window_size,
481
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
482
+ mlp_ratio=mlp_ratio,
483
+ qkv_bias=qkv_bias,
484
+ qk_scale=qk_scale,
485
+ drop=drop,
486
+ attn_drop=attn_drop,
487
+ drop_path=drop_path[i]
488
+ if isinstance(drop_path, list)
489
+ else drop_path,
490
+ norm_layer=norm_layer,
491
+ )
492
+ for i in range(depth)
493
+ ]
494
+ )
495
+
496
+ # patch merging layer
497
+ if downsample is not None:
498
+ self.downsample = downsample(
499
+ input_resolution, dim=dim, norm_layer=norm_layer
500
+ )
501
+ else:
502
+ self.downsample = None
503
+
504
+ def forward(self, x):
505
+ for blk in self.blocks:
506
+ if self.use_checkpoint:
507
+ x = checkpoint.checkpoint(blk, x)
508
+ else:
509
+ x = blk(x)
510
+ if self.downsample is not None:
511
+ x = self.downsample(x)
512
+ return x
513
+
514
+ def extra_repr(self) -> str:
515
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
516
+
517
+ def flops(self):
518
+ flops = 0
519
+ for blk in self.blocks:
520
+ flops += blk.flops()
521
+ if self.downsample is not None:
522
+ flops += self.downsample.flops()
523
+ return flops
524
+
525
+
526
+ class PatchEmbed(nn.Module):
527
+ r"""Image to Patch Embedding
528
+
529
+ Args:
530
+ img_size (int): Image size. Default: 224.
531
+ patch_size (int): Patch token size. Default: 4.
532
+ in_chans (int): Number of input image channels. Default: 3.
533
+ embed_dim (int): Number of linear projection output channels. Default: 96.
534
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
535
+ """
536
+
537
+ def __init__(
538
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
539
+ ):
540
+ super().__init__()
541
+ img_size = to_2tuple(img_size)
542
+ patch_size = to_2tuple(patch_size)
543
+ patches_resolution = [
544
+ img_size[0] // patch_size[0],
545
+ img_size[1] // patch_size[1],
546
+ ]
547
+ self.img_size = img_size
548
+ self.patch_size = patch_size
549
+ self.patches_resolution = patches_resolution
550
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
551
+
552
+ self.in_chans = in_chans
553
+ self.embed_dim = embed_dim
554
+
555
+ self.proj = nn.Conv2d(
556
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
557
+ )
558
+ if norm_layer is not None:
559
+ self.norm = norm_layer(embed_dim)
560
+ else:
561
+ self.norm = None
562
+
563
+ def forward(self, x):
564
+ B, C, H, W = x.shape
565
+ # FIXME look at relaxing size constraints
566
+ assert (
567
+ H == self.img_size[0] and W == self.img_size[1]
568
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
569
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
570
+ if self.norm is not None:
571
+ x = self.norm(x)
572
+ return x
573
+
574
+ def flops(self):
575
+ Ho, Wo = self.patches_resolution
576
+ flops = (
577
+ Ho
578
+ * Wo
579
+ * self.embed_dim
580
+ * self.in_chans
581
+ * (self.patch_size[0] * self.patch_size[1])
582
+ )
583
+ if self.norm is not None:
584
+ flops += Ho * Wo * self.embed_dim
585
+ return flops
586
+
587
+
588
+ class SwinTransformer(nn.Module):
589
+ r"""Swin Transformer
590
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
591
+ https://arxiv.org/pdf/2103.14030
592
+
593
+ Args:
594
+ img_size (int | tuple(int)): Input image size. Default 224
595
+ patch_size (int | tuple(int)): Patch size. Default: 4
596
+ in_chans (int): Number of input image channels. Default: 3
597
+ num_classes (int): Number of classes for classification head. Default: 1000
598
+ embed_dim (int): Patch embedding dimension. Default: 96
599
+ depths (tuple(int)): Depth of each Swin Transformer layer.
600
+ num_heads (tuple(int)): Number of attention heads in different layers.
601
+ window_size (int): Window size. Default: 7
602
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
603
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
604
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
605
+ drop_rate (float): Dropout rate. Default: 0
606
+ attn_drop_rate (float): Attention dropout rate. Default: 0
607
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
608
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
609
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
610
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
611
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
612
+ """
613
+
614
+ def __init__(
615
+ self,
616
+ img_size=224,
617
+ patch_size=4,
618
+ in_chans=3,
619
+ num_classes=1000,
620
+ embed_dim=96,
621
+ depths=[2, 2, 6, 2],
622
+ num_heads=[3, 6, 12, 24],
623
+ window_size=7,
624
+ mlp_ratio=4.0,
625
+ qkv_bias=True,
626
+ qk_scale=None,
627
+ drop_rate=0.0,
628
+ attn_drop_rate=0.0,
629
+ drop_path_rate=0.1,
630
+ norm_layer=nn.LayerNorm,
631
+ ape=False,
632
+ patch_norm=True,
633
+ use_checkpoint=False,
634
+ **kwargs,
635
+ ):
636
+ super().__init__()
637
+
638
+ self.num_classes = num_classes
639
+ self.num_layers = len(depths)
640
+ self.embed_dim = embed_dim
641
+ self.ape = ape
642
+ self.patch_norm = patch_norm
643
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
644
+ self.mlp_ratio = mlp_ratio
645
+
646
+ # split image into non-overlapping patches
647
+ self.patch_embed = PatchEmbed(
648
+ img_size=img_size,
649
+ patch_size=patch_size,
650
+ in_chans=in_chans,
651
+ embed_dim=embed_dim,
652
+ norm_layer=norm_layer if self.patch_norm else None,
653
+ )
654
+ num_patches = self.patch_embed.num_patches
655
+ patches_resolution = self.patch_embed.patches_resolution
656
+ self.patches_resolution = patches_resolution
657
+
658
+ # absolute position embedding
659
+ if self.ape:
660
+ self.absolute_pos_embed = nn.Parameter(
661
+ torch.zeros(1, num_patches, embed_dim)
662
+ )
663
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
664
+
665
+ self.pos_drop = nn.Dropout(p=drop_rate)
666
+
667
+ # stochastic depth
668
+ dpr = [
669
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
670
+ ] # stochastic depth decay rule
671
+
672
+ # build layers
673
+ self.layers = nn.ModuleList()
674
+ for i_layer in range(self.num_layers):
675
+ layer = BasicLayer(
676
+ dim=int(embed_dim * 2**i_layer),
677
+ input_resolution=(
678
+ patches_resolution[0] // (2**i_layer),
679
+ patches_resolution[1] // (2**i_layer),
680
+ ),
681
+ depth=depths[i_layer],
682
+ num_heads=num_heads[i_layer],
683
+ window_size=window_size,
684
+ mlp_ratio=self.mlp_ratio,
685
+ qkv_bias=qkv_bias,
686
+ qk_scale=qk_scale,
687
+ drop=drop_rate,
688
+ attn_drop=attn_drop_rate,
689
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
690
+ norm_layer=norm_layer,
691
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
692
+ use_checkpoint=use_checkpoint,
693
+ )
694
+ self.layers.append(layer)
695
+
696
+ self.norm = norm_layer(self.num_features)
697
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
698
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
699
+
700
+ self.apply(self._init_weights)
701
+
702
+ def _init_weights(self, m):
703
+ if isinstance(m, nn.Linear):
704
+ trunc_normal_(m.weight, std=0.02)
705
+ if isinstance(m, nn.Linear) and m.bias is not None:
706
+ nn.init.constant_(m.bias, 0)
707
+ elif isinstance(m, nn.LayerNorm):
708
+ nn.init.constant_(m.bias, 0)
709
+ nn.init.constant_(m.weight, 1.0)
710
+
711
+ @torch.jit.ignore
712
+ def no_weight_decay(self):
713
+ return {"absolute_pos_embed"}
714
+
715
+ @torch.jit.ignore
716
+ def no_weight_decay_keywords(self):
717
+ return {"relative_position_bias_table"}
718
+
719
+ def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs):
720
+ x = self.patch_embed(x)
721
+ if self.ape:
722
+ x = x + self.absolute_pos_embed
723
+ x = self.pos_drop(x)
724
+
725
+ for layer in self.layers:
726
+ x = layer(x)
727
+
728
+ x = self.norm(x) # B L C
729
+
730
+ x_cls = self.avgpool(x.transpose(1, 2)) # B C 1
731
+
732
+ if idx_to_group_img is None:
733
+ return torch.cat([x_cls.transpose(1, 2), x], dim=1)
734
+ else:
735
+ x_bs = torch.gather(
736
+ x,
737
+ dim=0,
738
+ index=idx_to_group_img.view(-1, 1, 1).expand(
739
+ -1, x.shape[1], x.shape[2]
740
+ ),
741
+ )
742
+ weights = image_atts[:, 1:].unsqueeze(2) # B L 1
743
+ x_bs_cls = torch.sum(
744
+ (weights * x_bs).transpose(1, 2), dim=-1, keepdim=True
745
+ ) # B C 1
746
+ x_bs_cls = x_bs_cls / torch.sum(
747
+ weights.transpose(1, 2), dim=-1, keepdim=True
748
+ ) # avgpool
749
+
750
+ return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), torch.cat(
751
+ [x_cls.transpose(1, 2), x], dim=1
752
+ )
753
+
754
+ def flops(self):
755
+ flops = 0
756
+ flops += self.patch_embed.flops()
757
+ for i, layer in enumerate(self.layers):
758
+ flops += layer.flops()
759
+ flops += (
760
+ self.num_features
761
+ * self.patches_resolution[0]
762
+ * self.patches_resolution[1]
763
+ // (2**self.num_layers)
764
+ )
765
+ flops += self.num_features * self.num_classes
766
+ return flops
767
+
768
+
769
+ def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=""):
770
+ # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348
771
+
772
+ # rel_pos_bias: relative_position_bias_table
773
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
774
+
775
+ num_extra_tokens = 0
776
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
777
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
778
+ if src_size != dst_size:
779
+ print(
780
+ "Position interpolate %s from %dx%d to %dx%d"
781
+ % (param_name, src_size, src_size, dst_size, dst_size)
782
+ )
783
+
784
+ # extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
785
+ # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
786
+
787
+ def geometric_progression(a, r, n):
788
+ return a * (1.0 - r**n) / (1.0 - r)
789
+
790
+ left, right = 1.01, 1.5
791
+ while right - left > 1e-6:
792
+ q = (left + right) / 2.0
793
+ gp = geometric_progression(1, q, src_size // 2)
794
+ if gp > dst_size // 2:
795
+ right = q
796
+ else:
797
+ left = q
798
+
799
+ # if q > 1.090307:
800
+ # q = 1.090307
801
+
802
+ dis = []
803
+ cur = 1
804
+ for i in range(src_size // 2):
805
+ dis.append(cur)
806
+ cur += q ** (i + 1)
807
+
808
+ r_ids = [-_ for _ in reversed(dis)]
809
+
810
+ x = r_ids + [0] + dis
811
+ y = r_ids + [0] + dis
812
+
813
+ t = dst_size // 2.0
814
+ dx = np.arange(-t, t + 0.1, 1.0)
815
+ dy = np.arange(-t, t + 0.1, 1.0)
816
+
817
+ # print("Original positions = %s" % str(x))
818
+ # print("Target positions = %s" % str(dx))
819
+
820
+ all_rel_pos_bias = []
821
+
822
+ for i in range(num_attn_heads):
823
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
824
+ f = interpolate.interp2d(x, y, z, kind="cubic")
825
+ all_rel_pos_bias.append(
826
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)
827
+ )
828
+
829
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
830
+
831
+ return rel_pos_bias
tag2text/models/tag2text.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Tag2Text
3
+ * Written by Xinyu Huang
4
+ """
5
+ import json
6
+ import warnings
7
+
8
+ import numpy as np
9
+ import torch
10
+ from models.bert import BertConfig
11
+ from models.bert import BertLMHeadModel
12
+ from models.bert import BertModel
13
+ from models.swin_transformer import SwinTransformer
14
+ from models.utils import *
15
+ from models.vit import VisionTransformer
16
+ from torch import nn
17
+
18
+ warnings.filterwarnings("ignore")
19
+
20
+
21
+ class Tag2Text_Caption(nn.Module):
22
+ def __init__(
23
+ self,
24
+ med_config=f"{CONFIG_PATH}/configs/med_config.json",
25
+ image_size=384,
26
+ vit="base",
27
+ vit_grad_ckpt=False,
28
+ vit_ckpt_layer=0,
29
+ prompt="a picture of ",
30
+ threshold=0.68,
31
+ delete_tag_index=[],
32
+ tag_list=f"{CONFIG_PATH}/data/tag_list.txt",
33
+ ):
34
+ r"""Tag2Text inference module, both captioning and tagging are included.
35
+ Tag2Text is an efficient and controllable vision-language pre-training framework.
36
+ Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657
37
+
38
+ Args:
39
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
40
+ image_size (int): input image size
41
+ vit (str): model size of vision transformer
42
+ threshold (int): tagging threshold
43
+ delete_tag_index (list): delete some tags that may disturb captioning
44
+ """
45
+ super().__init__()
46
+
47
+ # create image encoder
48
+ if vit == "swin_b":
49
+ if image_size == 224:
50
+ vision_config_path = f"{CONFIG_PATH}/configs/swin/config_swinB_224.json"
51
+ elif image_size == 384:
52
+ vision_config_path = f"{CONFIG_PATH}/configs/swin/config_swinB_384.json"
53
+ vision_config = read_json(vision_config_path)
54
+ assert image_size == vision_config["image_res"]
55
+ # assert config['patch_size'] == 32
56
+ vision_width = vision_config["vision_width"]
57
+
58
+ self.visual_encoder = SwinTransformer(
59
+ img_size=vision_config["image_res"],
60
+ patch_size=4,
61
+ in_chans=3,
62
+ embed_dim=vision_config["embed_dim"],
63
+ depths=vision_config["depths"],
64
+ num_heads=vision_config["num_heads"],
65
+ window_size=vision_config["window_size"],
66
+ mlp_ratio=4.0,
67
+ qkv_bias=True,
68
+ drop_rate=0.0,
69
+ drop_path_rate=0.1,
70
+ ape=False,
71
+ patch_norm=True,
72
+ use_checkpoint=False,
73
+ )
74
+
75
+ else:
76
+ self.visual_encoder, vision_width = create_vit(
77
+ vit, image_size, vit_grad_ckpt, vit_ckpt_layer
78
+ )
79
+
80
+ # create tokenzier
81
+ self.tokenizer = init_tokenizer()
82
+
83
+ # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
84
+ # create image-tag interaction encoder
85
+ encoder_config = BertConfig.from_json_file(med_config)
86
+ encoder_config.encoder_width = vision_width
87
+ self.tag_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
88
+
89
+ # create image-tag-text decoder
90
+ decoder_config = BertConfig.from_json_file(med_config)
91
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
92
+
93
+ self.delete_tag_index = delete_tag_index
94
+ self.prompt = prompt
95
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
96
+
97
+ # load tag list
98
+ self.tag_list = self.load_tag_list(tag_list)
99
+
100
+ # create image-tag recognition decoder
101
+ self.threshold = threshold
102
+ self.num_class = len(self.tag_list)
103
+ q2l_config = BertConfig.from_json_file(f"{CONFIG_PATH}/configs/q2l_config.json")
104
+ q2l_config.encoder_width = vision_width
105
+ self.tagging_head = BertModel(config=q2l_config, add_pooling_layer=False)
106
+ self.tagging_head.resize_token_embeddings(len(self.tokenizer))
107
+ self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
108
+ self.fc = GroupWiseLinear(self.num_class, q2l_config.hidden_size, bias=True)
109
+ self.del_selfattention()
110
+
111
+ # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
112
+ tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, "", " ")
113
+
114
+ def load_tag_list(self, tag_list_file):
115
+ with open(tag_list_file) as f:
116
+ tag_list = f.read().splitlines()
117
+ tag_list = np.array(tag_list)
118
+ return tag_list
119
+
120
+ # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
121
+ def del_selfattention(self):
122
+ del self.tagging_head.embeddings
123
+ for layer in self.tagging_head.encoder.layer:
124
+ del layer.attention
125
+
126
+ def generate(
127
+ self,
128
+ image,
129
+ sample=False,
130
+ num_beams=3,
131
+ max_length=30,
132
+ min_length=10,
133
+ top_p=0.9,
134
+ repetition_penalty=1.0,
135
+ tag_input=None,
136
+ return_tag_predict=False,
137
+ ):
138
+ image_embeds = self.visual_encoder(image)
139
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
140
+ image.device
141
+ )
142
+
143
+ # if not user specified tags, recognized image tags using image-tag recogntiion decoder
144
+ if tag_input == None:
145
+ image_cls_embeds = image_embeds[:, 0, :]
146
+ image_spatial_embeds = image_embeds[:, 1:, :]
147
+
148
+ bs = image_spatial_embeds.shape[0]
149
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
150
+ tagging_embed = self.tagging_head(
151
+ encoder_embeds=label_embed,
152
+ encoder_hidden_states=image_embeds,
153
+ encoder_attention_mask=image_atts,
154
+ return_dict=False,
155
+ mode="tagging",
156
+ )
157
+
158
+ logits = self.fc(tagging_embed[0])
159
+
160
+ targets = torch.where(
161
+ torch.sigmoid(logits) > self.threshold,
162
+ torch.tensor(1.0).to(image.device),
163
+ torch.zeros(self.num_class).to(image.device),
164
+ )
165
+
166
+ tag = targets.cpu().numpy()
167
+
168
+ # delete some tags that may disturb captioning
169
+ tag[:, self.delete_tag_index] = 0
170
+
171
+ tag_input = []
172
+ for b in range(bs):
173
+ index = np.argwhere(tag[b] == 1)
174
+ token = self.tag_list[index].squeeze(axis=1)
175
+ tag_input.append(" | ".join(token))
176
+
177
+ tag_output = tag_input
178
+
179
+ # beam search for text generation(default)
180
+ if not sample:
181
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
182
+ tag_input_temp = []
183
+ for tag in tag_input:
184
+ for i in range(num_beams):
185
+ tag_input_temp.append(tag)
186
+ tag_input = tag_input_temp
187
+
188
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
189
+ image.device
190
+ )
191
+
192
+ # tokenizer input tags
193
+ tag_input_tokenzier = self.tokenizer(
194
+ tag_input,
195
+ padding="max_length",
196
+ truncation=True,
197
+ max_length=40,
198
+ return_tensors="pt",
199
+ ).to(image.device)
200
+ encoder_input_ids = tag_input_tokenzier.input_ids
201
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
202
+
203
+ # put input tag into image-tag interaction encoder to interact with image embeddings
204
+ output_tagembedding = self.tag_encoder(
205
+ encoder_input_ids,
206
+ attention_mask=tag_input_tokenzier.attention_mask,
207
+ encoder_hidden_states=image_embeds,
208
+ encoder_attention_mask=image_atts,
209
+ return_dict=True,
210
+ )
211
+
212
+ # prompt trick for better captioning, followed BLIP
213
+ prompt = [self.prompt] * image.size(0)
214
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
215
+ image.device
216
+ )
217
+ input_ids[:, 0] = self.tokenizer.bos_token_id
218
+ input_ids = input_ids[:, :-1]
219
+
220
+ if sample:
221
+ # nucleus sampling
222
+ model_kwargs = {
223
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
224
+ "encoder_attention_mask": None,
225
+ }
226
+ outputs = self.text_decoder.generate(
227
+ input_ids=input_ids,
228
+ max_length=max_length,
229
+ min_length=min_length,
230
+ do_sample=True,
231
+ top_p=top_p,
232
+ num_return_sequences=1,
233
+ eos_token_id=self.tokenizer.sep_token_id,
234
+ pad_token_id=self.tokenizer.pad_token_id,
235
+ repetition_penalty=1.1,
236
+ **model_kwargs,
237
+ )
238
+ else:
239
+ # beam search (default)
240
+ model_kwargs = {
241
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
242
+ "encoder_attention_mask": None,
243
+ }
244
+ outputs = self.text_decoder.generate(
245
+ input_ids=input_ids,
246
+ max_length=max_length,
247
+ min_length=min_length,
248
+ num_beams=num_beams,
249
+ eos_token_id=self.tokenizer.sep_token_id,
250
+ pad_token_id=self.tokenizer.pad_token_id,
251
+ repetition_penalty=repetition_penalty,
252
+ **model_kwargs,
253
+ )
254
+
255
+ captions = []
256
+ for output in outputs:
257
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
258
+ captions.append(caption[len(self.prompt) :])
259
+ if return_tag_predict == True:
260
+ return captions, tag_output
261
+ return captions
262
+
263
+
264
+ # load pretrained model parameters
265
+ def tag2text_caption(pretrained="", **kwargs):
266
+ model = Tag2Text_Caption(**kwargs)
267
+ if pretrained:
268
+ if kwargs["vit"] == "swin_b":
269
+ model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
270
+ else:
271
+ model, msg = load_checkpoint(model, pretrained)
272
+ print("vit:", kwargs["vit"])
273
+ print("msg", msg)
274
+ return model
tag2text/models/utils.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from pathlib import Path
5
+ from typing import List
6
+ from urllib.parse import urlparse
7
+
8
+ import torch
9
+ from models.swin_transformer import interpolate_relative_pos_embed
10
+ from models.vit import interpolate_pos_embed
11
+ from timm.models.hub import download_cached_file
12
+ from torch import nn
13
+ from transformers import BertTokenizer
14
+
15
+ CONFIG_PATH = Path(__file__).resolve().parents[1]
16
+
17
+
18
+ def read_json(rpath):
19
+ with open(rpath) as f:
20
+ return json.load(f)
21
+
22
+
23
+ def tie_encoder_decoder_weights(
24
+ encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str
25
+ ):
26
+ uninitialized_encoder_weights: List[str] = []
27
+ if decoder.__class__ != encoder.__class__:
28
+ logger.info(
29
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
30
+ )
31
+
32
+ def tie_encoder_to_decoder_recursively(
33
+ decoder_pointer: nn.Module,
34
+ encoder_pointer: nn.Module,
35
+ module_name: str,
36
+ uninitialized_encoder_weights: List[str],
37
+ skip_key: str,
38
+ depth=0,
39
+ ):
40
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
41
+ encoder_pointer, nn.Module
42
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
43
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
44
+ assert hasattr(encoder_pointer, "weight")
45
+ encoder_pointer.weight = decoder_pointer.weight
46
+ if hasattr(decoder_pointer, "bias"):
47
+ assert hasattr(encoder_pointer, "bias")
48
+ encoder_pointer.bias = decoder_pointer.bias
49
+ print(module_name + " is tied")
50
+ return
51
+
52
+ encoder_modules = encoder_pointer._modules
53
+ decoder_modules = decoder_pointer._modules
54
+ if len(decoder_modules) > 0:
55
+ assert (
56
+ len(encoder_modules) > 0
57
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
58
+
59
+ all_encoder_weights = {
60
+ module_name + "/" + sub_name for sub_name in encoder_modules.keys()
61
+ }
62
+ encoder_layer_pos = 0
63
+ for name, module in decoder_modules.items():
64
+ if name.isdigit():
65
+ encoder_name = str(int(name) + encoder_layer_pos)
66
+ decoder_name = name
67
+ if not isinstance(
68
+ decoder_modules[decoder_name],
69
+ type(encoder_modules[encoder_name]),
70
+ ) and len(encoder_modules) != len(decoder_modules):
71
+ # this can happen if the name corresponds to the position in a list module list of layers
72
+ # in this case the decoder has added a cross-attention that the encoder does not have
73
+ # thus skip this step and subtract one layer pos from encoder
74
+ encoder_layer_pos -= 1
75
+ continue
76
+ elif name not in encoder_modules:
77
+ continue
78
+ elif depth > 500:
79
+ raise ValueError(
80
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
81
+ )
82
+ else:
83
+ decoder_name = encoder_name = name
84
+ tie_encoder_to_decoder_recursively(
85
+ decoder_modules[decoder_name],
86
+ encoder_modules[encoder_name],
87
+ module_name + "/" + name,
88
+ uninitialized_encoder_weights,
89
+ skip_key,
90
+ depth=depth + 1,
91
+ )
92
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
93
+
94
+ uninitialized_encoder_weights += list(all_encoder_weights)
95
+
96
+ # tie weights recursively
97
+ tie_encoder_to_decoder_recursively(
98
+ decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key
99
+ )
100
+
101
+
102
+ class GroupWiseLinear(nn.Module):
103
+ # could be changed to:
104
+ # output = torch.einsum('ijk,zjk->ij', x, self.W)
105
+ # or output = torch.einsum('ijk,jk->ij', x, self.W[0])
106
+ def __init__(self, num_class, hidden_dim, bias=True):
107
+ super().__init__()
108
+ self.num_class = num_class
109
+ self.hidden_dim = hidden_dim
110
+ self.bias = bias
111
+
112
+ self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
113
+ if bias:
114
+ self.b = nn.Parameter(torch.Tensor(1, num_class))
115
+ self.reset_parameters()
116
+
117
+ def reset_parameters(self):
118
+ stdv = 1.0 / math.sqrt(self.W.size(2))
119
+ for i in range(self.num_class):
120
+ self.W[0][i].data.uniform_(-stdv, stdv)
121
+ if self.bias:
122
+ for i in range(self.num_class):
123
+ self.b[0][i].data.uniform_(-stdv, stdv)
124
+
125
+ def forward(self, x):
126
+ # x: B,K,d
127
+ x = (self.W * x).sum(-1)
128
+ if self.bias:
129
+ x = x + self.b
130
+ return x
131
+
132
+
133
+ def init_tokenizer():
134
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
135
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
136
+ tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
137
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
138
+ return tokenizer
139
+
140
+
141
+ def create_vit(
142
+ vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0
143
+ ):
144
+ assert vit in ["base", "large"], "vit parameter must be base or large"
145
+ if vit == "base":
146
+ vision_width = 768
147
+ visual_encoder = VisionTransformer(
148
+ img_size=image_size,
149
+ patch_size=16,
150
+ embed_dim=vision_width,
151
+ depth=12,
152
+ num_heads=12,
153
+ use_grad_checkpointing=use_grad_checkpointing,
154
+ ckpt_layer=ckpt_layer,
155
+ drop_path_rate=0 or drop_path_rate,
156
+ )
157
+ elif vit == "large":
158
+ vision_width = 1024
159
+ visual_encoder = VisionTransformer(
160
+ img_size=image_size,
161
+ patch_size=16,
162
+ embed_dim=vision_width,
163
+ depth=24,
164
+ num_heads=16,
165
+ use_grad_checkpointing=use_grad_checkpointing,
166
+ ckpt_layer=ckpt_layer,
167
+ drop_path_rate=0.1 or drop_path_rate,
168
+ )
169
+ return visual_encoder, vision_width
170
+
171
+
172
+ def is_url(url_or_filename):
173
+ parsed = urlparse(url_or_filename)
174
+ return parsed.scheme in ("http", "https")
175
+
176
+
177
+ def load_checkpoint(model, url_or_filename):
178
+ if is_url(url_or_filename):
179
+ cached_file = download_cached_file(
180
+ url_or_filename, check_hash=False, progress=True
181
+ )
182
+ checkpoint = torch.load(cached_file, map_location="cpu")
183
+ elif os.path.isfile(url_or_filename):
184
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
185
+ else:
186
+ raise RuntimeError("checkpoint url or path is invalid")
187
+
188
+ state_dict = checkpoint["model"]
189
+
190
+ state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed(
191
+ state_dict["visual_encoder.pos_embed"], model.visual_encoder
192
+ )
193
+ if "visual_encoder_m.pos_embed" in model.state_dict().keys():
194
+ state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed(
195
+ state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m
196
+ )
197
+ for key in model.state_dict().keys():
198
+ if key in state_dict.keys():
199
+ if state_dict[key].shape != model.state_dict()[key].shape:
200
+ del state_dict[key]
201
+
202
+ msg = model.load_state_dict(state_dict, strict=False)
203
+ print("load checkpoint from %s" % url_or_filename)
204
+ return model, msg
205
+
206
+
207
+ def load_checkpoint_swinbase(model, url_or_filename, kwargs):
208
+ if kwargs["image_size"] == 224:
209
+ vision_config_path = f"{CONFIG_PATH}/configs/swin/config_swinB_224.json"
210
+ elif kwargs["image_size"] == 384:
211
+ vision_config_path = f"{CONFIG_PATH}/configs/swin/config_swinB_384.json"
212
+ window_size = read_json(vision_config_path)["window_size"]
213
+ print("--------------")
214
+ print(url_or_filename)
215
+ print("--------------")
216
+ if is_url(url_or_filename):
217
+ cached_file = download_cached_file(
218
+ url_or_filename, check_hash=False, progress=True
219
+ )
220
+ checkpoint = torch.load(cached_file, map_location="cpu")
221
+ elif os.path.isfile(url_or_filename):
222
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
223
+ else:
224
+ raise RuntimeError("checkpoint url or path is invalid")
225
+
226
+ state_dict = checkpoint["model"]
227
+
228
+ for k in list(state_dict.keys()):
229
+ if "relative_position_bias_table" in k:
230
+ dst_num_pos = (2 * window_size - 1) ** 2
231
+ state_dict[k] = interpolate_relative_pos_embed(
232
+ state_dict[k], dst_num_pos, param_name=k
233
+ )
234
+ elif ("relative_position_index" in k) or ("attn_mask" in k):
235
+ del state_dict[k]
236
+ elif "vision_multi" in k:
237
+ state_dict[k.replace("vision_multi", "tagging_head")] = state_dict.pop(k)
238
+
239
+ msg = model.load_state_dict(state_dict, strict=False)
240
+ print("load checkpoint from %s" % url_or_filename)
241
+ return model, msg
tag2text/models/vit.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ """
10
+ from functools import partial
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
16
+ from timm.models.helpers import adapt_input_conv
17
+ from timm.models.helpers import named_apply
18
+ from timm.models.layers import DropPath
19
+ from timm.models.layers import trunc_normal_
20
+ from timm.models.registry import register_model
21
+ from timm.models.vision_transformer import _cfg
22
+ from timm.models.vision_transformer import PatchEmbed
23
+
24
+
25
+ class Mlp(nn.Module):
26
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks."""
27
+
28
+ def __init__(
29
+ self,
30
+ in_features,
31
+ hidden_features=None,
32
+ out_features=None,
33
+ act_layer=nn.GELU,
34
+ drop=0.0,
35
+ ):
36
+ super().__init__()
37
+ out_features = out_features or in_features
38
+ hidden_features = hidden_features or in_features
39
+ self.fc1 = nn.Linear(in_features, hidden_features)
40
+ self.act = act_layer()
41
+ self.fc2 = nn.Linear(hidden_features, out_features)
42
+ self.drop = nn.Dropout(drop)
43
+
44
+ def forward(self, x):
45
+ x = self.fc1(x)
46
+ x = self.act(x)
47
+ x = self.drop(x)
48
+ x = self.fc2(x)
49
+ x = self.drop(x)
50
+ return x
51
+
52
+
53
+ class Attention(nn.Module):
54
+ def __init__(
55
+ self,
56
+ dim,
57
+ num_heads=8,
58
+ qkv_bias=False,
59
+ qk_scale=None,
60
+ attn_drop=0.0,
61
+ proj_drop=0.0,
62
+ ):
63
+ super().__init__()
64
+ self.num_heads = num_heads
65
+ head_dim = dim // num_heads
66
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
67
+ self.scale = qk_scale or head_dim**-0.5
68
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
69
+ self.attn_drop = nn.Dropout(attn_drop)
70
+ self.proj = nn.Linear(dim, dim)
71
+ self.proj_drop = nn.Dropout(proj_drop)
72
+ self.attn_gradients = None
73
+ self.attention_map = None
74
+
75
+ def save_attn_gradients(self, attn_gradients):
76
+ self.attn_gradients = attn_gradients
77
+
78
+ def get_attn_gradients(self):
79
+ return self.attn_gradients
80
+
81
+ def save_attention_map(self, attention_map):
82
+ self.attention_map = attention_map
83
+
84
+ def get_attention_map(self):
85
+ return self.attention_map
86
+
87
+ def forward(self, x, register_hook=False):
88
+ B, N, C = x.shape
89
+ qkv = (
90
+ self.qkv(x)
91
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
92
+ .permute(2, 0, 3, 1, 4)
93
+ )
94
+ q, k, v = (
95
+ qkv[0],
96
+ qkv[1],
97
+ qkv[2],
98
+ ) # make torchscript happy (cannot use tensor as tuple)
99
+
100
+ attn = (q @ k.transpose(-2, -1)) * self.scale
101
+ attn = attn.softmax(dim=-1)
102
+ attn = self.attn_drop(attn)
103
+
104
+ if register_hook:
105
+ self.save_attention_map(attn)
106
+ attn.register_hook(self.save_attn_gradients)
107
+
108
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
109
+ x = self.proj(x)
110
+ x = self.proj_drop(x)
111
+ return x
112
+
113
+
114
+ class Block(nn.Module):
115
+ def __init__(
116
+ self,
117
+ dim,
118
+ num_heads,
119
+ mlp_ratio=4.0,
120
+ qkv_bias=False,
121
+ qk_scale=None,
122
+ drop=0.0,
123
+ attn_drop=0.0,
124
+ drop_path=0.0,
125
+ act_layer=nn.GELU,
126
+ norm_layer=nn.LayerNorm,
127
+ use_grad_checkpointing=False,
128
+ ):
129
+ super().__init__()
130
+ self.norm1 = norm_layer(dim)
131
+ self.attn = Attention(
132
+ dim,
133
+ num_heads=num_heads,
134
+ qkv_bias=qkv_bias,
135
+ qk_scale=qk_scale,
136
+ attn_drop=attn_drop,
137
+ proj_drop=drop,
138
+ )
139
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
140
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
141
+ self.norm2 = norm_layer(dim)
142
+ mlp_hidden_dim = int(dim * mlp_ratio)
143
+ self.mlp = Mlp(
144
+ in_features=dim,
145
+ hidden_features=mlp_hidden_dim,
146
+ act_layer=act_layer,
147
+ drop=drop,
148
+ )
149
+
150
+ if use_grad_checkpointing:
151
+ self.attn = checkpoint_wrapper(self.attn)
152
+ self.mlp = checkpoint_wrapper(self.mlp)
153
+
154
+ def forward(self, x, register_hook=False):
155
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
156
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
157
+ return x
158
+
159
+
160
+ class VisionTransformer(nn.Module):
161
+ """Vision Transformer
162
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
163
+ https://arxiv.org/abs/2010.11929
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ img_size=224,
169
+ patch_size=16,
170
+ in_chans=3,
171
+ num_classes=1000,
172
+ embed_dim=768,
173
+ depth=12,
174
+ num_heads=12,
175
+ mlp_ratio=4.0,
176
+ qkv_bias=True,
177
+ qk_scale=None,
178
+ representation_size=None,
179
+ drop_rate=0.0,
180
+ attn_drop_rate=0.0,
181
+ drop_path_rate=0.0,
182
+ norm_layer=None,
183
+ use_grad_checkpointing=False,
184
+ ckpt_layer=0,
185
+ ):
186
+ """
187
+ Args:
188
+ img_size (int, tuple): input image size
189
+ patch_size (int, tuple): patch size
190
+ in_chans (int): number of input channels
191
+ num_classes (int): number of classes for classification head
192
+ embed_dim (int): embedding dimension
193
+ depth (int): depth of transformer
194
+ num_heads (int): number of attention heads
195
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
196
+ qkv_bias (bool): enable bias for qkv if True
197
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
198
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
199
+ drop_rate (float): dropout rate
200
+ attn_drop_rate (float): attention dropout rate
201
+ drop_path_rate (float): stochastic depth rate
202
+ norm_layer: (nn.Module): normalization layer
203
+ """
204
+ super().__init__()
205
+ self.num_features = (
206
+ self.embed_dim
207
+ ) = embed_dim # num_features for consistency with other models
208
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
209
+
210
+ self.patch_embed = PatchEmbed(
211
+ img_size=img_size,
212
+ patch_size=patch_size,
213
+ in_chans=in_chans,
214
+ embed_dim=embed_dim,
215
+ )
216
+
217
+ num_patches = self.patch_embed.num_patches
218
+
219
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
220
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
221
+ self.pos_drop = nn.Dropout(p=drop_rate)
222
+
223
+ dpr = [
224
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
225
+ ] # stochastic depth decay rule
226
+ self.blocks = nn.ModuleList(
227
+ [
228
+ Block(
229
+ dim=embed_dim,
230
+ num_heads=num_heads,
231
+ mlp_ratio=mlp_ratio,
232
+ qkv_bias=qkv_bias,
233
+ qk_scale=qk_scale,
234
+ drop=drop_rate,
235
+ attn_drop=attn_drop_rate,
236
+ drop_path=dpr[i],
237
+ norm_layer=norm_layer,
238
+ use_grad_checkpointing=(
239
+ use_grad_checkpointing and i >= depth - ckpt_layer
240
+ ),
241
+ )
242
+ for i in range(depth)
243
+ ]
244
+ )
245
+ self.norm = norm_layer(embed_dim)
246
+
247
+ trunc_normal_(self.pos_embed, std=0.02)
248
+ trunc_normal_(self.cls_token, std=0.02)
249
+ self.apply(self._init_weights)
250
+
251
+ def _init_weights(self, m):
252
+ if isinstance(m, nn.Linear):
253
+ trunc_normal_(m.weight, std=0.02)
254
+ if isinstance(m, nn.Linear) and m.bias is not None:
255
+ nn.init.constant_(m.bias, 0)
256
+ elif isinstance(m, nn.LayerNorm):
257
+ nn.init.constant_(m.bias, 0)
258
+ nn.init.constant_(m.weight, 1.0)
259
+
260
+ @torch.jit.ignore
261
+ def no_weight_decay(self):
262
+ return {"pos_embed", "cls_token"}
263
+
264
+ def forward(self, x, register_blk=-1):
265
+ B = x.shape[0]
266
+ x = self.patch_embed(x)
267
+
268
+ cls_tokens = self.cls_token.expand(
269
+ B, -1, -1
270
+ ) # stole cls_tokens impl from Phil Wang, thanks
271
+ x = torch.cat((cls_tokens, x), dim=1)
272
+
273
+ x = x + self.pos_embed[:, : x.size(1), :]
274
+ x = self.pos_drop(x)
275
+
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x, register_blk == i)
278
+ x = self.norm(x)
279
+
280
+ return x
281
+
282
+ @torch.jit.ignore()
283
+ def load_pretrained(self, checkpoint_path, prefix=""):
284
+ _load_weights(self, checkpoint_path, prefix)
285
+
286
+
287
+ @torch.no_grad()
288
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
289
+ """Load weights from .npz checkpoints for official Google Brain Flax implementation."""
290
+ import numpy as np
291
+
292
+ def _n2p(w, t=True):
293
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
294
+ w = w.flatten()
295
+ if t:
296
+ if w.ndim == 4:
297
+ w = w.transpose([3, 2, 0, 1])
298
+ elif w.ndim == 3:
299
+ w = w.transpose([2, 0, 1])
300
+ elif w.ndim == 2:
301
+ w = w.transpose([1, 0])
302
+ return torch.from_numpy(w)
303
+
304
+ w = np.load(checkpoint_path)
305
+ if not prefix and "opt/target/embedding/kernel" in w:
306
+ prefix = "opt/target/"
307
+
308
+ if hasattr(model.patch_embed, "backbone"):
309
+ # hybrid
310
+ backbone = model.patch_embed.backbone
311
+ stem_only = not hasattr(backbone, "stem")
312
+ stem = backbone if stem_only else backbone.stem
313
+ stem.conv.weight.copy_(
314
+ adapt_input_conv(
315
+ stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])
316
+ )
317
+ )
318
+ stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
319
+ stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
320
+ if not stem_only:
321
+ for i, stage in enumerate(backbone.stages):
322
+ for j, block in enumerate(stage.blocks):
323
+ bp = f"{prefix}block{i + 1}/unit{j + 1}/"
324
+ for r in range(3):
325
+ getattr(block, f"conv{r + 1}").weight.copy_(
326
+ _n2p(w[f"{bp}conv{r + 1}/kernel"])
327
+ )
328
+ getattr(block, f"norm{r + 1}").weight.copy_(
329
+ _n2p(w[f"{bp}gn{r + 1}/scale"])
330
+ )
331
+ getattr(block, f"norm{r + 1}").bias.copy_(
332
+ _n2p(w[f"{bp}gn{r + 1}/bias"])
333
+ )
334
+ if block.downsample is not None:
335
+ block.downsample.conv.weight.copy_(
336
+ _n2p(w[f"{bp}conv_proj/kernel"])
337
+ )
338
+ block.downsample.norm.weight.copy_(
339
+ _n2p(w[f"{bp}gn_proj/scale"])
340
+ )
341
+ block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
342
+ embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
343
+ else:
344
+ embed_conv_w = adapt_input_conv(
345
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"])
346
+ )
347
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
348
+ model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
349
+ model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
350
+ pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
351
+ if pos_embed_w.shape != model.pos_embed.shape:
352
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
353
+ pos_embed_w,
354
+ model.pos_embed,
355
+ getattr(model, "num_tokens", 1),
356
+ model.patch_embed.grid_size,
357
+ )
358
+ model.pos_embed.copy_(pos_embed_w)
359
+ model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
360
+ model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
361
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
362
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
363
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
364
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
365
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
366
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
367
+ for i, block in enumerate(model.blocks.children()):
368
+ block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
369
+ mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
370
+ block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
371
+ block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
372
+ block.attn.qkv.weight.copy_(
373
+ torch.cat(
374
+ [
375
+ _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T
376
+ for n in ("query", "key", "value")
377
+ ]
378
+ )
379
+ )
380
+ block.attn.qkv.bias.copy_(
381
+ torch.cat(
382
+ [
383
+ _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1)
384
+ for n in ("query", "key", "value")
385
+ ]
386
+ )
387
+ )
388
+ block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
389
+ block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
390
+ for r in range(2):
391
+ getattr(block.mlp, f"fc{r + 1}").weight.copy_(
392
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"])
393
+ )
394
+ getattr(block.mlp, f"fc{r + 1}").bias.copy_(
395
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"])
396
+ )
397
+ block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
398
+ block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
399
+
400
+
401
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
402
+ # interpolate position embedding
403
+ embedding_size = pos_embed_checkpoint.shape[-1]
404
+ num_patches = visual_encoder.patch_embed.num_patches
405
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
406
+ # height (== width) for the checkpoint position embedding
407
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
408
+ # height (== width) for the new position embedding
409
+ new_size = int(num_patches**0.5)
410
+
411
+ if orig_size != new_size:
412
+ # class_token and dist_token are kept unchanged
413
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
414
+ # only the position tokens are interpolated
415
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
416
+ pos_tokens = pos_tokens.reshape(
417
+ -1, orig_size, orig_size, embedding_size
418
+ ).permute(0, 3, 1, 2)
419
+ pos_tokens = torch.nn.functional.interpolate(
420
+ pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
421
+ )
422
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
423
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
424
+ print(
425
+ "reshape position embedding from %d to %d" % (orig_size**2, new_size**2)
426
+ )
427
+
428
+ return new_pos_embed
429
+ else:
430
+ return pos_embed_checkpoint
tag2text/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ torch
6
+ torchvision
7
+ Pillow
8
+ scipy
utils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ from typing import Dict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import supervision as sv
8
+ import torch
9
+ import torchvision
10
+ import torchvision.transforms as T
11
+ from groundingdino.models import build_model
12
+ from groundingdino.util.inference import Model as DinoModel
13
+ from groundingdino.util.slconfig import SLConfig
14
+ from groundingdino.util.utils import clean_state_dict
15
+ from huggingface_hub import hf_hub_download
16
+ from PIL import Image
17
+ from segment_anything import SamPredictor
18
+
19
+ # segment anything
20
+
21
+ sys.path.append("tag2text")
22
+
23
+ from tag2text.inference import inference as tag2text_inference
24
+
25
+
26
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device="cpu"):
27
+ cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
28
+
29
+ args = SLConfig.fromfile(cache_config_file)
30
+ args.device = device
31
+ model = build_model(args)
32
+
33
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
34
+ checkpoint = torch.load(cache_file, map_location=device)
35
+ model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
36
+ model.eval()
37
+ return model
38
+
39
+
40
+ def download_file_hf(repo_id, filename, cache_dir="./cache"):
41
+ cache_file = hf_hub_download(
42
+ repo_id=repo_id, filename=filename, force_filename=filename, cache_dir=cache_dir
43
+ )
44
+ return cache_file
45
+
46
+
47
+ def transform_image_tag2text(image_pil: Image) -> torch.Tensor:
48
+ transform = T.Compose(
49
+ [
50
+ T.Resize((384, 384)),
51
+ T.ToTensor(),
52
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
53
+ ]
54
+ )
55
+ image = transform(image_pil) # 3, h, w
56
+ return image
57
+
58
+
59
+ def show_anns_sam(anns: List[Dict]):
60
+ """Extracts the mask annotations from the Segment Anything model output and plots them.
61
+ https://github.com/facebookresearch/segment-anything.
62
+
63
+ Arguments:
64
+ anns (List[Dict]): Segment Anything model output.
65
+
66
+ Returns:
67
+ (np.ndarray): Masked image.
68
+ (np.ndarray): annotation encoding from https://github.com/LUSSeg/ImageNet-S
69
+ """
70
+ if len(anns) == 0:
71
+ return
72
+ sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
73
+ full_img = None
74
+
75
+ # for ann in sorted_anns:
76
+ for i in range(len(sorted_anns)):
77
+ ann = anns[i]
78
+ m = ann["segmentation"]
79
+ if full_img is None:
80
+ full_img = np.zeros((m.shape[0], m.shape[1], 3))
81
+ map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
82
+ map[m != 0] = i + 1
83
+ color_mask = np.random.random((1, 3)).tolist()[0]
84
+ full_img[m != 0] = color_mask
85
+ full_img = full_img * 255
86
+
87
+ # anno encoding from https://github.com/LUSSeg/ImageNet-S
88
+ res = np.zeros((map.shape[0], map.shape[1], 3))
89
+ res[:, :, 0] = map % 256
90
+ res[:, :, 1] = map // 256
91
+ res.astype(np.float32)
92
+ full_img = np.uint8(full_img)
93
+ return full_img, res
94
+
95
+
96
+ def show_anns_sv(detections: sv.Detections):
97
+ """Extracts the mask annotations from the Supervision Detections object.
98
+ https://roboflow.github.io/supervision/detection/core/.
99
+
100
+ Arguments:
101
+ anns (sv.Detections): Containing information about the detections.
102
+
103
+ Returns:
104
+ (np.ndarray): Masked image.
105
+ (np.ndarray): annotation encoding from https://github.com/LUSSeg/ImageNet-S
106
+ """
107
+ if detections.mask is None:
108
+ return
109
+ full_img = None
110
+
111
+ for i in np.flip(np.argsort(detections.area)):
112
+ m = detections.mask[i]
113
+ if full_img is None:
114
+ full_img = np.zeros((m.shape[0], m.shape[1], 3))
115
+ map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
116
+ map[m != 0] = i + 1
117
+ color_mask = np.random.random((1, 3)).tolist()[0]
118
+ full_img[m != 0] = color_mask
119
+ full_img = full_img * 255
120
+
121
+ # anno encoding from https://github.com/LUSSeg/ImageNet-S
122
+ res = np.zeros((map.shape[0], map.shape[1], 3))
123
+ res[:, :, 0] = map % 256
124
+ res[:, :, 1] = map // 256
125
+ res.astype(np.float32)
126
+ full_img = np.uint8(full_img)
127
+ return full_img, res
128
+
129
+
130
+ def generate_tags(tag2text_model, image, specified_tags, device="cpu"):
131
+ """Generate image tags and caption using Tag2Text model.
132
+
133
+ Arguments:
134
+ tag2text_model (nn.Module): Tag2Text model to use for prediction.
135
+ image (np.ndarray): The image for calculating. Expects an
136
+ image in HWC uint8 format, with pixel values in [0, 255].
137
+ specified_tags(str): User input specified tags
138
+
139
+ Returns:
140
+ (List[str]): Predicted image tags.
141
+ (str): Predicted image caption
142
+ """
143
+ image = transform_image_tag2text(image).unsqueeze(0).to(device)
144
+ res = tag2text_inference(image, tag2text_model, specified_tags)
145
+ tags = res[0].split(" | ")
146
+ caption = res[2]
147
+ return tags, caption
148
+
149
+
150
+ def detect(
151
+ grounding_dino_model: DinoModel,
152
+ image: np.ndarray,
153
+ caption: str,
154
+ box_threshold: float = 0.3,
155
+ text_threshold: float = 0.25,
156
+ iou_threshold: float = 0.5,
157
+ post_process: bool = True,
158
+ ):
159
+ """Detect bounding boxes for the given image, using the input caption.
160
+
161
+ Arguments:
162
+ grounding_dino_model (DinoModel): The model to use for detection.
163
+ image (np.ndarray): The image for calculating masks. Expects an
164
+ image in HWC uint8 format, with pixel values in [0, 255].
165
+ caption (str): Input caption contain object names to detect. To detect multiple objects, seperating each name with '.', like this: cat . dog . chair
166
+ box_threshold (float): Box confidence threshold
167
+ text_threshold (float): Text confidence threshold
168
+ iou_threshold (float): IOU score threshold for post processing
169
+ post_process (bool): If True, run NMS algorithm to remove duplicates segments.
170
+
171
+ Returns:
172
+ (sv.Detections): Containing information about the detections in a video frame.
173
+ (str): Predicted phrases.
174
+ (List[str]): Predicted classes.
175
+ """
176
+ detections, phrases = grounding_dino_model.predict_with_caption(
177
+ image=image,
178
+ caption=caption,
179
+ box_threshold=box_threshold,
180
+ text_threshold=text_threshold,
181
+ )
182
+ classes = list(map(lambda x: x.strip(), caption.split(".")))
183
+ detections.class_id = DinoModel.phrases2classes(phrases=phrases, classes=classes)
184
+
185
+ # NMS post process
186
+ if post_process:
187
+ # print(f"Before NMS: {len(detections.xyxy)} boxes")
188
+ nms_idx = (
189
+ torchvision.ops.nms(
190
+ torch.from_numpy(detections.xyxy),
191
+ torch.from_numpy(detections.confidence),
192
+ iou_threshold,
193
+ )
194
+ .numpy()
195
+ .tolist()
196
+ )
197
+
198
+ phrases = [phrases[idx] for idx in nms_idx]
199
+ detections.xyxy = detections.xyxy[nms_idx]
200
+ detections.confidence = detections.confidence[nms_idx]
201
+ detections.class_id = detections.class_id[nms_idx]
202
+
203
+ # print(f"After NMS: {len(detections.xyxy)} boxes")
204
+
205
+ return detections, phrases, classes
206
+
207
+
208
+ def segment(sam_model: SamPredictor, image: np.ndarray, boxes: np.ndarray):
209
+ """Predict masks for the given input boxes, using the currently set image.
210
+
211
+ Arguments:
212
+ sam_model (SamPredictor): The model to use for mask prediction.
213
+ image (np.ndarray): The image for calculating masks. Expects an
214
+ image in HWC uint8 format, with pixel values in [0, 255].
215
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
216
+ model, in XYXY format.
217
+ return_logits (bool): If true, returns un-thresholded masks logits
218
+ instead of a binary mask.
219
+
220
+ Returns:
221
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
222
+ number of masks, and (H, W) is the original image size.
223
+ (torch.Tensor): An array of shape BxC containing the model's
224
+ predictions for the quality of each mask.
225
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
226
+ of masks and H=W=256. These low res logits can be passed to
227
+ a subsequent iteration as mask input.
228
+ """
229
+ sam_model.set_image(image)
230
+ transformed_boxes = None
231
+ if boxes is not None:
232
+ boxes = torch.from_numpy(boxes)
233
+
234
+ transformed_boxes = sam_model.transform.apply_boxes_torch(
235
+ boxes.to(sam_model.device), image.shape[:2]
236
+ )
237
+
238
+ masks, scores, _ = sam_model.predict_torch(
239
+ point_coords=None,
240
+ point_labels=None,
241
+ boxes=transformed_boxes,
242
+ multimask_output=False,
243
+ )
244
+ masks = masks[:, 0, :, :]
245
+ scores = scores[:, 0]
246
+ return masks.cpu().numpy(), scores.cpu().numpy()
247
+
248
+
249
+ def draw_mask(mask, draw, random_color=False):
250
+ if random_color:
251
+ color = (
252
+ random.randint(0, 255),
253
+ random.randint(0, 255),
254
+ random.randint(0, 255),
255
+ 153,
256
+ )
257
+ else:
258
+ color = (30, 144, 255, 153)
259
+
260
+ nonzero_coords = np.transpose(np.nonzero(mask))
261
+
262
+ for coord in nonzero_coords:
263
+ draw.point(coord[::-1], fill=color)