saicharan1234 commited on
Commit
264c242
1 Parent(s): 6742cd5

Upload 44 files

Browse files
Files changed (44) hide show
  1. minigpt4/__init__.py +31 -0
  2. minigpt4/common/__init__.py +0 -0
  3. minigpt4/common/config.py +468 -0
  4. minigpt4/common/dist_utils.py +138 -0
  5. minigpt4/common/gradcam.py +24 -0
  6. minigpt4/common/logger.py +195 -0
  7. minigpt4/common/optims.py +119 -0
  8. minigpt4/common/registry.py +329 -0
  9. minigpt4/common/utils.py +424 -0
  10. minigpt4/configs/datasets/cc_combine/align.yaml +16 -0
  11. minigpt4/configs/datasets/cc_combine/defaults.yaml +11 -0
  12. minigpt4/configs/datasets/laion/defaults.yaml +13 -0
  13. minigpt4/configs/default.yaml +10 -0
  14. minigpt4/configs/models/minigpt4.yaml +39 -0
  15. minigpt4/conversation/__init__.py +0 -0
  16. minigpt4/conversation/conversation.py +201 -0
  17. minigpt4/datasets/__init__.py +0 -0
  18. minigpt4/datasets/builders/__init__.py +72 -0
  19. minigpt4/datasets/builders/base_dataset_builder.py +235 -0
  20. minigpt4/datasets/builders/image_text_pair_builder.py +86 -0
  21. minigpt4/datasets/data_utils.py +196 -0
  22. minigpt4/datasets/datasets/__init__.py +0 -0
  23. minigpt4/datasets/datasets/base_dataset.py +68 -0
  24. minigpt4/datasets/datasets/caption_datasets.py +85 -0
  25. minigpt4/datasets/datasets/cc_combine_dataset.py +53 -0
  26. minigpt4/datasets/datasets/dataloader_utils.py +162 -0
  27. minigpt4/datasets/datasets/laion_dataset.py +31 -0
  28. minigpt4/models/Qformer.py +1216 -0
  29. minigpt4/models/__init__.py +200 -0
  30. minigpt4/models/base_model.py +247 -0
  31. minigpt4/models/blip2.py +221 -0
  32. minigpt4/models/blip2_outputs.py +110 -0
  33. minigpt4/models/eva_vit.py +442 -0
  34. minigpt4/models/mini_gpt4.py +263 -0
  35. minigpt4/models/modeling_llama.py +772 -0
  36. minigpt4/processors/__init__.py +33 -0
  37. minigpt4/processors/base_processor.py +26 -0
  38. minigpt4/processors/blip_processors.py +141 -0
  39. minigpt4/processors/randaugment.py +398 -0
  40. minigpt4/runners/__init__.py +10 -0
  41. minigpt4/runners/runner_base.py +658 -0
  42. minigpt4/tasks/__init__.py +26 -0
  43. minigpt4/tasks/base_task.py +286 -0
  44. minigpt4/tasks/image_text_pretrain.py +18 -0
minigpt4/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from minigpt4.common.registry import registry
14
+
15
+ from minigpt4.datasets.builders import *
16
+ from minigpt4.models import *
17
+ from minigpt4.processors import *
18
+ from minigpt4.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
minigpt4/common/__init__.py ADDED
File without changes
minigpt4/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from minigpt4.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hiararchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hiararchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
minigpt4/common/dist_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import datetime
10
+ import functools
11
+ import os
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ import timm.models.hub as timm_hub
16
+
17
+
18
+ def setup_for_distributed(is_master):
19
+ """
20
+ This function disables printing when not in master process
21
+ """
22
+ import builtins as __builtin__
23
+
24
+ builtin_print = __builtin__.print
25
+
26
+ def print(*args, **kwargs):
27
+ force = kwargs.pop("force", False)
28
+ if is_master or force:
29
+ builtin_print(*args, **kwargs)
30
+
31
+ __builtin__.print = print
32
+
33
+
34
+ def is_dist_avail_and_initialized():
35
+ if not dist.is_available():
36
+ return False
37
+ if not dist.is_initialized():
38
+ return False
39
+ return True
40
+
41
+
42
+ def get_world_size():
43
+ if not is_dist_avail_and_initialized():
44
+ return 1
45
+ return dist.get_world_size()
46
+
47
+
48
+ def get_rank():
49
+ if not is_dist_avail_and_initialized():
50
+ return 0
51
+ return dist.get_rank()
52
+
53
+
54
+ def is_main_process():
55
+ return get_rank() == 0
56
+
57
+
58
+ def init_distributed_mode(args):
59
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
60
+ args.rank = int(os.environ["RANK"])
61
+ args.world_size = int(os.environ["WORLD_SIZE"])
62
+ args.gpu = int(os.environ["LOCAL_RANK"])
63
+ elif "SLURM_PROCID" in os.environ:
64
+ args.rank = int(os.environ["SLURM_PROCID"])
65
+ args.gpu = args.rank % torch.cuda.device_count()
66
+ else:
67
+ print("Not using distributed mode")
68
+ args.distributed = False
69
+ return
70
+
71
+ args.distributed = True
72
+
73
+ torch.cuda.set_device(args.gpu)
74
+ args.dist_backend = "nccl"
75
+ print(
76
+ "| distributed init (rank {}, world {}): {}".format(
77
+ args.rank, args.world_size, args.dist_url
78
+ ),
79
+ flush=True,
80
+ )
81
+ torch.distributed.init_process_group(
82
+ backend=args.dist_backend,
83
+ init_method=args.dist_url,
84
+ world_size=args.world_size,
85
+ rank=args.rank,
86
+ timeout=datetime.timedelta(
87
+ days=365
88
+ ), # allow auto-downloading and de-compressing
89
+ )
90
+ torch.distributed.barrier()
91
+ setup_for_distributed(args.rank == 0)
92
+
93
+
94
+ def get_dist_info():
95
+ if torch.__version__ < "1.0":
96
+ initialized = dist._initialized
97
+ else:
98
+ initialized = dist.is_initialized()
99
+ if initialized:
100
+ rank = dist.get_rank()
101
+ world_size = dist.get_world_size()
102
+ else: # non-distributed training
103
+ rank = 0
104
+ world_size = 1
105
+ return rank, world_size
106
+
107
+
108
+ def main_process(func):
109
+ @functools.wraps(func)
110
+ def wrapper(*args, **kwargs):
111
+ rank, _ = get_dist_info()
112
+ if rank == 0:
113
+ return func(*args, **kwargs)
114
+
115
+ return wrapper
116
+
117
+
118
+ def download_cached_file(url, check_hash=True, progress=False):
119
+ """
120
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
121
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
122
+ """
123
+
124
+ def get_cached_file_path():
125
+ # a hack to sync the file path across processes
126
+ parts = torch.hub.urlparse(url)
127
+ filename = os.path.basename(parts.path)
128
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
129
+
130
+ return cached_file
131
+
132
+ if is_main_process():
133
+ timm_hub.download_cached_file(url, check_hash, progress)
134
+
135
+ if is_dist_avail_and_initialized():
136
+ dist.barrier()
137
+
138
+ return get_cached_file_path()
minigpt4/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
minigpt4/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from minigpt4.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
minigpt4/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from minigpt4.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
minigpt4/common/registry.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from minigpt4.common.registry import registry
31
+ from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from minigpt4.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from minigpt4.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from minigpt4.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ from minigpt4.models import BaseModel
96
+
97
+ assert issubclass(
98
+ model_cls, BaseModel
99
+ ), "All models must inherit BaseModel class"
100
+ if name in cls.mapping["model_name_mapping"]:
101
+ raise KeyError(
102
+ "Name '{}' already registered for {}.".format(
103
+ name, cls.mapping["model_name_mapping"][name]
104
+ )
105
+ )
106
+ cls.mapping["model_name_mapping"][name] = model_cls
107
+ return model_cls
108
+
109
+ return wrap
110
+
111
+ @classmethod
112
+ def register_processor(cls, name):
113
+ r"""Register a processor to registry with key 'name'
114
+
115
+ Args:
116
+ name: Key with which the task will be registered.
117
+
118
+ Usage:
119
+
120
+ from minigpt4.common.registry import registry
121
+ """
122
+
123
+ def wrap(processor_cls):
124
+ from minigpt4.processors import BaseProcessor
125
+
126
+ assert issubclass(
127
+ processor_cls, BaseProcessor
128
+ ), "All processors must inherit BaseProcessor class"
129
+ if name in cls.mapping["processor_name_mapping"]:
130
+ raise KeyError(
131
+ "Name '{}' already registered for {}.".format(
132
+ name, cls.mapping["processor_name_mapping"][name]
133
+ )
134
+ )
135
+ cls.mapping["processor_name_mapping"][name] = processor_cls
136
+ return processor_cls
137
+
138
+ return wrap
139
+
140
+ @classmethod
141
+ def register_lr_scheduler(cls, name):
142
+ r"""Register a model to registry with key 'name'
143
+
144
+ Args:
145
+ name: Key with which the task will be registered.
146
+
147
+ Usage:
148
+
149
+ from minigpt4.common.registry import registry
150
+ """
151
+
152
+ def wrap(lr_sched_cls):
153
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
154
+ raise KeyError(
155
+ "Name '{}' already registered for {}.".format(
156
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
157
+ )
158
+ )
159
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160
+ return lr_sched_cls
161
+
162
+ return wrap
163
+
164
+ @classmethod
165
+ def register_runner(cls, name):
166
+ r"""Register a model to registry with key 'name'
167
+
168
+ Args:
169
+ name: Key with which the task will be registered.
170
+
171
+ Usage:
172
+
173
+ from minigpt4.common.registry import registry
174
+ """
175
+
176
+ def wrap(runner_cls):
177
+ if name in cls.mapping["runner_name_mapping"]:
178
+ raise KeyError(
179
+ "Name '{}' already registered for {}.".format(
180
+ name, cls.mapping["runner_name_mapping"][name]
181
+ )
182
+ )
183
+ cls.mapping["runner_name_mapping"][name] = runner_cls
184
+ return runner_cls
185
+
186
+ return wrap
187
+
188
+ @classmethod
189
+ def register_path(cls, name, path):
190
+ r"""Register a path to registry with key 'name'
191
+
192
+ Args:
193
+ name: Key with which the path will be registered.
194
+
195
+ Usage:
196
+
197
+ from minigpt4.common.registry import registry
198
+ """
199
+ assert isinstance(path, str), "All path must be str."
200
+ if name in cls.mapping["paths"]:
201
+ raise KeyError("Name '{}' already registered.".format(name))
202
+ cls.mapping["paths"][name] = path
203
+
204
+ @classmethod
205
+ def register(cls, name, obj):
206
+ r"""Register an item to registry with key 'name'
207
+
208
+ Args:
209
+ name: Key with which the item will be registered.
210
+
211
+ Usage::
212
+
213
+ from minigpt4.common.registry import registry
214
+
215
+ registry.register("config", {})
216
+ """
217
+ path = name.split(".")
218
+ current = cls.mapping["state"]
219
+
220
+ for part in path[:-1]:
221
+ if part not in current:
222
+ current[part] = {}
223
+ current = current[part]
224
+
225
+ current[path[-1]] = obj
226
+
227
+ # @classmethod
228
+ # def get_trainer_class(cls, name):
229
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
230
+
231
+ @classmethod
232
+ def get_builder_class(cls, name):
233
+ return cls.mapping["builder_name_mapping"].get(name, None)
234
+
235
+ @classmethod
236
+ def get_model_class(cls, name):
237
+ return cls.mapping["model_name_mapping"].get(name, None)
238
+
239
+ @classmethod
240
+ def get_task_class(cls, name):
241
+ return cls.mapping["task_name_mapping"].get(name, None)
242
+
243
+ @classmethod
244
+ def get_processor_class(cls, name):
245
+ return cls.mapping["processor_name_mapping"].get(name, None)
246
+
247
+ @classmethod
248
+ def get_lr_scheduler_class(cls, name):
249
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250
+
251
+ @classmethod
252
+ def get_runner_class(cls, name):
253
+ return cls.mapping["runner_name_mapping"].get(name, None)
254
+
255
+ @classmethod
256
+ def list_runners(cls):
257
+ return sorted(cls.mapping["runner_name_mapping"].keys())
258
+
259
+ @classmethod
260
+ def list_models(cls):
261
+ return sorted(cls.mapping["model_name_mapping"].keys())
262
+
263
+ @classmethod
264
+ def list_tasks(cls):
265
+ return sorted(cls.mapping["task_name_mapping"].keys())
266
+
267
+ @classmethod
268
+ def list_processors(cls):
269
+ return sorted(cls.mapping["processor_name_mapping"].keys())
270
+
271
+ @classmethod
272
+ def list_lr_schedulers(cls):
273
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274
+
275
+ @classmethod
276
+ def list_datasets(cls):
277
+ return sorted(cls.mapping["builder_name_mapping"].keys())
278
+
279
+ @classmethod
280
+ def get_path(cls, name):
281
+ return cls.mapping["paths"].get(name, None)
282
+
283
+ @classmethod
284
+ def get(cls, name, default=None, no_warning=False):
285
+ r"""Get an item from registry with key 'name'
286
+
287
+ Args:
288
+ name (string): Key whose value needs to be retrieved.
289
+ default: If passed and key is not in registry, default value will
290
+ be returned with a warning. Default: None
291
+ no_warning (bool): If passed as True, warning when key doesn't exist
292
+ will not be generated. Useful for MMF's
293
+ internal operations. Default: False
294
+ """
295
+ original_name = name
296
+ name = name.split(".")
297
+ value = cls.mapping["state"]
298
+ for subname in name:
299
+ value = value.get(subname, default)
300
+ if value is default:
301
+ break
302
+
303
+ if (
304
+ "writer" in cls.mapping["state"]
305
+ and value == default
306
+ and no_warning is False
307
+ ):
308
+ cls.mapping["state"]["writer"].warning(
309
+ "Key {} is not present in registry, returning default value "
310
+ "of {}".format(original_name, default)
311
+ )
312
+ return value
313
+
314
+ @classmethod
315
+ def unregister(cls, name):
316
+ r"""Remove an item from registry with key 'name'
317
+
318
+ Args:
319
+ name: Key which needs to be removed.
320
+ Usage::
321
+
322
+ from mmf.common.registry import registry
323
+
324
+ config = registry.unregister("config")
325
+ """
326
+ return cls.mapping["state"].pop(name, None)
327
+
328
+
329
+ registry = Registry()
minigpt4/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from minigpt4.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
minigpt4/configs/datasets/cc_combine/align.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ cc_align:
8
+ data_type: images
9
+ build_info:
10
+ # Be careful not to append minus sign (-) before split to avoid itemizing
11
+ annotations:
12
+ train:
13
+ url: placeholder
14
+ storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/filter_cap.json
15
+ images:
16
+ storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/
minigpt4/configs/datasets/cc_combine/defaults.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ cc_combine:
8
+ data_type: images
9
+ build_info:
10
+ # Be careful not to append minus sign (-) before split to avoid itemizing
11
+ storage: /ibex/project/c2133/blip_dataset/cc3m/cc3m_cc12m_sbu/{00000..01255}.tar
minigpt4/configs/datasets/laion/defaults.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ laion:
8
+
9
+ data_type: images
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ storage: /ibex/project/c2133/blip_dataset/laion_1b/laion_gpu/{00000..10488}.tar
minigpt4/configs/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ env:
7
+ # For default users
8
+ # cache_root: "cache"
9
+ # For internal use with persistent storage
10
+ cache_root: "/export/home/.cache/minigpt4"
minigpt4/configs/models/minigpt4.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: mini_gpt4
8
+
9
+ # vit encoder
10
+ image_size: 224
11
+ drop_path_rate: 0
12
+ use_grad_checkpoint: False
13
+ vit_precision: "fp16"
14
+ freeze_vit: True
15
+ freeze_qformer: True
16
+
17
+ # Q-Former
18
+ num_query_token: 32
19
+
20
+ # Vicuna
21
+ llama_model: "vicuna"
22
+
23
+ # generation configs
24
+ prompt: ""
25
+
26
+
27
+ preprocess:
28
+ vis_processor:
29
+ train:
30
+ name: "blip2_image_train"
31
+ image_size: 224
32
+ eval:
33
+ name: "blip2_image_eval"
34
+ image_size: 224
35
+ text_processor:
36
+ train:
37
+ name: "blip_caption"
38
+ eval:
39
+ name: "blip_caption"
minigpt4/conversation/__init__.py ADDED
File without changes
minigpt4/conversation/conversation.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+ from minigpt4.common.registry import registry
14
+
15
+
16
+ class SeparatorStyle(Enum):
17
+ """Different separator style."""
18
+ SINGLE = auto()
19
+ TWO = auto()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Conversation:
24
+ """A class that keeps all conversation history."""
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ # system_img: List[Image.Image] = []
30
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
+ sep: str = "###"
32
+ sep2: str = None
33
+
34
+ skip_next: bool = False
35
+ conv_id: Any = None
36
+
37
+ def get_prompt(self):
38
+ if self.sep_style == SeparatorStyle.SINGLE:
39
+ ret = self.system + self.sep
40
+ for role, message in self.messages:
41
+ if message:
42
+ ret += role + ": " + message + self.sep
43
+ else:
44
+ ret += role + ":"
45
+ return ret
46
+ elif self.sep_style == SeparatorStyle.TWO:
47
+ seps = [self.sep, self.sep2]
48
+ ret = self.system + seps[0]
49
+ for i, (role, message) in enumerate(self.messages):
50
+ if message:
51
+ ret += role + ": " + message + seps[i % 2]
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ else:
56
+ raise ValueError(f"Invalid style: {self.sep_style}")
57
+
58
+ def append_message(self, role, message):
59
+ self.messages.append([role, message])
60
+
61
+ def to_gradio_chatbot(self):
62
+ ret = []
63
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
64
+ if i % 2 == 0:
65
+ ret.append([msg, None])
66
+ else:
67
+ ret[-1][-1] = msg
68
+ return ret
69
+
70
+ def copy(self):
71
+ return Conversation(
72
+ system=self.system,
73
+ # system_img=self.system_img,
74
+ roles=self.roles,
75
+ messages=[[x, y] for x, y in self.messages],
76
+ offset=self.offset,
77
+ sep_style=self.sep_style,
78
+ sep=self.sep,
79
+ sep2=self.sep2,
80
+ conv_id=self.conv_id)
81
+
82
+ def dict(self):
83
+ return {
84
+ "system": self.system,
85
+ # "system_img": self.system_img,
86
+ "roles": self.roles,
87
+ "messages": self.messages,
88
+ "offset": self.offset,
89
+ "sep": self.sep,
90
+ "sep2": self.sep2,
91
+ "conv_id": self.conv_id,
92
+ }
93
+
94
+
95
+ class StoppingCriteriaSub(StoppingCriteria):
96
+
97
+ def __init__(self, stops=[], encounters=1):
98
+ super().__init__()
99
+ self.stops = stops
100
+
101
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
102
+ for stop in self.stops:
103
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
104
+ return True
105
+
106
+ return False
107
+
108
+
109
+ CONV_VISION = Conversation(
110
+ system="Give the following image: <Img>ImageContent</Img>. "
111
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
112
+ roles=("Human", "Assistant"),
113
+ messages=[],
114
+ offset=2,
115
+ sep_style=SeparatorStyle.SINGLE,
116
+ sep="###",
117
+ )
118
+
119
+
120
+
121
+ class Chat:
122
+ def __init__(self, model, vis_processor, device='cuda:0'):
123
+ self.device = device
124
+ self.model = model
125
+ self.vis_processor = vis_processor
126
+ stop_words_ids = [torch.tensor([835]).to(self.device),
127
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
128
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129
+
130
+ def ask(self, text, conv):
131
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
132
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
133
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
134
+ else:
135
+ conv.append_message(conv.roles[0], text)
136
+
137
+ def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
138
+ repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
139
+ conv.append_message(conv.roles[1], None)
140
+ embs = self.get_context_emb(conv, img_list)
141
+
142
+ # current_max_len = embs.shape[1] + max_new_tokens + 100
143
+ # begin_idx = max(0, current_max_len - max_length)
144
+ # embs = embs[:, begin_idx:]
145
+ outputs = self.model.llama_model.generate(
146
+ inputs_embeds=embs,
147
+ max_new_tokens=max_new_tokens,
148
+ stopping_criteria=self.stopping_criteria,
149
+ num_beams=num_beams,
150
+ min_length=min_length,
151
+ top_p=top_p,
152
+ repetition_penalty=repetition_penalty,
153
+ length_penalty=length_penalty,
154
+ temperature=temperature,
155
+ )
156
+ output_token = outputs[0]
157
+ if output_token[0] == 0:
158
+ output_token = output_token[1:]
159
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
160
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
161
+ output_text = output_text.split('Assistant:')[-1].strip()
162
+ conv.messages[-1][1] = output_text
163
+ return output_text, output_token.cpu().numpy()
164
+
165
+ def upload_img(self, image, conv, img_list):
166
+ if isinstance(image, str): # is a image path
167
+ raw_image = Image.open(image).convert('RGB')
168
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
169
+ elif isinstance(image, Image.Image):
170
+ raw_image = image
171
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
172
+ elif isinstance(image, torch.Tensor):
173
+ if len(image.shape) == 3:
174
+ image = image.unsqueeze(0)
175
+ image = image.to(self.device)
176
+
177
+ image_emb, _ = self.model.encode_img(image)
178
+ img_list.append(image_emb)
179
+ conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
180
+ msg = "Received."
181
+ # self.conv.append_message(self.conv.roles[1], msg)
182
+ return msg
183
+
184
+
185
+
186
+ def get_context_emb(self, conv, img_list):
187
+ prompt = conv.get_prompt()
188
+ prompt_segs = prompt.split('<ImageHere>')
189
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
190
+ seg_tokens = [
191
+ self.model.llama_tokenizer(
192
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
193
+ # only add bos to the first seg
194
+ for i, seg in enumerate(prompt_segs)
195
+ ]
196
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
197
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
198
+ mixed_embs = torch.cat(mixed_embs, dim=1)
199
+ return mixed_embs
200
+
201
+
minigpt4/datasets/__init__.py ADDED
File without changes
minigpt4/datasets/builders/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
9
+ from minigpt4.datasets.builders.image_text_pair_builder import (
10
+ CCCombineBuilder,
11
+ LaionBuilder,
12
+ CCAlignBuilder
13
+ )
14
+ from minigpt4.common.registry import registry
15
+
16
+ __all__ = [
17
+ "CCCombineBuilder",
18
+ "LaionBuilder",
19
+ "CCAlignBuilder"
20
+ ]
21
+
22
+
23
+ def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
24
+ """
25
+ Example
26
+
27
+ >>> dataset = load_dataset("coco_caption", cfg=None)
28
+ >>> splits = dataset.keys()
29
+ >>> print([len(dataset[split]) for split in splits])
30
+
31
+ """
32
+ if cfg_path is None:
33
+ cfg = None
34
+ else:
35
+ cfg = load_dataset_config(cfg_path)
36
+
37
+ try:
38
+ builder = registry.get_builder_class(name)(cfg)
39
+ except TypeError:
40
+ print(
41
+ f"Dataset {name} not found. Available datasets:\n"
42
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
43
+ )
44
+ exit(1)
45
+
46
+ if vis_path is not None:
47
+ if data_type is None:
48
+ # use default data type in the config
49
+ data_type = builder.config.data_type
50
+
51
+ assert (
52
+ data_type in builder.config.build_info
53
+ ), f"Invalid data_type {data_type} for {name}."
54
+
55
+ builder.config.build_info.get(data_type).storage = vis_path
56
+
57
+ dataset = builder.build_datasets()
58
+ return dataset
59
+
60
+
61
+ class DatasetZoo:
62
+ def __init__(self) -> None:
63
+ self.dataset_zoo = {
64
+ k: list(v.DATASET_CONFIG_DICT.keys())
65
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
66
+ }
67
+
68
+ def get_names(self):
69
+ return list(self.dataset_zoo.keys())
70
+
71
+
72
+ dataset_zoo = DatasetZoo()
minigpt4/datasets/builders/base_dataset_builder.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import shutil
11
+ import warnings
12
+
13
+ from omegaconf import OmegaConf
14
+ import torch.distributed as dist
15
+ from torchvision.datasets.utils import download_url
16
+
17
+ import minigpt4.common.utils as utils
18
+ from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
19
+ from minigpt4.common.registry import registry
20
+ from minigpt4.processors.base_processor import BaseProcessor
21
+
22
+
23
+
24
+ class BaseDatasetBuilder:
25
+ train_dataset_cls, eval_dataset_cls = None, None
26
+
27
+ def __init__(self, cfg=None):
28
+ super().__init__()
29
+
30
+ if cfg is None:
31
+ # help to create datasets from default config.
32
+ self.config = load_dataset_config(self.default_config_path())
33
+ elif isinstance(cfg, str):
34
+ self.config = load_dataset_config(cfg)
35
+ else:
36
+ # when called from task.build_dataset()
37
+ self.config = cfg
38
+
39
+ self.data_type = self.config.data_type
40
+
41
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
42
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43
+
44
+ def build_datasets(self):
45
+ # download, split, etc...
46
+ # only called on 1 GPU/TPU in distributed
47
+
48
+ if is_main_process():
49
+ self._download_data()
50
+
51
+ if is_dist_avail_and_initialized():
52
+ dist.barrier()
53
+
54
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
55
+ logging.info("Building datasets...")
56
+ datasets = self.build() # dataset['train'/'val'/'test']
57
+
58
+ return datasets
59
+
60
+ def build_processors(self):
61
+ vis_proc_cfg = self.config.get("vis_processor")
62
+ txt_proc_cfg = self.config.get("text_processor")
63
+
64
+ if vis_proc_cfg is not None:
65
+ vis_train_cfg = vis_proc_cfg.get("train")
66
+ vis_eval_cfg = vis_proc_cfg.get("eval")
67
+
68
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
69
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
70
+
71
+ if txt_proc_cfg is not None:
72
+ txt_train_cfg = txt_proc_cfg.get("train")
73
+ txt_eval_cfg = txt_proc_cfg.get("eval")
74
+
75
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
76
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
77
+
78
+ @staticmethod
79
+ def _build_proc_from_cfg(cfg):
80
+ return (
81
+ registry.get_processor_class(cfg.name).from_config(cfg)
82
+ if cfg is not None
83
+ else None
84
+ )
85
+
86
+ @classmethod
87
+ def default_config_path(cls, type="default"):
88
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
89
+
90
+ def _download_data(self):
91
+ self._download_ann()
92
+ self._download_vis()
93
+
94
+ def _download_ann(self):
95
+ """
96
+ Download annotation files if necessary.
97
+ All the vision-language datasets should have annotations of unified format.
98
+
99
+ storage_path can be:
100
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
101
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
102
+
103
+ Local annotation paths should be relative.
104
+ """
105
+ anns = self.config.build_info.annotations
106
+
107
+ splits = anns.keys()
108
+
109
+ cache_root = registry.get_path("cache_root")
110
+
111
+ for split in splits:
112
+ info = anns[split]
113
+
114
+ urls, storage_paths = info.get("url", None), info.storage
115
+
116
+ if isinstance(urls, str):
117
+ urls = [urls]
118
+ if isinstance(storage_paths, str):
119
+ storage_paths = [storage_paths]
120
+
121
+ assert len(urls) == len(storage_paths)
122
+
123
+ for url_or_filename, storage_path in zip(urls, storage_paths):
124
+ # if storage_path is relative, make it full by prefixing with cache_root.
125
+ if not os.path.isabs(storage_path):
126
+ storage_path = os.path.join(cache_root, storage_path)
127
+
128
+ dirname = os.path.dirname(storage_path)
129
+ if not os.path.exists(dirname):
130
+ os.makedirs(dirname)
131
+
132
+ if os.path.isfile(url_or_filename):
133
+ src, dst = url_or_filename, storage_path
134
+ if not os.path.exists(dst):
135
+ shutil.copyfile(src=src, dst=dst)
136
+ else:
137
+ logging.info("Using existing file {}.".format(dst))
138
+ else:
139
+ if os.path.isdir(storage_path):
140
+ # if only dirname is provided, suffix with basename of URL.
141
+ raise ValueError(
142
+ "Expecting storage_path to be a file path, got directory {}".format(
143
+ storage_path
144
+ )
145
+ )
146
+ else:
147
+ filename = os.path.basename(storage_path)
148
+
149
+ download_url(url=url_or_filename, root=dirname, filename=filename)
150
+
151
+ def _download_vis(self):
152
+
153
+ storage_path = self.config.build_info.get(self.data_type).storage
154
+ storage_path = utils.get_cache_path(storage_path)
155
+
156
+ if not os.path.exists(storage_path):
157
+ warnings.warn(
158
+ f"""
159
+ The specified path {storage_path} for visual inputs does not exist.
160
+ Please provide a correct path to the visual inputs or
161
+ refer to datasets/download_scripts/README.md for downloading instructions.
162
+ """
163
+ )
164
+
165
+ def build(self):
166
+ """
167
+ Create by split datasets inheriting torch.utils.data.Datasets.
168
+
169
+ # build() can be dataset-specific. Overwrite to customize.
170
+ """
171
+ self.build_processors()
172
+
173
+ build_info = self.config.build_info
174
+
175
+ ann_info = build_info.annotations
176
+ vis_info = build_info.get(self.data_type)
177
+
178
+ datasets = dict()
179
+ for split in ann_info.keys():
180
+ if split not in ["train", "val", "test"]:
181
+ continue
182
+
183
+ is_train = split == "train"
184
+
185
+ # processors
186
+ vis_processor = (
187
+ self.vis_processors["train"]
188
+ if is_train
189
+ else self.vis_processors["eval"]
190
+ )
191
+ text_processor = (
192
+ self.text_processors["train"]
193
+ if is_train
194
+ else self.text_processors["eval"]
195
+ )
196
+
197
+ # annotation path
198
+ ann_paths = ann_info.get(split).storage
199
+ if isinstance(ann_paths, str):
200
+ ann_paths = [ann_paths]
201
+
202
+ abs_ann_paths = []
203
+ for ann_path in ann_paths:
204
+ if not os.path.isabs(ann_path):
205
+ ann_path = utils.get_cache_path(ann_path)
206
+ abs_ann_paths.append(ann_path)
207
+ ann_paths = abs_ann_paths
208
+
209
+ # visual data storage path
210
+ vis_path = os.path.join(vis_info.storage, split)
211
+
212
+ if not os.path.isabs(vis_path):
213
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
214
+ vis_path = utils.get_cache_path(vis_path)
215
+
216
+ if not os.path.exists(vis_path):
217
+ warnings.warn("storage path {} does not exist.".format(vis_path))
218
+
219
+ # create datasets
220
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
221
+ datasets[split] = dataset_cls(
222
+ vis_processor=vis_processor,
223
+ text_processor=text_processor,
224
+ ann_paths=ann_paths,
225
+ vis_root=vis_path,
226
+ )
227
+
228
+ return datasets
229
+
230
+
231
+ def load_dataset_config(cfg_path):
232
+ cfg = OmegaConf.load(cfg_path).datasets
233
+ cfg = cfg[list(cfg.keys())[0]]
234
+
235
+ return cfg
minigpt4/datasets/builders/image_text_pair_builder.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+
10
+ from minigpt4.common.registry import registry
11
+ from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
12
+ from minigpt4.datasets.datasets.laion_dataset import LaionDataset
13
+ from minigpt4.datasets.datasets.cc_combine_dataset import CCCombineDataset, CCAlignDataset
14
+
15
+
16
+ @registry.register_builder("cc_combine")
17
+ class CCCombineBuilder(BaseDatasetBuilder):
18
+ train_dataset_cls = CCCombineDataset
19
+
20
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_combine/defaults.yaml"}
21
+
22
+ def _download_ann(self):
23
+ pass
24
+
25
+ def _download_vis(self):
26
+ pass
27
+
28
+ def build(self):
29
+ self.build_processors()
30
+
31
+ build_info = self.config.build_info
32
+
33
+ datasets = dict()
34
+ split = "train"
35
+
36
+ # create datasets
37
+ # [NOTE] return inner_datasets (wds.DataPipeline)
38
+ dataset_cls = self.train_dataset_cls
39
+ datasets[split] = dataset_cls(
40
+ vis_processor=self.vis_processors[split],
41
+ text_processor=self.text_processors[split],
42
+ location=build_info.storage,
43
+ ).inner_dataset
44
+
45
+ return datasets
46
+
47
+
48
+ @registry.register_builder("laion")
49
+ class LaionBuilder(BaseDatasetBuilder):
50
+ train_dataset_cls = LaionDataset
51
+
52
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
53
+
54
+ def _download_ann(self):
55
+ pass
56
+
57
+ def _download_vis(self):
58
+ pass
59
+
60
+ def build(self):
61
+ self.build_processors()
62
+
63
+ build_info = self.config.build_info
64
+
65
+ datasets = dict()
66
+ split = "train"
67
+
68
+ # create datasets
69
+ # [NOTE] return inner_datasets (wds.DataPipeline)
70
+ dataset_cls = self.train_dataset_cls
71
+ datasets[split] = dataset_cls(
72
+ vis_processor=self.vis_processors[split],
73
+ text_processor=self.text_processors[split],
74
+ location=build_info.storage,
75
+ ).inner_dataset
76
+
77
+ return datasets
78
+
79
+
80
+ @registry.register_builder("cc_align")
81
+ class CCAlignBuilder(BaseDatasetBuilder):
82
+ train_dataset_cls = CCAlignDataset
83
+
84
+ DATASET_CONFIG_DICT = {
85
+ "default": "configs/datasets/cc_combine/align.yaml",
86
+ }
minigpt4/datasets/data_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import gzip
9
+ import logging
10
+ import os
11
+ import random as rnd
12
+ import tarfile
13
+ import zipfile
14
+ import random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+ import decord
19
+ from decord import VideoReader
20
+ import webdataset as wds
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils.data.dataset import IterableDataset
24
+
25
+ from minigpt4.common.registry import registry
26
+ from minigpt4.datasets.datasets.base_dataset import ConcatDataset
27
+
28
+
29
+ decord.bridge.set_bridge("torch")
30
+ MAX_INT = registry.get("MAX_INT")
31
+
32
+
33
+ class ChainDataset(wds.DataPipeline):
34
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
35
+
36
+ This class is useful to assemble different existing dataset streams. The
37
+ chaining operation is done on-the-fly, so concatenating large-scale
38
+ datasets with this class will be efficient.
39
+
40
+ Args:
41
+ datasets (iterable of IterableDataset): datasets to be chained together
42
+ """
43
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44
+ super().__init__()
45
+ self.datasets = datasets
46
+ self.prob = []
47
+ self.names = []
48
+ for dataset in self.datasets:
49
+ if hasattr(dataset, 'name'):
50
+ self.names.append(dataset.name)
51
+ else:
52
+ self.names.append('Unknown')
53
+ if hasattr(dataset, 'sample_ratio'):
54
+ self.prob.append(dataset.sample_ratio)
55
+ else:
56
+ self.prob.append(1)
57
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58
+
59
+ def __iter__(self):
60
+ datastreams = [iter(dataset) for dataset in self.datasets]
61
+ while True:
62
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63
+ yield next(select_datastream)
64
+
65
+
66
+ def apply_to_sample(f, sample):
67
+ if len(sample) == 0:
68
+ return {}
69
+
70
+ def _apply(x):
71
+ if torch.is_tensor(x):
72
+ return f(x)
73
+ elif isinstance(x, dict):
74
+ return {key: _apply(value) for key, value in x.items()}
75
+ elif isinstance(x, list):
76
+ return [_apply(x) for x in x]
77
+ else:
78
+ return x
79
+
80
+ return _apply(sample)
81
+
82
+
83
+ def move_to_cuda(sample):
84
+ def _move_to_cuda(tensor):
85
+ return tensor.cuda()
86
+
87
+ return apply_to_sample(_move_to_cuda, sample)
88
+
89
+
90
+ def prepare_sample(samples, cuda_enabled=True):
91
+ if cuda_enabled:
92
+ samples = move_to_cuda(samples)
93
+
94
+ # TODO fp16 support
95
+
96
+ return samples
97
+
98
+
99
+ def reorg_datasets_by_split(datasets):
100
+ """
101
+ Organizes datasets by split.
102
+
103
+ Args:
104
+ datasets: dict of torch.utils.data.Dataset objects by name.
105
+
106
+ Returns:
107
+ Dict of datasets by split {split_name: List[Datasets]}.
108
+ """
109
+ # if len(datasets) == 1:
110
+ # return datasets[list(datasets.keys())[0]]
111
+ # else:
112
+ reorg_datasets = dict()
113
+
114
+ # reorganize by split
115
+ for _, dataset in datasets.items():
116
+ for split_name, dataset_split in dataset.items():
117
+ if split_name not in reorg_datasets:
118
+ reorg_datasets[split_name] = [dataset_split]
119
+ else:
120
+ reorg_datasets[split_name].append(dataset_split)
121
+
122
+ return reorg_datasets
123
+
124
+
125
+ def concat_datasets(datasets):
126
+ """
127
+ Concatenates multiple datasets into a single dataset.
128
+
129
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130
+ generic IterableDataset because it requires creating separate samplers.
131
+
132
+ Now only supports conctenating training datasets and assuming validation and testing
133
+ have only a single dataset. This is because metrics should not be computed on the concatenated
134
+ datasets.
135
+
136
+ Args:
137
+ datasets: dict of torch.utils.data.Dataset objects by split.
138
+
139
+ Returns:
140
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141
+ "val" and "test" remain the same.
142
+
143
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
144
+ a tuple, where the first element is a concatenated map-style dataset and the second
145
+ element is a chained DataPipeline dataset.
146
+
147
+ """
148
+ # concatenate datasets in the same split
149
+ for split_name in datasets:
150
+ if split_name != "train":
151
+ assert (
152
+ len(datasets[split_name]) == 1
153
+ ), "Do not support multiple {} datasets.".format(split_name)
154
+ datasets[split_name] = datasets[split_name][0]
155
+ else:
156
+ iterable_datasets, map_datasets = [], []
157
+ for dataset in datasets[split_name]:
158
+ if isinstance(dataset, wds.DataPipeline):
159
+ logging.info(
160
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
161
+ dataset
162
+ )
163
+ )
164
+ iterable_datasets.append(dataset)
165
+ elif isinstance(dataset, IterableDataset):
166
+ raise NotImplementedError(
167
+ "Do not support concatenation of generic IterableDataset."
168
+ )
169
+ else:
170
+ map_datasets.append(dataset)
171
+
172
+ # if len(iterable_datasets) > 0:
173
+ # concatenate map-style datasets and iterable-style datasets separately
174
+ if len(iterable_datasets) > 1:
175
+ chained_datasets = (
176
+ ChainDataset(iterable_datasets)
177
+ )
178
+ elif len(iterable_datasets) == 1:
179
+ chained_datasets = iterable_datasets[0]
180
+ else:
181
+ chained_datasets = None
182
+
183
+ concat_datasets = (
184
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185
+ )
186
+
187
+ train_datasets = concat_datasets, chained_datasets
188
+ train_datasets = tuple([x for x in train_datasets if x is not None])
189
+ train_datasets = (
190
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
191
+ )
192
+
193
+ datasets[split_name] = train_datasets
194
+
195
+ return datasets
196
+
minigpt4/datasets/datasets/__init__.py ADDED
File without changes
minigpt4/datasets/datasets/base_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import json
9
+ from typing import Iterable
10
+
11
+ from torch.utils.data import Dataset, ConcatDataset
12
+ from torch.utils.data.dataloader import default_collate
13
+
14
+
15
+ class BaseDataset(Dataset):
16
+ def __init__(
17
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18
+ ):
19
+ """
20
+ vis_root (string): Root directory of images (e.g. coco/images/)
21
+ ann_root (string): directory to store the annotation file
22
+ """
23
+ self.vis_root = vis_root
24
+
25
+ self.annotation = []
26
+ for ann_path in ann_paths:
27
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28
+
29
+ self.vis_processor = vis_processor
30
+ self.text_processor = text_processor
31
+
32
+ self._add_instance_ids()
33
+
34
+ def __len__(self):
35
+ return len(self.annotation)
36
+
37
+ def collater(self, samples):
38
+ return default_collate(samples)
39
+
40
+ def set_processors(self, vis_processor, text_processor):
41
+ self.vis_processor = vis_processor
42
+ self.text_processor = text_processor
43
+
44
+ def _add_instance_ids(self, key="instance_id"):
45
+ for idx, ann in enumerate(self.annotation):
46
+ ann[key] = str(idx)
47
+
48
+
49
+ class ConcatDataset(ConcatDataset):
50
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
51
+ super().__init__(datasets)
52
+
53
+ def collater(self, samples):
54
+ # TODO For now only supports datasets with same underlying collater implementations
55
+
56
+ all_keys = set()
57
+ for s in samples:
58
+ all_keys.update(s)
59
+
60
+ shared_keys = all_keys
61
+ for s in samples:
62
+ shared_keys = shared_keys & set(s.keys())
63
+
64
+ samples_shared_keys = []
65
+ for s in samples:
66
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67
+
68
+ return self.datasets[0].collater(samples_shared_keys)
minigpt4/datasets/datasets/caption_datasets.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ from collections import OrderedDict
10
+
11
+ from minigpt4.datasets.datasets.base_dataset import BaseDataset
12
+ from PIL import Image
13
+
14
+
15
+ class __DisplMixin:
16
+ def displ_item(self, index):
17
+ sample, ann = self.__getitem__(index), self.annotation[index]
18
+
19
+ return OrderedDict(
20
+ {
21
+ "file": ann["image"],
22
+ "caption": ann["caption"],
23
+ "image": sample["image"],
24
+ }
25
+ )
26
+
27
+
28
+ class CaptionDataset(BaseDataset, __DisplMixin):
29
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30
+ """
31
+ vis_root (string): Root directory of images (e.g. coco/images/)
32
+ ann_root (string): directory to store the annotation file
33
+ """
34
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35
+
36
+ self.img_ids = {}
37
+ n = 0
38
+ for ann in self.annotation:
39
+ img_id = ann["image_id"]
40
+ if img_id not in self.img_ids.keys():
41
+ self.img_ids[img_id] = n
42
+ n += 1
43
+
44
+ def __getitem__(self, index):
45
+
46
+ # TODO this assumes image input, not general enough
47
+ ann = self.annotation[index]
48
+
49
+ img_file = '{:0>12}.jpg'.format(ann["image_id"])
50
+ image_path = os.path.join(self.vis_root, img_file)
51
+ image = Image.open(image_path).convert("RGB")
52
+
53
+ image = self.vis_processor(image)
54
+ caption = self.text_processor(ann["caption"])
55
+
56
+ return {
57
+ "image": image,
58
+ "text_input": caption,
59
+ "image_id": self.img_ids[ann["image_id"]],
60
+ }
61
+
62
+
63
+ class CaptionEvalDataset(BaseDataset, __DisplMixin):
64
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65
+ """
66
+ vis_root (string): Root directory of images (e.g. coco/images/)
67
+ ann_root (string): directory to store the annotation file
68
+ split (string): val or test
69
+ """
70
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.vis_root, ann["image"])
77
+ image = Image.open(image_path).convert("RGB")
78
+
79
+ image = self.vis_processor(image)
80
+
81
+ return {
82
+ "image": image,
83
+ "image_id": ann["image_id"],
84
+ "instance_id": ann["instance_id"],
85
+ }
minigpt4/datasets/datasets/cc_combine_dataset.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import os
8
+ from PIL import Image
9
+ import webdataset as wds
10
+ from minigpt4.datasets.datasets.base_dataset import BaseDataset
11
+ from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
12
+
13
+
14
+ class CCCombineDataset(BaseDataset):
15
+ def __init__(self, vis_processor, text_processor, location):
16
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
17
+
18
+ self.inner_dataset = wds.DataPipeline(
19
+ wds.ResampledShards(location),
20
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
21
+ wds.shuffle(1000, handler=wds.warn_and_continue),
22
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
23
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
24
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
25
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
26
+ )
27
+
28
+ def to_dict(self, sample):
29
+ return {
30
+ "image": sample[0],
31
+ "text_input": self.text_processor(sample[1]["caption"]),
32
+ }
33
+
34
+
35
+ class CCAlignDataset(CaptionDataset):
36
+
37
+ def __getitem__(self, index):
38
+
39
+ # TODO this assumes image input, not general enough
40
+ ann = self.annotation[index]
41
+
42
+ img_file = '{}.jpg'.format(ann["image_id"])
43
+ image_path = os.path.join(self.vis_root, img_file)
44
+ image = Image.open(image_path).convert("RGB")
45
+
46
+ image = self.vis_processor(image)
47
+ caption = ann["caption"]
48
+
49
+ return {
50
+ "image": image,
51
+ "text_input": caption,
52
+ "image_id": self.img_ids[ann["image_id"]],
53
+ }
minigpt4/datasets/datasets/dataloader_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import time
9
+ import random
10
+ import torch
11
+ from minigpt4.datasets.data_utils import move_to_cuda
12
+ from torch.utils.data import DataLoader
13
+
14
+
15
+ class MultiIterLoader:
16
+ """
17
+ A simple wrapper for iterating over multiple iterators.
18
+
19
+ Args:
20
+ loaders (List[Loader]): List of Iterator loaders.
21
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22
+ """
23
+
24
+ def __init__(self, loaders, ratios=None):
25
+ # assert all loaders has __next__ method
26
+ for loader in loaders:
27
+ assert hasattr(
28
+ loader, "__next__"
29
+ ), "Loader {} has no __next__ method.".format(loader)
30
+
31
+ if ratios is None:
32
+ ratios = [1.0] * len(loaders)
33
+ else:
34
+ assert len(ratios) == len(loaders)
35
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36
+
37
+ self.loaders = loaders
38
+ self.ratios = ratios
39
+
40
+ def __next__(self):
41
+ # random sample from each loader by ratio
42
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43
+ return next(self.loaders[loader_idx])
44
+
45
+
46
+ class PrefetchLoader(object):
47
+ """
48
+ Modified from https://github.com/ChenRocks/UNITER.
49
+
50
+ overlap compute and cuda data transfer
51
+ (copied and then modified from nvidia apex)
52
+ """
53
+
54
+ def __init__(self, loader):
55
+ self.loader = loader
56
+ self.stream = torch.cuda.Stream()
57
+
58
+ def __iter__(self):
59
+ loader_it = iter(self.loader)
60
+ self.preload(loader_it)
61
+ batch = self.next(loader_it)
62
+ while batch is not None:
63
+ is_tuple = isinstance(batch, tuple)
64
+ if is_tuple:
65
+ task, batch = batch
66
+
67
+ if is_tuple:
68
+ yield task, batch
69
+ else:
70
+ yield batch
71
+ batch = self.next(loader_it)
72
+
73
+ def __len__(self):
74
+ return len(self.loader)
75
+
76
+ def preload(self, it):
77
+ try:
78
+ self.batch = next(it)
79
+ except StopIteration:
80
+ self.batch = None
81
+ return
82
+ # if record_stream() doesn't work, another option is to make sure
83
+ # device inputs are created on the main stream.
84
+ # self.next_input_gpu = torch.empty_like(self.next_input,
85
+ # device='cuda')
86
+ # self.next_target_gpu = torch.empty_like(self.next_target,
87
+ # device='cuda')
88
+ # Need to make sure the memory allocated for next_* is not still in use
89
+ # by the main stream at the time we start copying to next_*:
90
+ # self.stream.wait_stream(torch.cuda.current_stream())
91
+ with torch.cuda.stream(self.stream):
92
+ self.batch = move_to_cuda(self.batch)
93
+ # more code for the alternative if record_stream() doesn't work:
94
+ # copy_ will record the use of the pinned source tensor in this
95
+ # side stream.
96
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98
+ # self.next_input = self.next_input_gpu
99
+ # self.next_target = self.next_target_gpu
100
+
101
+ def next(self, it):
102
+ torch.cuda.current_stream().wait_stream(self.stream)
103
+ batch = self.batch
104
+ if batch is not None:
105
+ record_cuda_stream(batch)
106
+ self.preload(it)
107
+ return batch
108
+
109
+ def __getattr__(self, name):
110
+ method = self.loader.__getattribute__(name)
111
+ return method
112
+
113
+
114
+ def record_cuda_stream(batch):
115
+ if isinstance(batch, torch.Tensor):
116
+ batch.record_stream(torch.cuda.current_stream())
117
+ elif isinstance(batch, list) or isinstance(batch, tuple):
118
+ for t in batch:
119
+ record_cuda_stream(t)
120
+ elif isinstance(batch, dict):
121
+ for t in batch.values():
122
+ record_cuda_stream(t)
123
+ else:
124
+ pass
125
+
126
+
127
+ class IterLoader:
128
+ """
129
+ A wrapper to convert DataLoader as an infinite iterator.
130
+
131
+ Modified from:
132
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133
+ """
134
+
135
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136
+ self._dataloader = dataloader
137
+ self.iter_loader = iter(self._dataloader)
138
+ self._use_distributed = use_distributed
139
+ self._epoch = 0
140
+
141
+ @property
142
+ def epoch(self) -> int:
143
+ return self._epoch
144
+
145
+ def __next__(self):
146
+ try:
147
+ data = next(self.iter_loader)
148
+ except StopIteration:
149
+ self._epoch += 1
150
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151
+ self._dataloader.sampler.set_epoch(self._epoch)
152
+ time.sleep(2) # Prevent possible deadlock during epoch transition
153
+ self.iter_loader = iter(self._dataloader)
154
+ data = next(self.iter_loader)
155
+
156
+ return data
157
+
158
+ def __iter__(self):
159
+ return self
160
+
161
+ def __len__(self):
162
+ return len(self._dataloader)
minigpt4/datasets/datasets/laion_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import webdataset as wds
9
+ from minigpt4.datasets.datasets.base_dataset import BaseDataset
10
+
11
+
12
+ class LaionDataset(BaseDataset):
13
+ def __init__(self, vis_processor, text_processor, location):
14
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15
+
16
+ self.inner_dataset = wds.DataPipeline(
17
+ wds.ResampledShards(location),
18
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
19
+ wds.shuffle(1000, handler=wds.warn_and_continue),
20
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
21
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
24
+ )
25
+
26
+ def to_dict(self, sample):
27
+ return {
28
+ "image": sample[0],
29
+ "text_input": self.text_processor(sample[1]["caption"]),
30
+ }
31
+
minigpt4/models/Qformer.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ self.position_embeddings = nn.Embedding(
60
+ config.max_position_embeddings, config.hidden_size
61
+ )
62
+
63
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
64
+ # any TensorFlow checkpoint file
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+
68
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
69
+ self.register_buffer(
70
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
71
+ )
72
+ self.position_embedding_type = getattr(
73
+ config, "position_embedding_type", "absolute"
74
+ )
75
+
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ input_ids=None,
81
+ position_ids=None,
82
+ query_embeds=None,
83
+ past_key_values_length=0,
84
+ ):
85
+ if input_ids is not None:
86
+ seq_length = input_ids.size()[1]
87
+ else:
88
+ seq_length = 0
89
+
90
+ if position_ids is None:
91
+ position_ids = self.position_ids[
92
+ :, past_key_values_length : seq_length + past_key_values_length
93
+ ].clone()
94
+
95
+ if input_ids is not None:
96
+ embeddings = self.word_embeddings(input_ids)
97
+ if self.position_embedding_type == "absolute":
98
+ position_embeddings = self.position_embeddings(position_ids)
99
+ embeddings = embeddings + position_embeddings
100
+
101
+ if query_embeds is not None:
102
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
103
+ else:
104
+ embeddings = query_embeds
105
+
106
+ embeddings = self.LayerNorm(embeddings)
107
+ embeddings = self.dropout(embeddings)
108
+ return embeddings
109
+
110
+
111
+ class BertSelfAttention(nn.Module):
112
+ def __init__(self, config, is_cross_attention):
113
+ super().__init__()
114
+ self.config = config
115
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
116
+ config, "embedding_size"
117
+ ):
118
+ raise ValueError(
119
+ "The hidden size (%d) is not a multiple of the number of attention "
120
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
121
+ )
122
+
123
+ self.num_attention_heads = config.num_attention_heads
124
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
125
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
126
+
127
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
128
+ if is_cross_attention:
129
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
130
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
131
+ else:
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if (
140
+ self.position_embedding_type == "relative_key"
141
+ or self.position_embedding_type == "relative_key_query"
142
+ ):
143
+ self.max_position_embeddings = config.max_position_embeddings
144
+ self.distance_embedding = nn.Embedding(
145
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
146
+ )
147
+ self.save_attention = False
148
+
149
+ def save_attn_gradients(self, attn_gradients):
150
+ self.attn_gradients = attn_gradients
151
+
152
+ def get_attn_gradients(self):
153
+ return self.attn_gradients
154
+
155
+ def save_attention_map(self, attention_map):
156
+ self.attention_map = attention_map
157
+
158
+ def get_attention_map(self):
159
+ return self.attention_map
160
+
161
+ def transpose_for_scores(self, x):
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(*new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states,
172
+ attention_mask=None,
173
+ head_mask=None,
174
+ encoder_hidden_states=None,
175
+ encoder_attention_mask=None,
176
+ past_key_value=None,
177
+ output_attentions=False,
178
+ ):
179
+
180
+ # If this is instantiated as a cross-attention module, the keys
181
+ # and values come from an encoder; the attention mask needs to be
182
+ # such that the encoder's padding tokens are not attended to.
183
+ is_cross_attention = encoder_hidden_states is not None
184
+
185
+ if is_cross_attention:
186
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
188
+ attention_mask = encoder_attention_mask
189
+ elif past_key_value is not None:
190
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
191
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
192
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
193
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
194
+ else:
195
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
196
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
197
+
198
+ mixed_query_layer = self.query(hidden_states)
199
+
200
+ query_layer = self.transpose_for_scores(mixed_query_layer)
201
+
202
+ past_key_value = (key_layer, value_layer)
203
+
204
+ # Take the dot product between "query" and "key" to get the raw attention scores.
205
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
206
+
207
+ if (
208
+ self.position_embedding_type == "relative_key"
209
+ or self.position_embedding_type == "relative_key_query"
210
+ ):
211
+ seq_length = hidden_states.size()[1]
212
+ position_ids_l = torch.arange(
213
+ seq_length, dtype=torch.long, device=hidden_states.device
214
+ ).view(-1, 1)
215
+ position_ids_r = torch.arange(
216
+ seq_length, dtype=torch.long, device=hidden_states.device
217
+ ).view(1, -1)
218
+ distance = position_ids_l - position_ids_r
219
+ positional_embedding = self.distance_embedding(
220
+ distance + self.max_position_embeddings - 1
221
+ )
222
+ positional_embedding = positional_embedding.to(
223
+ dtype=query_layer.dtype
224
+ ) # fp16 compatibility
225
+
226
+ if self.position_embedding_type == "relative_key":
227
+ relative_position_scores = torch.einsum(
228
+ "bhld,lrd->bhlr", query_layer, positional_embedding
229
+ )
230
+ attention_scores = attention_scores + relative_position_scores
231
+ elif self.position_embedding_type == "relative_key_query":
232
+ relative_position_scores_query = torch.einsum(
233
+ "bhld,lrd->bhlr", query_layer, positional_embedding
234
+ )
235
+ relative_position_scores_key = torch.einsum(
236
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
237
+ )
238
+ attention_scores = (
239
+ attention_scores
240
+ + relative_position_scores_query
241
+ + relative_position_scores_key
242
+ )
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245
+ if attention_mask is not None:
246
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247
+ attention_scores = attention_scores + attention_mask
248
+
249
+ # Normalize the attention scores to probabilities.
250
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
251
+
252
+ if is_cross_attention and self.save_attention:
253
+ self.save_attention_map(attention_probs)
254
+ attention_probs.register_hook(self.save_attn_gradients)
255
+
256
+ # This is actually dropping out entire tokens to attend to, which might
257
+ # seem a bit unusual, but is taken from the original Transformer paper.
258
+ attention_probs_dropped = self.dropout(attention_probs)
259
+
260
+ # Mask heads if we want to
261
+ if head_mask is not None:
262
+ attention_probs_dropped = attention_probs_dropped * head_mask
263
+
264
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
265
+
266
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
267
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
268
+ context_layer = context_layer.view(*new_context_layer_shape)
269
+
270
+ outputs = (
271
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
272
+ )
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads,
304
+ self.self.num_attention_heads,
305
+ self.self.attention_head_size,
306
+ self.pruned_heads,
307
+ )
308
+
309
+ # Prune linear layers
310
+ self.self.query = prune_linear_layer(self.self.query, index)
311
+ self.self.key = prune_linear_layer(self.self.key, index)
312
+ self.self.value = prune_linear_layer(self.self.value, index)
313
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
+
315
+ # Update hyper params and store pruned heads
316
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
317
+ self.self.all_head_size = (
318
+ self.self.attention_head_size * self.self.num_attention_heads
319
+ )
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ self_outputs = self.self(
333
+ hidden_states,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+
343
+ outputs = (attention_output,) + self_outputs[
344
+ 1:
345
+ ] # add attentions if we output them
346
+ return outputs
347
+
348
+
349
+ class BertIntermediate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
353
+ if isinstance(config.hidden_act, str):
354
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
355
+ else:
356
+ self.intermediate_act_fn = config.hidden_act
357
+
358
+ def forward(self, hidden_states):
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.intermediate_act_fn(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class BertOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ def forward(self, hidden_states, input_tensor):
372
+ hidden_states = self.dense(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
375
+ return hidden_states
376
+
377
+
378
+ class BertLayer(nn.Module):
379
+ def __init__(self, config, layer_num):
380
+ super().__init__()
381
+ self.config = config
382
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
383
+ self.seq_len_dim = 1
384
+ self.attention = BertAttention(config)
385
+ self.layer_num = layer_num
386
+ if (
387
+ self.config.add_cross_attention
388
+ and layer_num % self.config.cross_attention_freq == 0
389
+ ):
390
+ self.crossattention = BertAttention(
391
+ config, is_cross_attention=self.config.add_cross_attention
392
+ )
393
+ self.has_cross_attention = True
394
+ else:
395
+ self.has_cross_attention = False
396
+ self.intermediate = BertIntermediate(config)
397
+ self.output = BertOutput(config)
398
+
399
+ self.intermediate_query = BertIntermediate(config)
400
+ self.output_query = BertOutput(config)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states,
405
+ attention_mask=None,
406
+ head_mask=None,
407
+ encoder_hidden_states=None,
408
+ encoder_attention_mask=None,
409
+ past_key_value=None,
410
+ output_attentions=False,
411
+ query_length=0,
412
+ ):
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = (
415
+ past_key_value[:2] if past_key_value is not None else None
416
+ )
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+ outputs = self_attention_outputs[1:-1]
426
+
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if query_length > 0:
430
+ query_attention_output = attention_output[:, :query_length, :]
431
+
432
+ if self.has_cross_attention:
433
+ assert (
434
+ encoder_hidden_states is not None
435
+ ), "encoder_hidden_states must be given for cross-attention layers"
436
+ cross_attention_outputs = self.crossattention(
437
+ query_attention_output,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ output_attentions=output_attentions,
443
+ )
444
+ query_attention_output = cross_attention_outputs[0]
445
+ outputs = (
446
+ outputs + cross_attention_outputs[1:-1]
447
+ ) # add cross attentions if we output attention weights
448
+
449
+ layer_output = apply_chunking_to_forward(
450
+ self.feed_forward_chunk_query,
451
+ self.chunk_size_feed_forward,
452
+ self.seq_len_dim,
453
+ query_attention_output,
454
+ )
455
+ if attention_output.shape[1] > query_length:
456
+ layer_output_text = apply_chunking_to_forward(
457
+ self.feed_forward_chunk,
458
+ self.chunk_size_feed_forward,
459
+ self.seq_len_dim,
460
+ attention_output[:, query_length:, :],
461
+ )
462
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
463
+ else:
464
+ layer_output = apply_chunking_to_forward(
465
+ self.feed_forward_chunk,
466
+ self.chunk_size_feed_forward,
467
+ self.seq_len_dim,
468
+ attention_output,
469
+ )
470
+ outputs = (layer_output,) + outputs
471
+
472
+ outputs = outputs + (present_key_value,)
473
+
474
+ return outputs
475
+
476
+ def feed_forward_chunk(self, attention_output):
477
+ intermediate_output = self.intermediate(attention_output)
478
+ layer_output = self.output(intermediate_output, attention_output)
479
+ return layer_output
480
+
481
+ def feed_forward_chunk_query(self, attention_output):
482
+ intermediate_output = self.intermediate_query(attention_output)
483
+ layer_output = self.output_query(intermediate_output, attention_output)
484
+ return layer_output
485
+
486
+
487
+ class BertEncoder(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.layer = nn.ModuleList(
492
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
493
+ )
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ head_mask=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ past_key_values=None,
503
+ use_cache=None,
504
+ output_attentions=False,
505
+ output_hidden_states=False,
506
+ return_dict=True,
507
+ query_length=0,
508
+ ):
509
+ all_hidden_states = () if output_hidden_states else None
510
+ all_self_attentions = () if output_attentions else None
511
+ all_cross_attentions = (
512
+ () if output_attentions and self.config.add_cross_attention else None
513
+ )
514
+
515
+ next_decoder_cache = () if use_cache else None
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ layer_module = self.layer[i]
519
+ if output_hidden_states:
520
+ all_hidden_states = all_hidden_states + (hidden_states,)
521
+
522
+ layer_head_mask = head_mask[i] if head_mask is not None else None
523
+ past_key_value = past_key_values[i] if past_key_values is not None else None
524
+
525
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
526
+
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+
970
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
971
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
972
+
973
+ def __init__(self, config):
974
+ super().__init__(config)
975
+
976
+ self.bert = BertModel(config, add_pooling_layer=False)
977
+ self.cls = BertOnlyMLMHead(config)
978
+
979
+ self.init_weights()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.cls.predictions.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.cls.predictions.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids=None,
990
+ attention_mask=None,
991
+ position_ids=None,
992
+ head_mask=None,
993
+ query_embeds=None,
994
+ encoder_hidden_states=None,
995
+ encoder_attention_mask=None,
996
+ labels=None,
997
+ past_key_values=None,
998
+ use_cache=True,
999
+ output_attentions=None,
1000
+ output_hidden_states=None,
1001
+ return_dict=None,
1002
+ return_logits=False,
1003
+ is_decoder=True,
1004
+ reduction="mean",
1005
+ ):
1006
+ r"""
1007
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1008
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1009
+ the model is configured as a decoder.
1010
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1011
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1012
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1013
+ - 1 for tokens that are **not masked**,
1014
+ - 0 for tokens that are **masked**.
1015
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1016
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1017
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1018
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1019
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1020
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1021
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1022
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1023
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1024
+ use_cache (:obj:`bool`, `optional`):
1025
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1026
+ decoding (see :obj:`past_key_values`).
1027
+ Returns:
1028
+ Example::
1029
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1030
+ >>> import torch
1031
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1032
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1033
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1034
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1035
+ >>> outputs = model(**inputs)
1036
+ >>> prediction_logits = outputs.logits
1037
+ """
1038
+ return_dict = (
1039
+ return_dict if return_dict is not None else self.config.use_return_dict
1040
+ )
1041
+ if labels is not None:
1042
+ use_cache = False
1043
+ if past_key_values is not None:
1044
+ query_embeds = None
1045
+
1046
+ outputs = self.bert(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ head_mask=head_mask,
1051
+ query_embeds=query_embeds,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ is_decoder=is_decoder,
1060
+ )
1061
+
1062
+ sequence_output = outputs[0]
1063
+ if query_embeds is not None:
1064
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1065
+
1066
+ prediction_scores = self.cls(sequence_output)
1067
+
1068
+ if return_logits:
1069
+ return prediction_scores[:, :-1, :].contiguous()
1070
+
1071
+ lm_loss = None
1072
+ if labels is not None:
1073
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1074
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1075
+ labels = labels[:, 1:].contiguous()
1076
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1077
+ lm_loss = loss_fct(
1078
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1079
+ labels.view(-1),
1080
+ )
1081
+ if reduction == "none":
1082
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1083
+
1084
+ if not return_dict:
1085
+ output = (prediction_scores,) + outputs[2:]
1086
+ return ((lm_loss,) + output) if lm_loss is not None else output
1087
+
1088
+ return CausalLMOutputWithCrossAttentions(
1089
+ loss=lm_loss,
1090
+ logits=prediction_scores,
1091
+ past_key_values=outputs.past_key_values,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ cross_attentions=outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1099
+ ):
1100
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1101
+ if attention_mask is None:
1102
+ attention_mask = input_ids.new_ones(input_ids.shape)
1103
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1104
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1105
+
1106
+ # cut decoder_input_ids if past is used
1107
+ if past is not None:
1108
+ input_ids = input_ids[:, -1:]
1109
+
1110
+ return {
1111
+ "input_ids": input_ids,
1112
+ "query_embeds": query_embeds,
1113
+ "attention_mask": attention_mask,
1114
+ "past_key_values": past,
1115
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1116
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1117
+ "is_decoder": True,
1118
+ }
1119
+
1120
+ def _reorder_cache(self, past, beam_idx):
1121
+ reordered_past = ()
1122
+ for layer_past in past:
1123
+ reordered_past += (
1124
+ tuple(
1125
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1126
+ ),
1127
+ )
1128
+ return reordered_past
1129
+
1130
+
1131
+ class BertForMaskedLM(BertPreTrainedModel):
1132
+
1133
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1134
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1135
+
1136
+ def __init__(self, config):
1137
+ super().__init__(config)
1138
+
1139
+ self.bert = BertModel(config, add_pooling_layer=False)
1140
+ self.cls = BertOnlyMLMHead(config)
1141
+
1142
+ self.init_weights()
1143
+
1144
+ def get_output_embeddings(self):
1145
+ return self.cls.predictions.decoder
1146
+
1147
+ def set_output_embeddings(self, new_embeddings):
1148
+ self.cls.predictions.decoder = new_embeddings
1149
+
1150
+ def forward(
1151
+ self,
1152
+ input_ids=None,
1153
+ attention_mask=None,
1154
+ position_ids=None,
1155
+ head_mask=None,
1156
+ query_embeds=None,
1157
+ encoder_hidden_states=None,
1158
+ encoder_attention_mask=None,
1159
+ labels=None,
1160
+ output_attentions=None,
1161
+ output_hidden_states=None,
1162
+ return_dict=None,
1163
+ return_logits=False,
1164
+ is_decoder=False,
1165
+ ):
1166
+ r"""
1167
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1168
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1169
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1171
+ """
1172
+
1173
+ return_dict = (
1174
+ return_dict if return_dict is not None else self.config.use_return_dict
1175
+ )
1176
+
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ head_mask=head_mask,
1182
+ query_embeds=query_embeds,
1183
+ encoder_hidden_states=encoder_hidden_states,
1184
+ encoder_attention_mask=encoder_attention_mask,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ is_decoder=is_decoder,
1189
+ )
1190
+
1191
+ if query_embeds is not None:
1192
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1193
+ prediction_scores = self.cls(sequence_output)
1194
+
1195
+ if return_logits:
1196
+ return prediction_scores
1197
+
1198
+ masked_lm_loss = None
1199
+ if labels is not None:
1200
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1201
+ masked_lm_loss = loss_fct(
1202
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1203
+ )
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return (
1208
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1209
+ )
1210
+
1211
+ return MaskedLMOutput(
1212
+ loss=masked_lm_loss,
1213
+ logits=prediction_scores,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
minigpt4/models/__init__.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+
12
+ from minigpt4.common.registry import registry
13
+ from minigpt4.models.base_model import BaseModel
14
+ from minigpt4.models.blip2 import Blip2Base
15
+ from minigpt4.models.mini_gpt4 import MiniGPT4
16
+ from minigpt4.processors.base_processor import BaseProcessor
17
+
18
+
19
+ __all__ = [
20
+ "load_model",
21
+ "BaseModel",
22
+ "Blip2Base",
23
+ "MiniGPT4",
24
+ ]
25
+
26
+
27
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
28
+ """
29
+ Load supported models.
30
+
31
+ To list all available models and types in registry:
32
+ >>> from minigpt4.models import model_zoo
33
+ >>> print(model_zoo)
34
+
35
+ Args:
36
+ name (str): name of the model.
37
+ model_type (str): type of the model.
38
+ is_eval (bool): whether the model is in eval mode. Default: False.
39
+ device (str): device to use. Default: "cpu".
40
+ checkpoint (str): path or to checkpoint. Default: None.
41
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
42
+
43
+ Returns:
44
+ model (torch.nn.Module): model.
45
+ """
46
+
47
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
48
+
49
+ if checkpoint is not None:
50
+ model.load_checkpoint(checkpoint)
51
+
52
+ if is_eval:
53
+ model.eval()
54
+
55
+ if device == "cpu":
56
+ model = model.float()
57
+
58
+ return model.to(device)
59
+
60
+
61
+ def load_preprocess(config):
62
+ """
63
+ Load preprocessor configs and construct preprocessors.
64
+
65
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
66
+
67
+ Args:
68
+ config (dict): preprocessor configs.
69
+
70
+ Returns:
71
+ vis_processors (dict): preprocessors for visual inputs.
72
+ txt_processors (dict): preprocessors for text inputs.
73
+
74
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
75
+ """
76
+
77
+ def _build_proc_from_cfg(cfg):
78
+ return (
79
+ registry.get_processor_class(cfg.name).from_config(cfg)
80
+ if cfg is not None
81
+ else BaseProcessor()
82
+ )
83
+
84
+ vis_processors = dict()
85
+ txt_processors = dict()
86
+
87
+ vis_proc_cfg = config.get("vis_processor")
88
+ txt_proc_cfg = config.get("text_processor")
89
+
90
+ if vis_proc_cfg is not None:
91
+ vis_train_cfg = vis_proc_cfg.get("train")
92
+ vis_eval_cfg = vis_proc_cfg.get("eval")
93
+ else:
94
+ vis_train_cfg = None
95
+ vis_eval_cfg = None
96
+
97
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
98
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
99
+
100
+ if txt_proc_cfg is not None:
101
+ txt_train_cfg = txt_proc_cfg.get("train")
102
+ txt_eval_cfg = txt_proc_cfg.get("eval")
103
+ else:
104
+ txt_train_cfg = None
105
+ txt_eval_cfg = None
106
+
107
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
108
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
109
+
110
+ return vis_processors, txt_processors
111
+
112
+
113
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
114
+ """
115
+ Load model and its related preprocessors.
116
+
117
+ List all available models and types in registry:
118
+ >>> from minigpt4.models import model_zoo
119
+ >>> print(model_zoo)
120
+
121
+ Args:
122
+ name (str): name of the model.
123
+ model_type (str): type of the model.
124
+ is_eval (bool): whether the model is in eval mode. Default: False.
125
+ device (str): device to use. Default: "cpu".
126
+
127
+ Returns:
128
+ model (torch.nn.Module): model.
129
+ vis_processors (dict): preprocessors for visual inputs.
130
+ txt_processors (dict): preprocessors for text inputs.
131
+ """
132
+ model_cls = registry.get_model_class(name)
133
+
134
+ # load model
135
+ model = model_cls.from_pretrained(model_type=model_type)
136
+
137
+ if is_eval:
138
+ model.eval()
139
+
140
+ # load preprocess
141
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
142
+ if cfg is not None:
143
+ preprocess_cfg = cfg.preprocess
144
+
145
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
146
+ else:
147
+ vis_processors, txt_processors = None, None
148
+ logging.info(
149
+ f"""No default preprocess for model {name} ({model_type}).
150
+ This can happen if the model is not finetuned on downstream datasets,
151
+ or it is not intended for direct use without finetuning.
152
+ """
153
+ )
154
+
155
+ if device == "cpu" or device == torch.device("cpu"):
156
+ model = model.float()
157
+
158
+ return model.to(device), vis_processors, txt_processors
159
+
160
+
161
+ class ModelZoo:
162
+ """
163
+ A utility class to create string representation of available model architectures and types.
164
+
165
+ >>> from minigpt4.models import model_zoo
166
+ >>> # list all available models
167
+ >>> print(model_zoo)
168
+ >>> # show total number of models
169
+ >>> print(len(model_zoo))
170
+ """
171
+
172
+ def __init__(self) -> None:
173
+ self.model_zoo = {
174
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
175
+ for k, v in registry.mapping["model_name_mapping"].items()
176
+ }
177
+
178
+ def __str__(self) -> str:
179
+ return (
180
+ "=" * 50
181
+ + "\n"
182
+ + f"{'Architectures':<30} {'Types'}\n"
183
+ + "=" * 50
184
+ + "\n"
185
+ + "\n".join(
186
+ [
187
+ f"{name:<30} {', '.join(types)}"
188
+ for name, types in self.model_zoo.items()
189
+ ]
190
+ )
191
+ )
192
+
193
+ def __iter__(self):
194
+ return iter(self.model_zoo.items())
195
+
196
+ def __len__(self):
197
+ return sum([len(v) for v in self.model_zoo.values()])
198
+
199
+
200
+ model_zoo = ModelZoo()
minigpt4/models/base_model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from minigpt4.common.utils import get_abs_path, is_url
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ class BaseModel(nn.Module):
20
+ """Base class for models."""
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @property
26
+ def device(self):
27
+ return list(self.parameters())[0].device
28
+
29
+ def load_checkpoint(self, url_or_filename):
30
+ """
31
+ Load from a finetuned checkpoint.
32
+
33
+ This should expect no mismatch in the model keys and the checkpoint keys.
34
+ """
35
+
36
+ if is_url(url_or_filename):
37
+ cached_file = download_cached_file(
38
+ url_or_filename, check_hash=False, progress=True
39
+ )
40
+ checkpoint = torch.load(cached_file, map_location="cpu")
41
+ elif os.path.isfile(url_or_filename):
42
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
43
+ else:
44
+ raise RuntimeError("checkpoint url or path is invalid")
45
+
46
+ if "model" in checkpoint.keys():
47
+ state_dict = checkpoint["model"]
48
+ else:
49
+ state_dict = checkpoint
50
+
51
+ msg = self.load_state_dict(state_dict, strict=False)
52
+
53
+ logging.info("Missing keys {}".format(msg.missing_keys))
54
+ logging.info("load checkpoint from %s" % url_or_filename)
55
+
56
+ return msg
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, model_type):
60
+ """
61
+ Build a pretrained model from default configuration file, specified by model_type.
62
+
63
+ Args:
64
+ - model_type (str): model type, specifying architecture and checkpoints.
65
+
66
+ Returns:
67
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
68
+ """
69
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
70
+ model = cls.from_config(model_cfg)
71
+
72
+ return model
73
+
74
+ @classmethod
75
+ def default_config_path(cls, model_type):
76
+ assert (
77
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
78
+ ), "Unknown model type {}".format(model_type)
79
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
80
+
81
+ def load_checkpoint_from_config(self, cfg, **kwargs):
82
+ """
83
+ Load checkpoint as specified in the config file.
84
+
85
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
86
+ When loading the pretrained model, each task-specific architecture may define their
87
+ own load_from_pretrained() method.
88
+ """
89
+ load_finetuned = cfg.get("load_finetuned", True)
90
+ if load_finetuned:
91
+ finetune_path = cfg.get("finetuned", None)
92
+ assert (
93
+ finetune_path is not None
94
+ ), "Found load_finetuned is True, but finetune_path is None."
95
+ self.load_checkpoint(url_or_filename=finetune_path)
96
+ else:
97
+ # load pre-trained weights
98
+ pretrain_path = cfg.get("pretrained", None)
99
+ assert "Found load_finetuned is False, but pretrain_path is None."
100
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
101
+
102
+ def before_evaluation(self, **kwargs):
103
+ pass
104
+
105
+ def show_n_params(self, return_str=True):
106
+ tot = 0
107
+ for p in self.parameters():
108
+ w = 1
109
+ for x in p.shape:
110
+ w *= x
111
+ tot += w
112
+ if return_str:
113
+ if tot >= 1e6:
114
+ return "{:.1f}M".format(tot / 1e6)
115
+ else:
116
+ return "{:.1f}K".format(tot / 1e3)
117
+ else:
118
+ return tot
119
+
120
+
121
+ class BaseEncoder(nn.Module):
122
+ """
123
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
124
+ """
125
+
126
+ def __init__(self):
127
+ super().__init__()
128
+
129
+ def forward_features(self, samples, **kwargs):
130
+ raise NotImplementedError
131
+
132
+ @property
133
+ def device(self):
134
+ return list(self.parameters())[0].device
135
+
136
+
137
+ class SharedQueueMixin:
138
+ @torch.no_grad()
139
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
140
+ # gather keys before updating queue
141
+ image_feats = concat_all_gather(image_feat)
142
+ text_feats = concat_all_gather(text_feat)
143
+
144
+ batch_size = image_feats.shape[0]
145
+
146
+ ptr = int(self.queue_ptr)
147
+ assert self.queue_size % batch_size == 0 # for simplicity
148
+
149
+ # replace the keys at ptr (dequeue and enqueue)
150
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
151
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
152
+
153
+ if idxs is not None:
154
+ idxs = concat_all_gather(idxs)
155
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
156
+
157
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
158
+ self.queue_ptr[0] = ptr
159
+
160
+
161
+ class MomentumDistilationMixin:
162
+ @torch.no_grad()
163
+ def copy_params(self):
164
+ for model_pair in self.model_pairs:
165
+ for param, param_m in zip(
166
+ model_pair[0].parameters(), model_pair[1].parameters()
167
+ ):
168
+ param_m.data.copy_(param.data) # initialize
169
+ param_m.requires_grad = False # not update by gradient
170
+
171
+ @torch.no_grad()
172
+ def _momentum_update(self):
173
+ for model_pair in self.model_pairs:
174
+ for param, param_m in zip(
175
+ model_pair[0].parameters(), model_pair[1].parameters()
176
+ ):
177
+ param_m.data = param_m.data * self.momentum + param.data * (
178
+ 1.0 - self.momentum
179
+ )
180
+
181
+
182
+ class GatherLayer(torch.autograd.Function):
183
+ """
184
+ Gather tensors from all workers with support for backward propagation:
185
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
186
+ """
187
+
188
+ @staticmethod
189
+ def forward(ctx, x):
190
+ output = [
191
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
192
+ ]
193
+ torch.distributed.all_gather(output, x)
194
+ return tuple(output)
195
+
196
+ @staticmethod
197
+ def backward(ctx, *grads):
198
+ all_gradients = torch.stack(grads)
199
+ torch.distributed.all_reduce(all_gradients)
200
+ return all_gradients[torch.distributed.get_rank()]
201
+
202
+
203
+ def all_gather_with_grad(tensors):
204
+ """
205
+ Performs all_gather operation on the provided tensors.
206
+ Graph remains connected for backward grad computation.
207
+ """
208
+ # Queue the gathered tensors
209
+ world_size = torch.distributed.get_world_size()
210
+ # There is no need for reduction in the single-proc case
211
+ if world_size == 1:
212
+ return tensors
213
+
214
+ # tensor_all = GatherLayer.apply(tensors)
215
+ tensor_all = GatherLayer.apply(tensors)
216
+
217
+ return torch.cat(tensor_all, dim=0)
218
+
219
+
220
+ @torch.no_grad()
221
+ def concat_all_gather(tensor):
222
+ """
223
+ Performs all_gather operation on the provided tensors.
224
+ *** Warning ***: torch.distributed.all_gather has no gradient.
225
+ """
226
+ # if use distributed training
227
+ if not is_dist_avail_and_initialized():
228
+ return tensor
229
+
230
+ tensors_gather = [
231
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
232
+ ]
233
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
234
+
235
+ output = torch.cat(tensors_gather, dim=0)
236
+ return output
237
+
238
+
239
+ def tile(x, dim, n_tile):
240
+ init_dim = x.size(dim)
241
+ repeat_idx = [1] * x.dim()
242
+ repeat_idx[dim] = n_tile
243
+ x = x.repeat(*(repeat_idx))
244
+ order_index = torch.LongTensor(
245
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
246
+ )
247
+ return torch.index_select(x, dim, order_index.to(x.device))
minigpt4/models/blip2.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+ import time
11
+ import datetime
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+ import minigpt4.common.dist_utils as dist_utils
19
+ from minigpt4.common.dist_utils import download_cached_file
20
+ from minigpt4.common.utils import is_url
21
+ from minigpt4.common.logger import MetricLogger
22
+ from minigpt4.models.base_model import BaseModel
23
+ from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
24
+ from minigpt4.models.eva_vit import create_eva_vit_g
25
+ from transformers import BertTokenizer
26
+
27
+
28
+ class Blip2Base(BaseModel):
29
+ @classmethod
30
+ def init_tokenizer(cls):
31
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
32
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
33
+ return tokenizer
34
+
35
+ def maybe_autocast(self, dtype=torch.float16):
36
+ # if on cpu, don't use autocast
37
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
38
+ enable_autocast = self.device != torch.device("cpu")
39
+
40
+ if enable_autocast:
41
+ return torch.cuda.amp.autocast(dtype=dtype)
42
+ else:
43
+ return contextlib.nullcontext()
44
+
45
+ @classmethod
46
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
47
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
48
+ encoder_config.encoder_width = vision_width
49
+ # insert cross-attention layer every other block
50
+ encoder_config.add_cross_attention = True
51
+ encoder_config.cross_attention_freq = cross_attention_freq
52
+ encoder_config.query_length = num_query_token
53
+ Qformer = BertLMHeadModel(config=encoder_config)
54
+ query_tokens = nn.Parameter(
55
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
56
+ )
57
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
58
+ return Qformer, query_tokens
59
+
60
+ @classmethod
61
+ def init_vision_encoder(
62
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
63
+ ):
64
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
65
+ visual_encoder = create_eva_vit_g(
66
+ img_size, drop_path_rate, use_grad_checkpoint, precision
67
+ )
68
+
69
+ ln_vision = LayerNorm(visual_encoder.num_features)
70
+ return visual_encoder, ln_vision
71
+
72
+ def load_from_pretrained(self, url_or_filename):
73
+ if is_url(url_or_filename):
74
+ cached_file = download_cached_file(
75
+ url_or_filename, check_hash=False, progress=True
76
+ )
77
+ checkpoint = torch.load(cached_file, map_location="cpu")
78
+ elif os.path.isfile(url_or_filename):
79
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
80
+ else:
81
+ raise RuntimeError("checkpoint url or path is invalid")
82
+
83
+ state_dict = checkpoint["model"]
84
+
85
+ msg = self.load_state_dict(state_dict, strict=False)
86
+
87
+ # logging.info("Missing keys {}".format(msg.missing_keys))
88
+ logging.info("load checkpoint from %s" % url_or_filename)
89
+
90
+ return msg
91
+
92
+
93
+ def disabled_train(self, mode=True):
94
+ """Overwrite model.train with this function to make sure train/eval mode
95
+ does not change anymore."""
96
+ return self
97
+
98
+
99
+ class LayerNorm(nn.LayerNorm):
100
+ """Subclass torch's LayerNorm to handle fp16."""
101
+
102
+ def forward(self, x: torch.Tensor):
103
+ orig_type = x.dtype
104
+ ret = super().forward(x.type(torch.float32))
105
+ return ret.type(orig_type)
106
+
107
+
108
+ def compute_sim_matrix(model, data_loader, **kwargs):
109
+ k_test = kwargs.pop("k_test")
110
+
111
+ metric_logger = MetricLogger(delimiter=" ")
112
+ header = "Evaluation:"
113
+
114
+ logging.info("Computing features for evaluation...")
115
+ start_time = time.time()
116
+
117
+ texts = data_loader.dataset.text
118
+ num_text = len(texts)
119
+ text_bs = 256
120
+ text_ids = []
121
+ text_embeds = []
122
+ text_atts = []
123
+ for i in range(0, num_text, text_bs):
124
+ text = texts[i : min(num_text, i + text_bs)]
125
+ text_input = model.tokenizer(
126
+ text,
127
+ padding="max_length",
128
+ truncation=True,
129
+ max_length=35,
130
+ return_tensors="pt",
131
+ ).to(model.device)
132
+ text_feat = model.forward_text(text_input)
133
+ text_embed = F.normalize(model.text_proj(text_feat))
134
+ text_embeds.append(text_embed)
135
+ text_ids.append(text_input.input_ids)
136
+ text_atts.append(text_input.attention_mask)
137
+
138
+ text_embeds = torch.cat(text_embeds, dim=0)
139
+ text_ids = torch.cat(text_ids, dim=0)
140
+ text_atts = torch.cat(text_atts, dim=0)
141
+
142
+ vit_feats = []
143
+ image_embeds = []
144
+ for samples in data_loader:
145
+ image = samples["image"]
146
+
147
+ image = image.to(model.device)
148
+ image_feat, vit_feat = model.forward_image(image)
149
+ image_embed = model.vision_proj(image_feat)
150
+ image_embed = F.normalize(image_embed, dim=-1)
151
+
152
+ vit_feats.append(vit_feat.cpu())
153
+ image_embeds.append(image_embed)
154
+
155
+ vit_feats = torch.cat(vit_feats, dim=0)
156
+ image_embeds = torch.cat(image_embeds, dim=0)
157
+
158
+ sims_matrix = []
159
+ for image_embed in image_embeds:
160
+ sim_q2t = image_embed @ text_embeds.t()
161
+ sim_i2t, _ = sim_q2t.max(0)
162
+ sims_matrix.append(sim_i2t)
163
+ sims_matrix = torch.stack(sims_matrix, dim=0)
164
+
165
+ score_matrix_i2t = torch.full(
166
+ (len(data_loader.dataset.image), len(texts)), -100.0
167
+ ).to(model.device)
168
+
169
+ num_tasks = dist_utils.get_world_size()
170
+ rank = dist_utils.get_rank()
171
+ step = sims_matrix.size(0) // num_tasks + 1
172
+ start = rank * step
173
+ end = min(sims_matrix.size(0), start + step)
174
+
175
+ for i, sims in enumerate(
176
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
177
+ ):
178
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
179
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
180
+ score = model.compute_itm(
181
+ image_inputs=image_inputs,
182
+ text_ids=text_ids[topk_idx],
183
+ text_atts=text_atts[topk_idx],
184
+ ).float()
185
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
186
+
187
+ sims_matrix = sims_matrix.t()
188
+ score_matrix_t2i = torch.full(
189
+ (len(texts), len(data_loader.dataset.image)), -100.0
190
+ ).to(model.device)
191
+
192
+ step = sims_matrix.size(0) // num_tasks + 1
193
+ start = rank * step
194
+ end = min(sims_matrix.size(0), start + step)
195
+
196
+ for i, sims in enumerate(
197
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
198
+ ):
199
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
200
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
201
+ score = model.compute_itm(
202
+ image_inputs=image_inputs,
203
+ text_ids=text_ids[start + i].repeat(k_test, 1),
204
+ text_atts=text_atts[start + i].repeat(k_test, 1),
205
+ ).float()
206
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
207
+
208
+ if dist_utils.is_dist_avail_and_initialized():
209
+ dist.barrier()
210
+ torch.distributed.all_reduce(
211
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
212
+ )
213
+ torch.distributed.all_reduce(
214
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
215
+ )
216
+
217
+ total_time = time.time() - start_time
218
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
219
+ logging.info("Evaluation time {}".format(total_time_str))
220
+
221
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
minigpt4/models/blip2_outputs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from transformers.modeling_outputs import (
13
+ ModelOutput,
14
+ BaseModelOutputWithPoolingAndCrossAttentions,
15
+ CausalLMOutputWithCrossAttentions,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class BlipSimilarity(ModelOutput):
21
+ sim_i2t: torch.FloatTensor = None
22
+ sim_t2i: torch.FloatTensor = None
23
+
24
+ sim_i2t_m: Optional[torch.FloatTensor] = None
25
+ sim_t2i_m: Optional[torch.FloatTensor] = None
26
+
27
+ sim_i2t_targets: Optional[torch.FloatTensor] = None
28
+ sim_t2i_targets: Optional[torch.FloatTensor] = None
29
+
30
+
31
+ @dataclass
32
+ class BlipIntermediateOutput(ModelOutput):
33
+ """
34
+ Data class for intermediate outputs of BLIP models.
35
+
36
+ image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37
+ text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38
+
39
+ image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40
+ text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41
+
42
+ encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43
+ encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44
+
45
+ decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46
+ decoder_labels (torch.LongTensor): labels for the captioning loss.
47
+
48
+ itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49
+ itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50
+
51
+ """
52
+
53
+ # uni-modal features
54
+ image_embeds: torch.FloatTensor = None
55
+ text_embeds: Optional[torch.FloatTensor] = None
56
+
57
+ image_embeds_m: Optional[torch.FloatTensor] = None
58
+ text_embeds_m: Optional[torch.FloatTensor] = None
59
+
60
+ # intermediate outputs of multimodal encoder
61
+ encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62
+ encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63
+
64
+ itm_logits: Optional[torch.FloatTensor] = None
65
+ itm_labels: Optional[torch.LongTensor] = None
66
+
67
+ # intermediate outputs of multimodal decoder
68
+ decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69
+ decoder_labels: Optional[torch.LongTensor] = None
70
+
71
+
72
+ @dataclass
73
+ class BlipOutput(ModelOutput):
74
+ # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75
+ sims: Optional[BlipSimilarity] = None
76
+
77
+ intermediate_output: BlipIntermediateOutput = None
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+
81
+ loss_itc: Optional[torch.FloatTensor] = None
82
+
83
+ loss_itm: Optional[torch.FloatTensor] = None
84
+
85
+ loss_lm: Optional[torch.FloatTensor] = None
86
+
87
+
88
+ @dataclass
89
+ class BlipOutputFeatures(ModelOutput):
90
+ """
91
+ Data class of features from BlipFeatureExtractor.
92
+
93
+ Args:
94
+ image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
95
+ image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
96
+ text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
97
+ text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
98
+
99
+ The first embedding or feature is for the [CLS] token.
100
+
101
+ Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
102
+ """
103
+
104
+ image_embeds: Optional[torch.FloatTensor] = None
105
+ image_embeds_proj: Optional[torch.FloatTensor] = None
106
+
107
+ text_embeds: Optional[torch.FloatTensor] = None
108
+ text_embeds_proj: Optional[torch.FloatTensor] = None
109
+
110
+ multimodal_embeds: Optional[torch.FloatTensor] = None
minigpt4/models/eva_vit.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from minigpt4.common.dist_utils import download_cached_file
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+
30
+ class DropPath(nn.Module):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ """
33
+ def __init__(self, drop_prob=None):
34
+ super(DropPath, self).__init__()
35
+ self.drop_prob = drop_prob
36
+
37
+ def forward(self, x):
38
+ return drop_path(x, self.drop_prob, self.training)
39
+
40
+ def extra_repr(self) -> str:
41
+ return 'p={}'.format(self.drop_prob)
42
+
43
+
44
+ class Mlp(nn.Module):
45
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
+ super().__init__()
47
+ out_features = out_features or in_features
48
+ hidden_features = hidden_features or in_features
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+ self.fc2 = nn.Linear(hidden_features, out_features)
52
+ self.drop = nn.Dropout(drop)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ # x = self.drop(x)
58
+ # commit this for the orignal BERT implement
59
+ x = self.fc2(x)
60
+ x = self.drop(x)
61
+ return x
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(
66
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
67
+ proj_drop=0., window_size=None, attn_head_dim=None):
68
+ super().__init__()
69
+ self.num_heads = num_heads
70
+ head_dim = dim // num_heads
71
+ if attn_head_dim is not None:
72
+ head_dim = attn_head_dim
73
+ all_head_dim = head_dim * self.num_heads
74
+ self.scale = qk_scale or head_dim ** -0.5
75
+
76
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
77
+ if qkv_bias:
78
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
80
+ else:
81
+ self.q_bias = None
82
+ self.v_bias = None
83
+
84
+ if window_size:
85
+ self.window_size = window_size
86
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
87
+ self.relative_position_bias_table = nn.Parameter(
88
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
89
+ # cls to token & token 2 cls & cls to cls
90
+
91
+ # get pair-wise relative position index for each token inside the window
92
+ coords_h = torch.arange(window_size[0])
93
+ coords_w = torch.arange(window_size[1])
94
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
99
+ relative_coords[:, :, 1] += window_size[1] - 1
100
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
101
+ relative_position_index = \
102
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
103
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
104
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
105
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
106
+ relative_position_index[0, 0] = self.num_relative_distance - 1
107
+
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+ else:
110
+ self.window_size = None
111
+ self.relative_position_bias_table = None
112
+ self.relative_position_index = None
113
+
114
+ self.attn_drop = nn.Dropout(attn_drop)
115
+ self.proj = nn.Linear(all_head_dim, dim)
116
+ self.proj_drop = nn.Dropout(proj_drop)
117
+
118
+ def forward(self, x, rel_pos_bias=None):
119
+ B, N, C = x.shape
120
+ qkv_bias = None
121
+ if self.q_bias is not None:
122
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
123
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
126
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+
131
+ if self.relative_position_bias_table is not None:
132
+ relative_position_bias = \
133
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1] + 1,
135
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
136
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137
+ attn = attn + relative_position_bias.unsqueeze(0)
138
+
139
+ if rel_pos_bias is not None:
140
+ attn = attn + rel_pos_bias
141
+
142
+ attn = attn.softmax(dim=-1)
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
154
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
155
+ window_size=None, attn_head_dim=None):
156
+ super().__init__()
157
+ self.norm1 = norm_layer(dim)
158
+ self.attn = Attention(
159
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
160
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
161
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166
+
167
+ if init_values is not None and init_values > 0:
168
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
170
+ else:
171
+ self.gamma_1, self.gamma_2 = None, None
172
+
173
+ def forward(self, x, rel_pos_bias=None):
174
+ if self.gamma_1 is None:
175
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
176
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
177
+ else:
178
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
179
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
180
+ return x
181
+
182
+
183
+ class PatchEmbed(nn.Module):
184
+ """ Image to Patch Embedding
185
+ """
186
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
187
+ super().__init__()
188
+ img_size = to_2tuple(img_size)
189
+ patch_size = to_2tuple(patch_size)
190
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
191
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
192
+ self.img_size = img_size
193
+ self.patch_size = patch_size
194
+ self.num_patches = num_patches
195
+
196
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, x, **kwargs):
199
+ B, C, H, W = x.shape
200
+ # FIXME look at relaxing size constraints
201
+ assert H == self.img_size[0] and W == self.img_size[1], \
202
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
203
+ x = self.proj(x).flatten(2).transpose(1, 2)
204
+ return x
205
+
206
+
207
+ class RelativePositionBias(nn.Module):
208
+
209
+ def __init__(self, window_size, num_heads):
210
+ super().__init__()
211
+ self.window_size = window_size
212
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
213
+ self.relative_position_bias_table = nn.Parameter(
214
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
215
+ # cls to token & token 2 cls & cls to cls
216
+
217
+ # get pair-wise relative position index for each token inside the window
218
+ coords_h = torch.arange(window_size[0])
219
+ coords_w = torch.arange(window_size[1])
220
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
221
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
222
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
223
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
224
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
225
+ relative_coords[:, :, 1] += window_size[1] - 1
226
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
227
+ relative_position_index = \
228
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
229
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
230
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
231
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
232
+ relative_position_index[0, 0] = self.num_relative_distance - 1
233
+
234
+ self.register_buffer("relative_position_index", relative_position_index)
235
+
236
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
237
+
238
+ def forward(self):
239
+ relative_position_bias = \
240
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
241
+ self.window_size[0] * self.window_size[1] + 1,
242
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
243
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ """ Vision Transformer with support for patch or hybrid CNN input stage
248
+ """
249
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
250
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
251
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
252
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
253
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
254
+ super().__init__()
255
+ self.image_size = img_size
256
+ self.num_classes = num_classes
257
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
258
+
259
+ self.patch_embed = PatchEmbed(
260
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
261
+ num_patches = self.patch_embed.num_patches
262
+
263
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
264
+ if use_abs_pos_emb:
265
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
266
+ else:
267
+ self.pos_embed = None
268
+ self.pos_drop = nn.Dropout(p=drop_rate)
269
+
270
+ if use_shared_rel_pos_bias:
271
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
272
+ else:
273
+ self.rel_pos_bias = None
274
+ self.use_checkpoint = use_checkpoint
275
+
276
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
277
+ self.use_rel_pos_bias = use_rel_pos_bias
278
+ self.blocks = nn.ModuleList([
279
+ Block(
280
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
281
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
282
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
283
+ for i in range(depth)])
284
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
285
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
286
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
287
+
288
+ if self.pos_embed is not None:
289
+ trunc_normal_(self.pos_embed, std=.02)
290
+ trunc_normal_(self.cls_token, std=.02)
291
+ # trunc_normal_(self.mask_token, std=.02)
292
+ # if isinstance(self.head, nn.Linear):
293
+ # trunc_normal_(self.head.weight, std=.02)
294
+ self.apply(self._init_weights)
295
+ self.fix_init_weight()
296
+ # if isinstance(self.head, nn.Linear):
297
+ # self.head.weight.data.mul_(init_scale)
298
+ # self.head.bias.data.mul_(init_scale)
299
+
300
+ def fix_init_weight(self):
301
+ def rescale(param, layer_id):
302
+ param.div_(math.sqrt(2.0 * layer_id))
303
+
304
+ for layer_id, layer in enumerate(self.blocks):
305
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
306
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
307
+
308
+ def _init_weights(self, m):
309
+ if isinstance(m, nn.Linear):
310
+ trunc_normal_(m.weight, std=.02)
311
+ if isinstance(m, nn.Linear) and m.bias is not None:
312
+ nn.init.constant_(m.bias, 0)
313
+ elif isinstance(m, nn.LayerNorm):
314
+ nn.init.constant_(m.bias, 0)
315
+ nn.init.constant_(m.weight, 1.0)
316
+
317
+ def get_classifier(self):
318
+ return self.head
319
+
320
+ def reset_classifier(self, num_classes, global_pool=''):
321
+ self.num_classes = num_classes
322
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
323
+
324
+ def forward_features(self, x):
325
+ x = self.patch_embed(x)
326
+ batch_size, seq_len, _ = x.size()
327
+
328
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
329
+ x = torch.cat((cls_tokens, x), dim=1)
330
+ if self.pos_embed is not None:
331
+ x = x + self.pos_embed
332
+ x = self.pos_drop(x)
333
+
334
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
335
+ for blk in self.blocks:
336
+ if self.use_checkpoint:
337
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
338
+ else:
339
+ x = blk(x, rel_pos_bias)
340
+ return x
341
+ # x = self.norm(x)
342
+
343
+ # if self.fc_norm is not None:
344
+ # t = x[:, 1:, :]
345
+ # return self.fc_norm(t.mean(1))
346
+ # else:
347
+ # return x[:, 0]
348
+
349
+ def forward(self, x):
350
+ x = self.forward_features(x)
351
+ # x = self.head(x)
352
+ return x
353
+
354
+ def get_intermediate_layers(self, x):
355
+ x = self.patch_embed(x)
356
+ batch_size, seq_len, _ = x.size()
357
+
358
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
359
+ x = torch.cat((cls_tokens, x), dim=1)
360
+ if self.pos_embed is not None:
361
+ x = x + self.pos_embed
362
+ x = self.pos_drop(x)
363
+
364
+ features = []
365
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
366
+ for blk in self.blocks:
367
+ x = blk(x, rel_pos_bias)
368
+ features.append(x)
369
+
370
+ return features
371
+
372
+
373
+ def interpolate_pos_embed(model, checkpoint_model):
374
+ if 'pos_embed' in checkpoint_model:
375
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
376
+ embedding_size = pos_embed_checkpoint.shape[-1]
377
+ num_patches = model.patch_embed.num_patches
378
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
379
+ # height (== width) for the checkpoint position embedding
380
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
381
+ # height (== width) for the new position embedding
382
+ new_size = int(num_patches ** 0.5)
383
+ # class_token and dist_token are kept unchanged
384
+ if orig_size != new_size:
385
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
386
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
387
+ # only the position tokens are interpolated
388
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
389
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
390
+ pos_tokens = torch.nn.functional.interpolate(
391
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
392
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
393
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
394
+ checkpoint_model['pos_embed'] = new_pos_embed
395
+
396
+
397
+ def convert_weights_to_fp16(model: nn.Module):
398
+ """Convert applicable model parameters to fp16"""
399
+
400
+ def _convert_weights_to_fp16(l):
401
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
402
+ l.weight.data = l.weight.data.half()
403
+ if l.bias is not None:
404
+ l.bias.data = l.bias.data.half()
405
+
406
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
407
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
408
+ # tensor = getattr(l, attr)
409
+ # if tensor is not None:
410
+ # tensor.data = tensor.data.half()
411
+
412
+ model.apply(_convert_weights_to_fp16)
413
+
414
+
415
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
416
+ model = VisionTransformer(
417
+ img_size=img_size,
418
+ patch_size=14,
419
+ use_mean_pooling=False,
420
+ embed_dim=1408,
421
+ depth=39,
422
+ num_heads=1408//88,
423
+ mlp_ratio=4.3637,
424
+ qkv_bias=True,
425
+ drop_path_rate=drop_path_rate,
426
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
427
+ use_checkpoint=use_checkpoint,
428
+ )
429
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
430
+ cached_file = download_cached_file(
431
+ url, check_hash=False, progress=True
432
+ )
433
+ state_dict = torch.load(cached_file, map_location="cpu")
434
+ interpolate_pos_embed(model,state_dict)
435
+
436
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
437
+ # print(incompatible_keys)
438
+
439
+ if precision == "fp16":
440
+ # model.to("cuda")
441
+ convert_weights_to_fp16(model)
442
+ return model
minigpt4/models/mini_gpt4.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import random
9
+ import os
10
+ import torch
11
+ from torch.cuda.amp import autocast as autocast
12
+ import torch.nn as nn
13
+
14
+ from minigpt4.common.registry import registry
15
+ from minigpt4.models.blip2 import Blip2Base, disabled_train
16
+ from minigpt4.models.modeling_llama import LlamaForCausalLM
17
+ from transformers import LlamaTokenizer
18
+
19
+
20
+ @registry.register_model("mini_gpt4")
21
+ class MiniGPT4(Blip2Base):
22
+ """
23
+ BLIP2 GPT-LLAMA model.
24
+ """
25
+
26
+ PRETRAINED_MODEL_CONFIG_DICT = {
27
+ "pretrain_vicuna": "configs/models/minigpt4.yaml",
28
+ }
29
+
30
+ def __init__(
31
+ self,
32
+ vit_model="eva_clip_g",
33
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
34
+ img_size=224,
35
+ drop_path_rate=0,
36
+ use_grad_checkpoint=False,
37
+ vit_precision="fp16",
38
+ freeze_vit=True,
39
+ freeze_qformer=True,
40
+ num_query_token=32,
41
+ llama_model="",
42
+ llama_cache_dir='',
43
+ prompt_path="",
44
+ prompt_template="",
45
+ max_txt_len=32,
46
+ end_sym='\n',
47
+ ):
48
+ super().__init__()
49
+
50
+ self.tokenizer = self.init_tokenizer()
51
+
52
+ print('Loading VIT')
53
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
54
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
55
+ )
56
+ if freeze_vit:
57
+ for name, param in self.visual_encoder.named_parameters():
58
+ param.requires_grad = False
59
+ self.visual_encoder = self.visual_encoder.eval()
60
+ self.visual_encoder.train = disabled_train
61
+ for name, param in self.ln_vision.named_parameters():
62
+ param.requires_grad = False
63
+ self.ln_vision = self.ln_vision.eval()
64
+ self.ln_vision.train = disabled_train
65
+ logging.info("freeze vision encoder")
66
+ print('Loading VIT Done')
67
+
68
+ print('Loading Q-Former')
69
+ self.Qformer, self.query_tokens = self.init_Qformer(
70
+ num_query_token, self.visual_encoder.num_features
71
+ )
72
+ self.Qformer.cls = None
73
+ self.Qformer.bert.embeddings.word_embeddings = None
74
+ self.Qformer.bert.embeddings.position_embeddings = None
75
+ for layer in self.Qformer.bert.encoder.layer:
76
+ layer.output = None
77
+ layer.intermediate = None
78
+ self.load_from_pretrained(url_or_filename=q_former_model)
79
+
80
+ if freeze_qformer:
81
+ for name, param in self.Qformer.named_parameters():
82
+ param.requires_grad = False
83
+ self.Qformer = self.Qformer.eval()
84
+ self.Qformer.train = disabled_train
85
+ self.query_tokens.requires_grad = False
86
+ logging.info("freeze Qformer")
87
+ print('Loading Q-Former Done')
88
+
89
+ print('Loading LLAMA')
90
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained('Vision-CAIR/vicuna-7b', use_fast=False, use_auth_token=True)
91
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
92
+
93
+ if llama_cache_dir:
94
+ self.llama_model = LlamaForCausalLM.from_pretrained(
95
+ 'Vision-CAIR/vicuna-7b', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
96
+ )
97
+ else:
98
+ self.llama_model = LlamaForCausalLM.from_pretrained(
99
+ 'Vision-CAIR/vicuna-7b', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
100
+ )
101
+ for name, param in self.llama_model.named_parameters():
102
+ param.requires_grad = False
103
+ print('Loading LLAMA Done')
104
+
105
+ self.llama_proj = nn.Linear(
106
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
107
+ )
108
+ self.max_txt_len = max_txt_len
109
+ self.end_sym = end_sym
110
+
111
+ if prompt_path:
112
+ with open(prompt_path, 'r') as f:
113
+ raw_prompts = f.read().splitlines()
114
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
115
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
116
+ print('Load {} training prompts'.format(len(self.prompt_list)))
117
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
118
+ else:
119
+ self.prompt_list = []
120
+
121
+ def vit_to_cpu(self):
122
+ self.ln_vision.to("cpu")
123
+ self.ln_vision.float()
124
+ self.visual_encoder.to("cpu")
125
+ self.visual_encoder.float()
126
+
127
+ def encode_img(self, image):
128
+ device = image.device
129
+ self.vit_to_cpu()
130
+ image = image.to("cpu")
131
+ with self.maybe_autocast():
132
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
133
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
134
+
135
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
136
+ query_output = self.Qformer.bert(
137
+ query_embeds=query_tokens,
138
+ encoder_hidden_states=image_embeds,
139
+ encoder_attention_mask=image_atts,
140
+ return_dict=True,
141
+ )
142
+
143
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
144
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
145
+ return inputs_llama, atts_llama
146
+
147
+ def prompt_wrap(self, img_embeds, atts_img, prompt):
148
+ if prompt:
149
+ batch_size = img_embeds.shape[0]
150
+ p_before, p_after = prompt.split('<ImageHere>')
151
+ p_before_tokens = self.llama_tokenizer(
152
+ p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
153
+ p_after_tokens = self.llama_tokenizer(
154
+ p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
155
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
156
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
157
+ wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
158
+ wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
159
+ return wrapped_img_embeds, wrapped_atts_img
160
+ else:
161
+ return img_embeds, atts_img
162
+
163
+ def forward(self, samples):
164
+ image = samples["image"]
165
+ img_embeds, atts_img = self.encode_img(image)
166
+ if hasattr(samples, 'question_split'): # VQA dataset
167
+ print('VQA Batch')
168
+ vqa_prompt = '###Human: <Img><ImageHere></Img> '
169
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
170
+ elif self.prompt_list:
171
+ prompt = random.choice(self.prompt_list)
172
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
173
+
174
+ self.llama_tokenizer.padding_side = "right"
175
+
176
+ text = [t + self.end_sym for t in samples["text_input"]]
177
+
178
+ to_regress_tokens = self.llama_tokenizer(
179
+ text,
180
+ return_tensors="pt",
181
+ padding="longest",
182
+ truncation=True,
183
+ max_length=self.max_txt_len,
184
+ add_special_tokens=False
185
+ ).to(image.device)
186
+
187
+ targets = to_regress_tokens.input_ids.masked_fill(
188
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
189
+ )
190
+
191
+ empty_targets = (
192
+ torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
193
+ dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
194
+ )
195
+ targets = torch.cat([empty_targets, targets], dim=1)
196
+
197
+ batch_size = img_embeds.shape[0]
198
+ bos = torch.ones([batch_size, 1],
199
+ dtype=to_regress_tokens.input_ids.dtype,
200
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
201
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
202
+ atts_bos = atts_img[:, :1]
203
+
204
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
205
+ inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
206
+ attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
207
+
208
+ with self.maybe_autocast():
209
+ outputs = self.llama_model(
210
+ inputs_embeds=inputs_embeds,
211
+ attention_mask=attention_mask,
212
+ return_dict=True,
213
+ labels=targets,
214
+ )
215
+ loss = outputs.loss
216
+
217
+ return {"loss": loss}
218
+
219
+ @classmethod
220
+ def from_config(cls, cfg):
221
+ vit_model = cfg.get("vit_model", "eva_clip_g")
222
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
223
+ img_size = cfg.get("image_size")
224
+ num_query_token = cfg.get("num_query_token")
225
+ llama_model = cfg.get("llama_model")
226
+
227
+ drop_path_rate = cfg.get("drop_path_rate", 0)
228
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
229
+ vit_precision = cfg.get("vit_precision", "fp16")
230
+ freeze_vit = cfg.get("freeze_vit", True)
231
+ freeze_qformer = cfg.get("freeze_qformer", True)
232
+ llama_cache_dir = cfg.get("llama_cache_dir", "")
233
+
234
+ prompt_path = cfg.get("prompt_path", "")
235
+ prompt_template = cfg.get("prompt_template", "")
236
+ max_txt_len = cfg.get("max_txt_len", 32)
237
+ end_sym = cfg.get("end_sym", '\n')
238
+
239
+ model = cls(
240
+ vit_model=vit_model,
241
+ q_former_model=q_former_model,
242
+ img_size=img_size,
243
+ drop_path_rate=drop_path_rate,
244
+ use_grad_checkpoint=use_grad_checkpoint,
245
+ vit_precision=vit_precision,
246
+ freeze_vit=freeze_vit,
247
+ freeze_qformer=freeze_qformer,
248
+ llama_cache_dir=llama_cache_dir,
249
+ num_query_token=num_query_token,
250
+ llama_model=llama_model,
251
+ prompt_path=prompt_path,
252
+ prompt_template=prompt_template,
253
+ max_txt_len=max_txt_len,
254
+ end_sym=end_sym
255
+ )
256
+
257
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
258
+ if ckpt_path:
259
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
260
+ ckpt = torch.load(ckpt_path, map_location="cpu")
261
+ msg = model.load_state_dict(ckpt['model'], strict=False)
262
+
263
+ return model
minigpt4/models/modeling_llama.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from transformers.models.llama.configuration_llama import LlamaConfig
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = "LlamaConfig"
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86
+
87
+ # convert into half-precision if necessary
88
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
89
+ hidden_states = hidden_states.to(self.weight.dtype)
90
+
91
+ return self.weight * hidden_states
92
+
93
+
94
+ class LlamaRotaryEmbedding(torch.nn.Module):
95
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
96
+ super().__init__()
97
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
98
+ self.register_buffer("inv_freq", inv_freq)
99
+
100
+ # Build here to make `torch.jit.trace` work.
101
+ self.max_seq_len_cached = max_position_embeddings
102
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
103
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
104
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
105
+ emb = torch.cat((freqs, freqs), dim=-1)
106
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
107
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
108
+
109
+ def forward(self, x, seq_len=None):
110
+ # x: [bs, num_attention_heads, seq_len, head_size]
111
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
112
+ if seq_len > self.max_seq_len_cached:
113
+ self.max_seq_len_cached = seq_len
114
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
115
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
118
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
119
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
120
+ return (
121
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
122
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123
+ )
124
+
125
+
126
+ def rotate_half(x):
127
+ """Rotates half the hidden dims of the input."""
128
+ x1 = x[..., : x.shape[-1] // 2]
129
+ x2 = x[..., x.shape[-1] // 2 :]
130
+ return torch.cat((-x2, x1), dim=-1)
131
+
132
+
133
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
134
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
135
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
136
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
137
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
138
+ q_embed = (q * cos) + (rotate_half(q) * sin)
139
+ k_embed = (k * cos) + (rotate_half(k) * sin)
140
+ return q_embed, k_embed
141
+
142
+
143
+ class LlamaMLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ hidden_size: int,
147
+ intermediate_size: int,
148
+ hidden_act: str,
149
+ ):
150
+ super().__init__()
151
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
152
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
153
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
154
+ self.act_fn = ACT2FN[hidden_act]
155
+
156
+ def forward(self, x):
157
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
+
159
+
160
+ class LlamaAttention(nn.Module):
161
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
162
+
163
+ def __init__(self, config: LlamaConfig):
164
+ super().__init__()
165
+ self.config = config
166
+ self.hidden_size = config.hidden_size
167
+ self.num_heads = config.num_attention_heads
168
+ self.head_dim = self.hidden_size // self.num_heads
169
+ self.max_position_embeddings = config.max_position_embeddings
170
+
171
+ if (self.head_dim * self.num_heads) != self.hidden_size:
172
+ raise ValueError(
173
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
174
+ f" and `num_heads`: {self.num_heads})."
175
+ )
176
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
177
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
180
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
181
+
182
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
183
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
184
+
185
+ def forward(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ position_ids: Optional[torch.LongTensor] = None,
190
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
191
+ output_attentions: bool = False,
192
+ use_cache: bool = False,
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
+ bsz, q_len, _ = hidden_states.size()
195
+
196
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
+
200
+ kv_seq_len = key_states.shape[-2]
201
+ if past_key_value is not None:
202
+ kv_seq_len += past_key_value[0].shape[-2]
203
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
204
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
205
+ # [bsz, nh, t, hd]
206
+
207
+ if past_key_value is not None:
208
+ # reuse k, v, self_attention
209
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
210
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
211
+
212
+ past_key_value = (key_states, value_states) if use_cache else None
213
+
214
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
215
+
216
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
217
+ raise ValueError(
218
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
219
+ f" {attn_weights.size()}"
220
+ )
221
+
222
+ if attention_mask is not None:
223
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
224
+ raise ValueError(
225
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
226
+ )
227
+ attn_weights = attn_weights + attention_mask
228
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
229
+
230
+ # upcast attention to fp32
231
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
232
+ attn_output = torch.matmul(attn_weights, value_states)
233
+
234
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
235
+ raise ValueError(
236
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
237
+ f" {attn_output.size()}"
238
+ )
239
+
240
+ attn_output = attn_output.transpose(1, 2)
241
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
242
+
243
+ attn_output = self.o_proj(attn_output)
244
+
245
+ if not output_attentions:
246
+ attn_weights = None
247
+
248
+ return attn_output, attn_weights, past_key_value
249
+
250
+
251
+ class LlamaDecoderLayer(nn.Module):
252
+ def __init__(self, config: LlamaConfig):
253
+ super().__init__()
254
+ self.hidden_size = config.hidden_size
255
+ self.self_attn = LlamaAttention(config=config)
256
+ self.mlp = LlamaMLP(
257
+ hidden_size=self.hidden_size,
258
+ intermediate_size=config.intermediate_size,
259
+ hidden_act=config.hidden_act,
260
+ )
261
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
262
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
270
+ output_attentions: Optional[bool] = False,
271
+ use_cache: Optional[bool] = False,
272
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
273
+ """
274
+ Args:
275
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
276
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
277
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
278
+ output_attentions (`bool`, *optional*):
279
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
280
+ returned tensors for more detail.
281
+ use_cache (`bool`, *optional*):
282
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
283
+ (see `past_key_values`).
284
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
285
+ """
286
+
287
+ residual = hidden_states
288
+
289
+ hidden_states = self.input_layernorm(hidden_states)
290
+
291
+ # Self Attention
292
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
293
+ hidden_states=hidden_states,
294
+ attention_mask=attention_mask,
295
+ position_ids=position_ids,
296
+ past_key_value=past_key_value,
297
+ output_attentions=output_attentions,
298
+ use_cache=use_cache,
299
+ )
300
+ hidden_states = residual + hidden_states
301
+
302
+ # Fully Connected
303
+ residual = hidden_states
304
+ hidden_states = self.post_attention_layernorm(hidden_states)
305
+ hidden_states = self.mlp(hidden_states)
306
+ hidden_states = residual + hidden_states
307
+
308
+ outputs = (hidden_states,)
309
+
310
+ if output_attentions:
311
+ outputs += (self_attn_weights,)
312
+
313
+ if use_cache:
314
+ outputs += (present_key_value,)
315
+
316
+ return outputs
317
+
318
+
319
+ LLAMA_START_DOCSTRING = r"""
320
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
321
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
322
+ etc.)
323
+
324
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
325
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
326
+ and behavior.
327
+
328
+ Parameters:
329
+ config ([`LlamaConfig`]):
330
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
331
+ load the weights associated with the model, only the configuration. Check out the
332
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
333
+ """
334
+
335
+
336
+ @add_start_docstrings(
337
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
338
+ LLAMA_START_DOCSTRING,
339
+ )
340
+ class LlamaPreTrainedModel(PreTrainedModel):
341
+ config_class = LlamaConfig
342
+ base_model_prefix = "model"
343
+ supports_gradient_checkpointing = True
344
+ _no_split_modules = ["LlamaDecoderLayer"]
345
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
346
+
347
+ def _init_weights(self, module):
348
+ std = self.config.initializer_range
349
+ if isinstance(module, nn.Linear):
350
+ module.weight.data.normal_(mean=0.0, std=std)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=std)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+
358
+ def _set_gradient_checkpointing(self, module, value=False):
359
+ if isinstance(module, LlamaModel):
360
+ module.gradient_checkpointing = value
361
+
362
+
363
+ LLAMA_INPUTS_DOCSTRING = r"""
364
+ Args:
365
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
366
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
367
+ it.
368
+
369
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
370
+ [`PreTrainedTokenizer.__call__`] for details.
371
+
372
+ [What are input IDs?](../glossary#input-ids)
373
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
374
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
375
+
376
+ - 1 for tokens that are **not masked**,
377
+ - 0 for tokens that are **masked**.
378
+
379
+ [What are attention masks?](../glossary#attention-mask)
380
+
381
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
382
+ [`PreTrainedTokenizer.__call__`] for details.
383
+
384
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
385
+ `past_key_values`).
386
+
387
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
388
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
389
+ information on the default strategy.
390
+
391
+ - 1 indicates the head is **not masked**,
392
+ - 0 indicates the head is **masked**.
393
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
394
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
395
+ config.n_positions - 1]`.
396
+
397
+ [What are position IDs?](../glossary#position-ids)
398
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
399
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
400
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
401
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
402
+
403
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
404
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
405
+
406
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
407
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
408
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
409
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
410
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
411
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
412
+ model's internal embedding lookup matrix.
413
+ use_cache (`bool`, *optional*):
414
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
415
+ `past_key_values`).
416
+ output_attentions (`bool`, *optional*):
417
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
418
+ tensors for more detail.
419
+ output_hidden_states (`bool`, *optional*):
420
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
421
+ more detail.
422
+ return_dict (`bool`, *optional*):
423
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
424
+ """
425
+
426
+
427
+ @add_start_docstrings(
428
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
429
+ LLAMA_START_DOCSTRING,
430
+ )
431
+ class LlamaModel(LlamaPreTrainedModel):
432
+ """
433
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
434
+
435
+ Args:
436
+ config: LlamaConfig
437
+ """
438
+
439
+ def __init__(self, config: LlamaConfig):
440
+ super().__init__(config)
441
+ self.padding_idx = config.pad_token_id
442
+ self.vocab_size = config.vocab_size
443
+
444
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
445
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
446
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
447
+
448
+ self.gradient_checkpointing = False
449
+ # Initialize weights and apply final processing
450
+ self.post_init()
451
+
452
+ def get_input_embeddings(self):
453
+ return self.embed_tokens
454
+
455
+ def set_input_embeddings(self, value):
456
+ self.embed_tokens = value
457
+
458
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
459
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
460
+ # create causal mask
461
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
462
+ combined_attention_mask = None
463
+ if input_shape[-1] > 1:
464
+ combined_attention_mask = _make_causal_mask(
465
+ input_shape,
466
+ inputs_embeds.dtype,
467
+ device=inputs_embeds.device,
468
+ past_key_values_length=past_key_values_length,
469
+ )
470
+
471
+ if attention_mask is not None:
472
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
473
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
474
+ inputs_embeds.device
475
+ )
476
+ combined_attention_mask = (
477
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
478
+ )
479
+
480
+ return combined_attention_mask
481
+
482
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
483
+ def forward(
484
+ self,
485
+ input_ids: torch.LongTensor = None,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
489
+ inputs_embeds: Optional[torch.FloatTensor] = None,
490
+ query_embeds: Optional[torch.FloatTensor] = None,
491
+ use_cache: Optional[bool] = None,
492
+ output_attentions: Optional[bool] = None,
493
+ output_hidden_states: Optional[bool] = None,
494
+ return_dict: Optional[bool] = None,
495
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
496
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
497
+ output_hidden_states = (
498
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
499
+ )
500
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
501
+
502
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
503
+
504
+ # retrieve input_ids and inputs_embeds
505
+ if input_ids is not None and inputs_embeds is not None:
506
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
507
+ elif input_ids is not None:
508
+ batch_size, seq_length = input_ids.shape
509
+ elif inputs_embeds is not None:
510
+ batch_size, seq_length, _ = inputs_embeds.shape
511
+ else:
512
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
513
+
514
+ if inputs_embeds is None:
515
+ inputs_embeds = self.embed_tokens(input_ids)
516
+ if query_embeds is not None:
517
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
518
+ batch_size, seq_length, _ = inputs_embeds.shape
519
+
520
+ seq_length_with_past = seq_length
521
+ past_key_values_length = 0
522
+
523
+ if past_key_values is not None:
524
+ past_key_values_length = past_key_values[0][0].shape[2]
525
+ seq_length_with_past = seq_length_with_past + past_key_values_length
526
+
527
+ if position_ids is None:
528
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
529
+ position_ids = torch.arange(
530
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
531
+ )
532
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
533
+ else:
534
+ position_ids = position_ids.view(-1, seq_length).long()
535
+
536
+ # embed positions
537
+ if attention_mask is None:
538
+ attention_mask = torch.ones(
539
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
540
+ )
541
+ attention_mask = self._prepare_decoder_attention_mask(
542
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
543
+ )
544
+
545
+ hidden_states = inputs_embeds
546
+
547
+ if self.gradient_checkpointing and self.training:
548
+ if use_cache:
549
+ logger.warning_once(
550
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
551
+ )
552
+ use_cache = False
553
+
554
+ # decoder layers
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attns = () if output_attentions else None
557
+ next_decoder_cache = () if use_cache else None
558
+
559
+ for idx, decoder_layer in enumerate(self.layers):
560
+ if output_hidden_states:
561
+ all_hidden_states += (hidden_states,)
562
+
563
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
564
+
565
+ if self.gradient_checkpointing and self.training:
566
+
567
+ def create_custom_forward(module):
568
+ def custom_forward(*inputs):
569
+ # None for past_key_value
570
+ return module(*inputs, output_attentions, None)
571
+
572
+ return custom_forward
573
+
574
+ layer_outputs = torch.utils.checkpoint.checkpoint(
575
+ create_custom_forward(decoder_layer),
576
+ hidden_states,
577
+ attention_mask,
578
+ position_ids,
579
+ None,
580
+ )
581
+ else:
582
+ layer_outputs = decoder_layer(
583
+ hidden_states,
584
+ attention_mask=attention_mask,
585
+ position_ids=position_ids,
586
+ past_key_value=past_key_value,
587
+ output_attentions=output_attentions,
588
+ use_cache=use_cache,
589
+ )
590
+
591
+ hidden_states = layer_outputs[0]
592
+
593
+ if use_cache:
594
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
595
+
596
+ if output_attentions:
597
+ all_self_attns += (layer_outputs[1],)
598
+
599
+ hidden_states = self.norm(hidden_states)
600
+
601
+ # add hidden states from the last decoder layer
602
+ if output_hidden_states:
603
+ all_hidden_states += (hidden_states,)
604
+
605
+ next_cache = next_decoder_cache if use_cache else None
606
+ if not return_dict:
607
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
608
+ return BaseModelOutputWithPast(
609
+ last_hidden_state=hidden_states,
610
+ past_key_values=next_cache,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attns,
613
+ )
614
+
615
+
616
+ class LlamaForCausalLM(LlamaPreTrainedModel):
617
+ def __init__(self, config):
618
+ super().__init__(config)
619
+ self.model = LlamaModel(config)
620
+
621
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
622
+
623
+ # Initialize weights and apply final processing
624
+ self.post_init()
625
+
626
+ def get_input_embeddings(self):
627
+ return self.model.embed_tokens
628
+
629
+ def set_input_embeddings(self, value):
630
+ self.model.embed_tokens = value
631
+
632
+ def get_output_embeddings(self):
633
+ return self.lm_head
634
+
635
+ def set_output_embeddings(self, new_embeddings):
636
+ self.lm_head = new_embeddings
637
+
638
+ def set_decoder(self, decoder):
639
+ self.model = decoder
640
+
641
+ def get_decoder(self):
642
+ return self.model
643
+
644
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
645
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
646
+ def forward(
647
+ self,
648
+ input_ids: torch.LongTensor = None,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ position_ids: Optional[torch.LongTensor] = None,
651
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
652
+ inputs_embeds: Optional[torch.FloatTensor] = None,
653
+ query_embeds: Optional[torch.FloatTensor] = None,
654
+ labels: Optional[torch.LongTensor] = None,
655
+ use_cache: Optional[bool] = None,
656
+ output_attentions: Optional[bool] = None,
657
+ output_hidden_states: Optional[bool] = None,
658
+ return_dict: Optional[bool] = None,
659
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
660
+ r"""
661
+ Args:
662
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
663
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
664
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
665
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
666
+
667
+ Returns:
668
+
669
+ Example:
670
+
671
+ ```python
672
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
673
+
674
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
675
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
676
+
677
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
678
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
679
+
680
+ >>> # Generate
681
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
682
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
683
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
684
+ ```"""
685
+
686
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
687
+ output_hidden_states = (
688
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
689
+ )
690
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691
+
692
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
693
+ outputs = self.model(
694
+ input_ids=input_ids,
695
+ attention_mask=attention_mask,
696
+ position_ids=position_ids,
697
+ past_key_values=past_key_values,
698
+ inputs_embeds=inputs_embeds,
699
+ query_embeds=query_embeds,
700
+ use_cache=use_cache,
701
+ output_attentions=output_attentions,
702
+ output_hidden_states=output_hidden_states,
703
+ return_dict=return_dict,
704
+ )
705
+
706
+ hidden_states = outputs[0]
707
+ logits = self.lm_head(hidden_states)
708
+
709
+ loss = None
710
+ if labels is not None:
711
+ # Shift so that tokens < n predict n
712
+ shift_logits = logits[..., :-1, :].contiguous()
713
+ shift_labels = labels[..., 1:].contiguous()
714
+ # Flatten the tokens
715
+ loss_fct = CrossEntropyLoss()
716
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
717
+ shift_labels = shift_labels.view(-1)
718
+ # Enable model parallelism
719
+ shift_labels = shift_labels.to(shift_logits.device)
720
+ loss = loss_fct(shift_logits, shift_labels)
721
+
722
+ if not return_dict:
723
+ output = (logits,) + outputs[1:]
724
+ return (loss,) + output if loss is not None else output
725
+
726
+ return CausalLMOutputWithPast(
727
+ loss=loss,
728
+ logits=logits,
729
+ past_key_values=outputs.past_key_values,
730
+ hidden_states=outputs.hidden_states,
731
+ attentions=outputs.attentions,
732
+ )
733
+
734
+ def prepare_inputs_for_generation(
735
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
736
+ ):
737
+ if past_key_values:
738
+ input_ids = input_ids[:, -1:]
739
+
740
+ position_ids = kwargs.get("position_ids", None)
741
+ if attention_mask is not None and position_ids is None:
742
+ # create position_ids on the fly for batch generation
743
+ position_ids = attention_mask.long().cumsum(-1) - 1
744
+ position_ids.masked_fill_(attention_mask == 0, 1)
745
+ if past_key_values:
746
+ position_ids = position_ids[:, -1].unsqueeze(-1)
747
+ query_embeds = None
748
+
749
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
750
+ if inputs_embeds is not None and past_key_values is None:
751
+ model_inputs = {"inputs_embeds": inputs_embeds}
752
+ else:
753
+ model_inputs = {"input_ids": input_ids}
754
+
755
+ model_inputs.update(
756
+ {
757
+ "position_ids": position_ids,
758
+ "query_embeds": query_embeds,
759
+ "past_key_values": past_key_values,
760
+ "use_cache": kwargs.get("use_cache"),
761
+ "attention_mask": attention_mask,
762
+ }
763
+ )
764
+ return model_inputs
765
+
766
+ @staticmethod
767
+ def _reorder_cache(past_key_values, beam_idx):
768
+ reordered_past = ()
769
+ for layer_past in past_key_values:
770
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
771
+ return reordered_past
772
+
minigpt4/processors/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.processors.base_processor import BaseProcessor
9
+ from minigpt4.processors.blip_processors import (
10
+ Blip2ImageTrainProcessor,
11
+ Blip2ImageEvalProcessor,
12
+ BlipCaptionProcessor,
13
+ )
14
+
15
+ from minigpt4.common.registry import registry
16
+
17
+ __all__ = [
18
+ "BaseProcessor",
19
+ "Blip2ImageTrainProcessor",
20
+ "Blip2ImageEvalProcessor",
21
+ "BlipCaptionProcessor",
22
+ ]
23
+
24
+
25
+ def load_processor(name, cfg=None):
26
+ """
27
+ Example
28
+
29
+ >>> processor = load_processor("alpro_video_train", cfg=None)
30
+ """
31
+ processor = registry.get_processor_class(name).from_config(cfg)
32
+
33
+ return processor
minigpt4/processors/base_processor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ class BaseProcessor:
12
+ def __init__(self):
13
+ self.transform = lambda x: x
14
+ return
15
+
16
+ def __call__(self, item):
17
+ return self.transform(item)
18
+
19
+ @classmethod
20
+ def from_config(cls, cfg=None):
21
+ return cls()
22
+
23
+ def build(self, **kwargs):
24
+ cfg = OmegaConf.create(kwargs)
25
+
26
+ return self.from_config(cfg)
minigpt4/processors/blip_processors.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import re
9
+
10
+ from minigpt4.common.registry import registry
11
+ from minigpt4.processors.base_processor import BaseProcessor
12
+ from minigpt4.processors.randaugment import RandomAugment
13
+ from omegaconf import OmegaConf
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+
18
+ class BlipImageBaseProcessor(BaseProcessor):
19
+ def __init__(self, mean=None, std=None):
20
+ if mean is None:
21
+ mean = (0.48145466, 0.4578275, 0.40821073)
22
+ if std is None:
23
+ std = (0.26862954, 0.26130258, 0.27577711)
24
+
25
+ self.normalize = transforms.Normalize(mean, std)
26
+
27
+
28
+ @registry.register_processor("blip_caption")
29
+ class BlipCaptionProcessor(BaseProcessor):
30
+ def __init__(self, prompt="", max_words=50):
31
+ self.prompt = prompt
32
+ self.max_words = max_words
33
+
34
+ def __call__(self, caption):
35
+ caption = self.prompt + self.pre_caption(caption)
36
+
37
+ return caption
38
+
39
+ @classmethod
40
+ def from_config(cls, cfg=None):
41
+ if cfg is None:
42
+ cfg = OmegaConf.create()
43
+
44
+ prompt = cfg.get("prompt", "")
45
+ max_words = cfg.get("max_words", 50)
46
+
47
+ return cls(prompt=prompt, max_words=max_words)
48
+
49
+ def pre_caption(self, caption):
50
+ caption = re.sub(
51
+ r"([.!\"()*#:;~])",
52
+ " ",
53
+ caption.lower(),
54
+ )
55
+ caption = re.sub(
56
+ r"\s{2,}",
57
+ " ",
58
+ caption,
59
+ )
60
+ caption = caption.rstrip("\n")
61
+ caption = caption.strip(" ")
62
+
63
+ # truncate caption
64
+ caption_words = caption.split(" ")
65
+ if len(caption_words) > self.max_words:
66
+ caption = " ".join(caption_words[: self.max_words])
67
+
68
+ return caption
69
+
70
+
71
+ @registry.register_processor("blip2_image_train")
72
+ class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
73
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
74
+ super().__init__(mean=mean, std=std)
75
+
76
+ self.transform = transforms.Compose(
77
+ [
78
+ transforms.RandomResizedCrop(
79
+ image_size,
80
+ scale=(min_scale, max_scale),
81
+ interpolation=InterpolationMode.BICUBIC,
82
+ ),
83
+ transforms.ToTensor(),
84
+ self.normalize,
85
+ ]
86
+ )
87
+
88
+ def __call__(self, item):
89
+ return self.transform(item)
90
+
91
+ @classmethod
92
+ def from_config(cls, cfg=None):
93
+ if cfg is None:
94
+ cfg = OmegaConf.create()
95
+
96
+ image_size = cfg.get("image_size", 224)
97
+
98
+ mean = cfg.get("mean", None)
99
+ std = cfg.get("std", None)
100
+
101
+ min_scale = cfg.get("min_scale", 0.5)
102
+ max_scale = cfg.get("max_scale", 1.0)
103
+
104
+ return cls(
105
+ image_size=image_size,
106
+ mean=mean,
107
+ std=std,
108
+ min_scale=min_scale,
109
+ max_scale=max_scale,
110
+ )
111
+
112
+
113
+ @registry.register_processor("blip2_image_eval")
114
+ class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
115
+ def __init__(self, image_size=224, mean=None, std=None):
116
+ super().__init__(mean=mean, std=std)
117
+
118
+ self.transform = transforms.Compose(
119
+ [
120
+ transforms.Resize(
121
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
122
+ ),
123
+ transforms.ToTensor(),
124
+ self.normalize,
125
+ ]
126
+ )
127
+
128
+ def __call__(self, item):
129
+ return self.transform(item)
130
+
131
+ @classmethod
132
+ def from_config(cls, cfg=None):
133
+ if cfg is None:
134
+ cfg = OmegaConf.create()
135
+
136
+ image_size = cfg.get("image_size", 224)
137
+
138
+ mean = cfg.get("mean", None)
139
+ std = cfg.get("std", None)
140
+
141
+ return cls(image_size=image_size, mean=mean, std=std)
minigpt4/processors/randaugment.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch
12
+
13
+
14
+ ## aug functions
15
+ def identity_func(img):
16
+ return img
17
+
18
+
19
+ def autocontrast_func(img, cutoff=0):
20
+ """
21
+ same output as PIL.ImageOps.autocontrast
22
+ """
23
+ n_bins = 256
24
+
25
+ def tune_channel(ch):
26
+ n = ch.size
27
+ cut = cutoff * n // 100
28
+ if cut == 0:
29
+ high, low = ch.max(), ch.min()
30
+ else:
31
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32
+ low = np.argwhere(np.cumsum(hist) > cut)
33
+ low = 0 if low.shape[0] == 0 else low[0]
34
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36
+ if high <= low:
37
+ table = np.arange(n_bins)
38
+ else:
39
+ scale = (n_bins - 1) / (high - low)
40
+ offset = -low * scale
41
+ table = np.arange(n_bins) * scale + offset
42
+ table[table < 0] = 0
43
+ table[table > n_bins - 1] = n_bins - 1
44
+ table = table.clip(0, 255).astype(np.uint8)
45
+ return table[ch]
46
+
47
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
48
+ out = cv2.merge(channels)
49
+ return out
50
+
51
+
52
+ def equalize_func(img):
53
+ """
54
+ same output as PIL.ImageOps.equalize
55
+ PIL's implementation is different from cv2.equalize
56
+ """
57
+ n_bins = 256
58
+
59
+ def tune_channel(ch):
60
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61
+ non_zero_hist = hist[hist != 0].reshape(-1)
62
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63
+ if step == 0:
64
+ return ch
65
+ n = np.empty_like(hist)
66
+ n[0] = step // 2
67
+ n[1:] = hist[:-1]
68
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69
+ return table[ch]
70
+
71
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
72
+ out = cv2.merge(channels)
73
+ return out
74
+
75
+
76
+ def rotate_func(img, degree, fill=(0, 0, 0)):
77
+ """
78
+ like PIL, rotate by degree, not radians
79
+ """
80
+ H, W = img.shape[0], img.shape[1]
81
+ center = W / 2, H / 2
82
+ M = cv2.getRotationMatrix2D(center, degree, 1)
83
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84
+ return out
85
+
86
+
87
+ def solarize_func(img, thresh=128):
88
+ """
89
+ same output as PIL.ImageOps.posterize
90
+ """
91
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
92
+ table = table.clip(0, 255).astype(np.uint8)
93
+ out = table[img]
94
+ return out
95
+
96
+
97
+ def color_func(img, factor):
98
+ """
99
+ same output as PIL.ImageEnhance.Color
100
+ """
101
+ ## implementation according to PIL definition, quite slow
102
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103
+ # out = blend(degenerate, img, factor)
104
+ # M = (
105
+ # np.eye(3) * factor
106
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107
+ # )[np.newaxis, np.newaxis, :]
108
+ M = np.float32(
109
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
111
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112
+ return out
113
+
114
+
115
+ def contrast_func(img, factor):
116
+ """
117
+ same output as PIL.ImageEnhance.Contrast
118
+ """
119
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120
+ table = (
121
+ np.array([(el - mean) * factor + mean for el in range(256)])
122
+ .clip(0, 255)
123
+ .astype(np.uint8)
124
+ )
125
+ out = table[img]
126
+ return out
127
+
128
+
129
+ def brightness_func(img, factor):
130
+ """
131
+ same output as PIL.ImageEnhance.Contrast
132
+ """
133
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134
+ out = table[img]
135
+ return out
136
+
137
+
138
+ def sharpness_func(img, factor):
139
+ """
140
+ The differences the this result and PIL are all on the 4 boundaries, the center
141
+ areas are same
142
+ """
143
+ kernel = np.ones((3, 3), dtype=np.float32)
144
+ kernel[1][1] = 5
145
+ kernel /= 13
146
+ degenerate = cv2.filter2D(img, -1, kernel)
147
+ if factor == 0.0:
148
+ out = degenerate
149
+ elif factor == 1.0:
150
+ out = img
151
+ else:
152
+ out = img.astype(np.float32)
153
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155
+ out = out.astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
160
+ H, W = img.shape[0], img.shape[1]
161
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
162
+ out = cv2.warpAffine(
163
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164
+ ).astype(np.uint8)
165
+ return out
166
+
167
+
168
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
169
+ """
170
+ same output as PIL.Image.transform
171
+ """
172
+ H, W = img.shape[0], img.shape[1]
173
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
174
+ out = cv2.warpAffine(
175
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176
+ ).astype(np.uint8)
177
+ return out
178
+
179
+
180
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
181
+ """
182
+ same output as PIL.Image.transform
183
+ """
184
+ H, W = img.shape[0], img.shape[1]
185
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
186
+ out = cv2.warpAffine(
187
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188
+ ).astype(np.uint8)
189
+ return out
190
+
191
+
192
+ def posterize_func(img, bits):
193
+ """
194
+ same output as PIL.ImageOps.posterize
195
+ """
196
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197
+ return out
198
+
199
+
200
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
201
+ H, W = img.shape[0], img.shape[1]
202
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
203
+ out = cv2.warpAffine(
204
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205
+ ).astype(np.uint8)
206
+ return out
207
+
208
+
209
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
210
+ replace = np.array(replace, dtype=np.uint8)
211
+ H, W = img.shape[0], img.shape[1]
212
+ rh, rw = np.random.random(2)
213
+ pad_size = pad_size // 2
214
+ ch, cw = int(rh * H), int(rw * W)
215
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217
+ out = img.copy()
218
+ out[x1:x2, y1:y2, :] = replace
219
+ return out
220
+
221
+
222
+ ### level to args
223
+ def enhance_level_to_args(MAX_LEVEL):
224
+ def level_to_args(level):
225
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226
+
227
+ return level_to_args
228
+
229
+
230
+ def shear_level_to_args(MAX_LEVEL, replace_value):
231
+ def level_to_args(level):
232
+ level = (level / MAX_LEVEL) * 0.3
233
+ if np.random.random() > 0.5:
234
+ level = -level
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241
+ def level_to_args(level):
242
+ level = (level / MAX_LEVEL) * float(translate_const)
243
+ if np.random.random() > 0.5:
244
+ level = -level
245
+ return (level, replace_value)
246
+
247
+ return level_to_args
248
+
249
+
250
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251
+ def level_to_args(level):
252
+ level = int((level / MAX_LEVEL) * cutout_const)
253
+ return (level, replace_value)
254
+
255
+ return level_to_args
256
+
257
+
258
+ def solarize_level_to_args(MAX_LEVEL):
259
+ def level_to_args(level):
260
+ level = int((level / MAX_LEVEL) * 256)
261
+ return (level,)
262
+
263
+ return level_to_args
264
+
265
+
266
+ def none_level_to_args(level):
267
+ return ()
268
+
269
+
270
+ def posterize_level_to_args(MAX_LEVEL):
271
+ def level_to_args(level):
272
+ level = int((level / MAX_LEVEL) * 4)
273
+ return (level,)
274
+
275
+ return level_to_args
276
+
277
+
278
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
279
+ def level_to_args(level):
280
+ level = (level / MAX_LEVEL) * 30
281
+ if np.random.random() < 0.5:
282
+ level = -level
283
+ return (level, replace_value)
284
+
285
+ return level_to_args
286
+
287
+
288
+ func_dict = {
289
+ "Identity": identity_func,
290
+ "AutoContrast": autocontrast_func,
291
+ "Equalize": equalize_func,
292
+ "Rotate": rotate_func,
293
+ "Solarize": solarize_func,
294
+ "Color": color_func,
295
+ "Contrast": contrast_func,
296
+ "Brightness": brightness_func,
297
+ "Sharpness": sharpness_func,
298
+ "ShearX": shear_x_func,
299
+ "TranslateX": translate_x_func,
300
+ "TranslateY": translate_y_func,
301
+ "Posterize": posterize_func,
302
+ "ShearY": shear_y_func,
303
+ }
304
+
305
+ translate_const = 10
306
+ MAX_LEVEL = 10
307
+ replace_value = (128, 128, 128)
308
+ arg_dict = {
309
+ "Identity": none_level_to_args,
310
+ "AutoContrast": none_level_to_args,
311
+ "Equalize": none_level_to_args,
312
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
314
+ "Color": enhance_level_to_args(MAX_LEVEL),
315
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
316
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
317
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
318
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
322
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323
+ }
324
+
325
+
326
+ class RandomAugment(object):
327
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328
+ self.N = N
329
+ self.M = M
330
+ self.isPIL = isPIL
331
+ if augs:
332
+ self.augs = augs
333
+ else:
334
+ self.augs = list(arg_dict.keys())
335
+
336
+ def get_random_ops(self):
337
+ sampled_ops = np.random.choice(self.augs, self.N)
338
+ return [(op, 0.5, self.M) for op in sampled_ops]
339
+
340
+ def __call__(self, img):
341
+ if self.isPIL:
342
+ img = np.array(img)
343
+ ops = self.get_random_ops()
344
+ for name, prob, level in ops:
345
+ if np.random.random() > prob:
346
+ continue
347
+ args = arg_dict[name](level)
348
+ img = func_dict[name](img, *args)
349
+ return img
350
+
351
+
352
+ class VideoRandomAugment(object):
353
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354
+ self.N = N
355
+ self.M = M
356
+ self.p = p
357
+ self.tensor_in_tensor_out = tensor_in_tensor_out
358
+ if augs:
359
+ self.augs = augs
360
+ else:
361
+ self.augs = list(arg_dict.keys())
362
+
363
+ def get_random_ops(self):
364
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365
+ return [(op, self.M) for op in sampled_ops]
366
+
367
+ def __call__(self, frames):
368
+ assert (
369
+ frames.shape[-1] == 3
370
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371
+
372
+ if self.tensor_in_tensor_out:
373
+ frames = frames.numpy().astype(np.uint8)
374
+
375
+ num_frames = frames.shape[0]
376
+
377
+ ops = num_frames * [self.get_random_ops()]
378
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379
+
380
+ frames = torch.stack(
381
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
382
+ ).float()
383
+
384
+ return frames
385
+
386
+ def _aug(self, img, ops, apply_or_not):
387
+ for i, (name, level) in enumerate(ops):
388
+ if not apply_or_not[i]:
389
+ continue
390
+ args = arg_dict[name](level)
391
+ img = func_dict[name](img, *args)
392
+ return torch.from_numpy(img)
393
+
394
+
395
+ if __name__ == "__main__":
396
+ a = RandomAugment()
397
+ img = np.random.randn(32, 32, 3)
398
+ a(img)
minigpt4/runners/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.runners.runner_base import RunnerBase
9
+
10
+ __all__ = ["RunnerBase"]
minigpt4/runners/runner_base.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import json
10
+ import logging
11
+ import os
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ import webdataset as wds
18
+ from minigpt4.common.dist_utils import (
19
+ download_cached_file,
20
+ get_rank,
21
+ get_world_size,
22
+ is_main_process,
23
+ main_process,
24
+ )
25
+ from minigpt4.common.registry import registry
26
+ from minigpt4.common.utils import is_url
27
+ from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
28
+ from minigpt4.datasets.datasets.dataloader_utils import (
29
+ IterLoader,
30
+ MultiIterLoader,
31
+ PrefetchLoader,
32
+ )
33
+ from torch.nn.parallel import DistributedDataParallel as DDP
34
+ from torch.utils.data import DataLoader, DistributedSampler
35
+
36
+
37
+ @registry.register_runner("runner_base")
38
+ class RunnerBase:
39
+ """
40
+ A runner class to train and evaluate a model given a task and datasets.
41
+
42
+ The runner uses pytorch distributed data parallel by default. Future release
43
+ will support other distributed frameworks.
44
+ """
45
+
46
+ def __init__(self, cfg, task, model, datasets, job_id):
47
+ self.config = cfg
48
+ self.job_id = job_id
49
+
50
+ self.task = task
51
+ self.datasets = datasets
52
+
53
+ self._model = model
54
+
55
+ self._wrapped_model = None
56
+ self._device = None
57
+ self._optimizer = None
58
+ self._scaler = None
59
+ self._dataloaders = None
60
+ self._lr_sched = None
61
+
62
+ self.start_epoch = 0
63
+
64
+ # self.setup_seeds()
65
+ self.setup_output_dir()
66
+
67
+ @property
68
+ def device(self):
69
+ if self._device is None:
70
+ self._device = torch.device(self.config.run_cfg.device)
71
+
72
+ return self._device
73
+
74
+ @property
75
+ def use_distributed(self):
76
+ return self.config.run_cfg.distributed
77
+
78
+ @property
79
+ def model(self):
80
+ """
81
+ A property to get the DDP-wrapped model on the device.
82
+ """
83
+ # move model to device
84
+ if self._model.device != self.device:
85
+ self._model = self._model.to(self.device)
86
+
87
+ # distributed training wrapper
88
+ if self.use_distributed:
89
+ if self._wrapped_model is None:
90
+ self._wrapped_model = DDP(
91
+ self._model, device_ids=[self.config.run_cfg.gpu]
92
+ )
93
+ else:
94
+ self._wrapped_model = self._model
95
+
96
+ return self._wrapped_model
97
+
98
+ @property
99
+ def optimizer(self):
100
+ # TODO make optimizer class and configurations
101
+ if self._optimizer is None:
102
+ num_parameters = 0
103
+ p_wd, p_non_wd = [], []
104
+ for n, p in self.model.named_parameters():
105
+ if not p.requires_grad:
106
+ continue # frozen weights
107
+ print(n)
108
+ if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
109
+ p_non_wd.append(p)
110
+ else:
111
+ p_wd.append(p)
112
+ num_parameters += p.data.nelement()
113
+ logging.info("number of trainable parameters: %d" % num_parameters)
114
+ optim_params = [
115
+ {
116
+ "params": p_wd,
117
+ "weight_decay": float(self.config.run_cfg.weight_decay),
118
+ },
119
+ {"params": p_non_wd, "weight_decay": 0},
120
+ ]
121
+ beta2 = self.config.run_cfg.get("beta2", 0.999)
122
+ self._optimizer = torch.optim.AdamW(
123
+ optim_params,
124
+ lr=float(self.config.run_cfg.init_lr),
125
+ weight_decay=float(self.config.run_cfg.weight_decay),
126
+ betas=(0.9, beta2),
127
+ )
128
+
129
+ return self._optimizer
130
+
131
+ @property
132
+ def scaler(self):
133
+ amp = self.config.run_cfg.get("amp", False)
134
+
135
+ if amp:
136
+ if self._scaler is None:
137
+ self._scaler = torch.cuda.amp.GradScaler()
138
+
139
+ return self._scaler
140
+
141
+ @property
142
+ def lr_scheduler(self):
143
+ """
144
+ A property to get and create learning rate scheduler by split just in need.
145
+ """
146
+ if self._lr_sched is None:
147
+ lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
148
+
149
+ # max_epoch = self.config.run_cfg.max_epoch
150
+ max_epoch = self.max_epoch
151
+ # min_lr = self.config.run_cfg.min_lr
152
+ min_lr = self.min_lr
153
+ # init_lr = self.config.run_cfg.init_lr
154
+ init_lr = self.init_lr
155
+
156
+ # optional parameters
157
+ decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
158
+ warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
159
+ warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
160
+ iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None)
161
+
162
+ if iters_per_epoch is None:
163
+ try:
164
+ iters_per_epoch = len(self.dataloaders['train'])
165
+ except (AttributeError, TypeError):
166
+ iters_per_epoch = 10000
167
+
168
+ self._lr_sched = lr_sched_cls(
169
+ optimizer=self.optimizer,
170
+ max_epoch=max_epoch,
171
+ iters_per_epoch=iters_per_epoch,
172
+ min_lr=min_lr,
173
+ init_lr=init_lr,
174
+ decay_rate=decay_rate,
175
+ warmup_start_lr=warmup_start_lr,
176
+ warmup_steps=warmup_steps,
177
+ )
178
+
179
+ return self._lr_sched
180
+
181
+ @property
182
+ def dataloaders(self) -> dict:
183
+ """
184
+ A property to get and create dataloaders by split just in need.
185
+
186
+ If no train_dataset_ratio is provided, concatenate map-style datasets and
187
+ chain wds.DataPipe datasets separately. Training set becomes a tuple
188
+ (ConcatDataset, ChainDataset), both are optional but at least one of them is
189
+ required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
190
+
191
+ If train_dataset_ratio is provided, create a MultiIterLoader to sample
192
+ each dataset by ratios during training.
193
+
194
+ Currently do not support multiple datasets for validation and test.
195
+
196
+ Returns:
197
+ dict: {split_name: (tuples of) dataloader}
198
+ """
199
+ if self._dataloaders is None:
200
+
201
+ # concatenate map-style datasets and chain wds.DataPipe datasets separately
202
+ # training set becomes a tuple (ConcatDataset, ChainDataset), both are
203
+ # optional but at least one of them is required. The resultant ConcatDataset
204
+ # and ChainDataset will be sampled evenly.
205
+ logging.info(
206
+ "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
207
+ )
208
+
209
+ datasets = reorg_datasets_by_split(self.datasets)
210
+ self.datasets = datasets
211
+ # self.datasets = concat_datasets(datasets)
212
+
213
+ # print dataset statistics after concatenation/chaining
214
+ for split_name in self.datasets:
215
+ if isinstance(self.datasets[split_name], tuple) or isinstance(
216
+ self.datasets[split_name], list
217
+ ):
218
+ # mixed wds.DataPipeline and torch.utils.data.Dataset
219
+ num_records = sum(
220
+ [
221
+ len(d)
222
+ if not type(d) in [wds.DataPipeline, ChainDataset]
223
+ else 0
224
+ for d in self.datasets[split_name]
225
+ ]
226
+ )
227
+
228
+ else:
229
+ if hasattr(self.datasets[split_name], "__len__"):
230
+ # a single map-style dataset
231
+ num_records = len(self.datasets[split_name])
232
+ else:
233
+ # a single wds.DataPipeline
234
+ num_records = -1
235
+ logging.info(
236
+ "Only a single wds.DataPipeline dataset, no __len__ attribute."
237
+ )
238
+
239
+ if num_records >= 0:
240
+ logging.info(
241
+ "Loaded {} records for {} split from the dataset.".format(
242
+ num_records, split_name
243
+ )
244
+ )
245
+
246
+ # create dataloaders
247
+ split_names = sorted(self.datasets.keys())
248
+
249
+ datasets = [self.datasets[split] for split in split_names]
250
+ is_trains = [split in self.train_splits for split in split_names]
251
+
252
+ batch_sizes = [
253
+ self.config.run_cfg.batch_size_train
254
+ if split == "train"
255
+ else self.config.run_cfg.batch_size_eval
256
+ for split in split_names
257
+ ]
258
+
259
+ collate_fns = []
260
+ for dataset in datasets:
261
+ if isinstance(dataset, tuple) or isinstance(dataset, list):
262
+ collate_fns.append([getattr(d, "collater", None) for d in dataset])
263
+ else:
264
+ collate_fns.append(getattr(dataset, "collater", None))
265
+
266
+ dataloaders = self.create_loaders(
267
+ datasets=datasets,
268
+ num_workers=self.config.run_cfg.num_workers,
269
+ batch_sizes=batch_sizes,
270
+ is_trains=is_trains,
271
+ collate_fns=collate_fns,
272
+ )
273
+
274
+ self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
275
+
276
+ return self._dataloaders
277
+
278
+ @property
279
+ def cuda_enabled(self):
280
+ return self.device.type == "cuda"
281
+
282
+ @property
283
+ def max_epoch(self):
284
+ return int(self.config.run_cfg.max_epoch)
285
+
286
+ @property
287
+ def log_freq(self):
288
+ log_freq = self.config.run_cfg.get("log_freq", 50)
289
+ return int(log_freq)
290
+
291
+ @property
292
+ def init_lr(self):
293
+ return float(self.config.run_cfg.init_lr)
294
+
295
+ @property
296
+ def min_lr(self):
297
+ return float(self.config.run_cfg.min_lr)
298
+
299
+ @property
300
+ def accum_grad_iters(self):
301
+ return int(self.config.run_cfg.get("accum_grad_iters", 1))
302
+
303
+ @property
304
+ def valid_splits(self):
305
+ valid_splits = self.config.run_cfg.get("valid_splits", [])
306
+
307
+ if len(valid_splits) == 0:
308
+ logging.info("No validation splits found.")
309
+
310
+ return valid_splits
311
+
312
+ @property
313
+ def test_splits(self):
314
+ test_splits = self.config.run_cfg.get("test_splits", [])
315
+
316
+ return test_splits
317
+
318
+ @property
319
+ def train_splits(self):
320
+ train_splits = self.config.run_cfg.get("train_splits", [])
321
+
322
+ if len(train_splits) == 0:
323
+ logging.info("Empty train splits.")
324
+
325
+ return train_splits
326
+
327
+ @property
328
+ def evaluate_only(self):
329
+ """
330
+ Set to True to skip training.
331
+ """
332
+ return self.config.run_cfg.evaluate
333
+
334
+ @property
335
+ def use_dist_eval_sampler(self):
336
+ return self.config.run_cfg.get("use_dist_eval_sampler", True)
337
+
338
+ @property
339
+ def resume_ckpt_path(self):
340
+ return self.config.run_cfg.get("resume_ckpt_path", None)
341
+
342
+ @property
343
+ def train_loader(self):
344
+ train_dataloader = self.dataloaders["train"]
345
+
346
+ return train_dataloader
347
+
348
+ def setup_output_dir(self):
349
+ lib_root = Path(registry.get_path("library_root"))
350
+
351
+ output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
352
+ result_dir = output_dir / "result"
353
+
354
+ output_dir.mkdir(parents=True, exist_ok=True)
355
+ result_dir.mkdir(parents=True, exist_ok=True)
356
+
357
+ registry.register_path("result_dir", str(result_dir))
358
+ registry.register_path("output_dir", str(output_dir))
359
+
360
+ self.result_dir = result_dir
361
+ self.output_dir = output_dir
362
+
363
+ def train(self):
364
+ start_time = time.time()
365
+ best_agg_metric = 0
366
+ best_epoch = 0
367
+
368
+ self.log_config()
369
+
370
+ # resume from checkpoint if specified
371
+ if not self.evaluate_only and self.resume_ckpt_path is not None:
372
+ self._load_checkpoint(self.resume_ckpt_path)
373
+
374
+ for cur_epoch in range(self.start_epoch, self.max_epoch):
375
+ # training phase
376
+ if not self.evaluate_only:
377
+ logging.info("Start training")
378
+ train_stats = self.train_epoch(cur_epoch)
379
+ self.log_stats(split_name="train", stats=train_stats)
380
+
381
+ # evaluation phase
382
+ if len(self.valid_splits) > 0:
383
+ for split_name in self.valid_splits:
384
+ logging.info("Evaluating on {}.".format(split_name))
385
+
386
+ val_log = self.eval_epoch(
387
+ split_name=split_name, cur_epoch=cur_epoch
388
+ )
389
+ if val_log is not None:
390
+ if is_main_process():
391
+ assert (
392
+ "agg_metrics" in val_log
393
+ ), "No agg_metrics found in validation log."
394
+
395
+ agg_metrics = val_log["agg_metrics"]
396
+ if agg_metrics > best_agg_metric and split_name == "val":
397
+ best_epoch, best_agg_metric = cur_epoch, agg_metrics
398
+
399
+ self._save_checkpoint(cur_epoch, is_best=True)
400
+
401
+ val_log.update({"best_epoch": best_epoch})
402
+ self.log_stats(val_log, split_name)
403
+
404
+ else:
405
+ # if no validation split is provided, we just save the checkpoint at the end of each epoch.
406
+ if not self.evaluate_only:
407
+ self._save_checkpoint(cur_epoch, is_best=False)
408
+
409
+ if self.evaluate_only:
410
+ break
411
+
412
+ if self.config.run_cfg.distributed:
413
+ dist.barrier()
414
+
415
+ # testing phase
416
+ test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
417
+ self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
418
+
419
+ total_time = time.time() - start_time
420
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
421
+ logging.info("Training time {}".format(total_time_str))
422
+
423
+ def evaluate(self, cur_epoch="best", skip_reload=False):
424
+ test_logs = dict()
425
+
426
+ if len(self.test_splits) > 0:
427
+ for split_name in self.test_splits:
428
+ test_logs[split_name] = self.eval_epoch(
429
+ split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
430
+ )
431
+
432
+ return test_logs
433
+
434
+ def train_epoch(self, epoch):
435
+ # train
436
+ self.model.train()
437
+
438
+ return self.task.train_epoch(
439
+ epoch=epoch,
440
+ model=self.model,
441
+ data_loader=self.train_loader,
442
+ optimizer=self.optimizer,
443
+ scaler=self.scaler,
444
+ lr_scheduler=self.lr_scheduler,
445
+ cuda_enabled=self.cuda_enabled,
446
+ log_freq=self.log_freq,
447
+ accum_grad_iters=self.accum_grad_iters,
448
+ )
449
+
450
+ @torch.no_grad()
451
+ def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
452
+ """
453
+ Evaluate the model on a given split.
454
+
455
+ Args:
456
+ split_name (str): name of the split to evaluate on.
457
+ cur_epoch (int): current epoch.
458
+ skip_reload_best (bool): whether to skip reloading the best checkpoint.
459
+ During training, we will reload the best checkpoint for validation.
460
+ During testing, we will use provided weights and skip reloading the best checkpoint .
461
+ """
462
+ data_loader = self.dataloaders.get(split_name, None)
463
+ assert data_loader, "data_loader for split {} is None.".format(split_name)
464
+
465
+ # TODO In validation, you need to compute loss as well as metrics
466
+ # TODO consider moving to model.before_evaluation()
467
+ model = self.unwrap_dist_model(self.model)
468
+ if not skip_reload and cur_epoch == "best":
469
+ model = self._reload_best_model(model)
470
+ model.eval()
471
+
472
+ self.task.before_evaluation(
473
+ model=model,
474
+ dataset=self.datasets[split_name],
475
+ )
476
+ results = self.task.evaluation(model, data_loader)
477
+
478
+ if results is not None:
479
+ return self.task.after_evaluation(
480
+ val_result=results,
481
+ split_name=split_name,
482
+ epoch=cur_epoch,
483
+ )
484
+
485
+ def unwrap_dist_model(self, model):
486
+ if self.use_distributed:
487
+ return model.module
488
+ else:
489
+ return model
490
+
491
+ def create_loaders(
492
+ self,
493
+ datasets,
494
+ num_workers,
495
+ batch_sizes,
496
+ is_trains,
497
+ collate_fns,
498
+ dataset_ratios=None,
499
+ ):
500
+ """
501
+ Create dataloaders for training and validation.
502
+ """
503
+
504
+ def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
505
+ # create a single dataloader for each split
506
+ if isinstance(dataset, ChainDataset) or isinstance(
507
+ dataset, wds.DataPipeline
508
+ ):
509
+ # wds.WebdDataset instance are chained together
510
+ # webdataset.DataPipeline has its own sampler and collate_fn
511
+ loader = iter(
512
+ DataLoader(
513
+ dataset,
514
+ batch_size=bsz,
515
+ num_workers=num_workers,
516
+ pin_memory=True,
517
+ )
518
+ )
519
+ else:
520
+ # map-style dataset are concatenated together
521
+ # setup distributed sampler
522
+ if self.use_distributed:
523
+ sampler = DistributedSampler(
524
+ dataset,
525
+ shuffle=is_train,
526
+ num_replicas=get_world_size(),
527
+ rank=get_rank(),
528
+ )
529
+ if not self.use_dist_eval_sampler:
530
+ # e.g. retrieval evaluation
531
+ sampler = sampler if is_train else None
532
+ else:
533
+ sampler = None
534
+
535
+ loader = DataLoader(
536
+ dataset,
537
+ batch_size=bsz,
538
+ num_workers=num_workers,
539
+ pin_memory=True,
540
+ sampler=sampler,
541
+ shuffle=sampler is None and is_train,
542
+ collate_fn=collate_fn,
543
+ drop_last=True if is_train else False,
544
+ )
545
+ loader = PrefetchLoader(loader)
546
+
547
+ if is_train:
548
+ loader = IterLoader(loader, use_distributed=self.use_distributed)
549
+
550
+ return loader
551
+
552
+ loaders = []
553
+
554
+ for dataset, bsz, is_train, collate_fn in zip(
555
+ datasets, batch_sizes, is_trains, collate_fns
556
+ ):
557
+ if isinstance(dataset, list) or isinstance(dataset, tuple):
558
+ if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:
559
+ dataset_ratios = [d.sample_ratio for d in dataset]
560
+ loader = MultiIterLoader(
561
+ loaders=[
562
+ _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
563
+ for i, d in enumerate(dataset)
564
+ ],
565
+ ratios=dataset_ratios,
566
+ )
567
+ else:
568
+ loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
569
+
570
+ loaders.append(loader)
571
+
572
+ return loaders
573
+
574
+ @main_process
575
+ def _save_checkpoint(self, cur_epoch, is_best=False):
576
+ """
577
+ Save the checkpoint at the current epoch.
578
+ """
579
+ model_no_ddp = self.unwrap_dist_model(self.model)
580
+ param_grad_dic = {
581
+ k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
582
+ }
583
+ state_dict = model_no_ddp.state_dict()
584
+ for k in list(state_dict.keys()):
585
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
586
+ # delete parameters that do not require gradient
587
+ del state_dict[k]
588
+ save_obj = {
589
+ "model": state_dict,
590
+ "optimizer": self.optimizer.state_dict(),
591
+ "config": self.config.to_dict(),
592
+ "scaler": self.scaler.state_dict() if self.scaler else None,
593
+ "epoch": cur_epoch,
594
+ }
595
+ save_to = os.path.join(
596
+ self.output_dir,
597
+ "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
598
+ )
599
+ logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
600
+ torch.save(save_obj, save_to)
601
+
602
+ def _reload_best_model(self, model):
603
+ """
604
+ Load the best checkpoint for evaluation.
605
+ """
606
+ checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
607
+
608
+ logging.info("Loading checkpoint from {}.".format(checkpoint_path))
609
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
610
+ try:
611
+ model.load_state_dict(checkpoint["model"])
612
+ except RuntimeError as e:
613
+ logging.warning(
614
+ """
615
+ Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
616
+ Trying to load the model with strict=False.
617
+ """
618
+ )
619
+ model.load_state_dict(checkpoint["model"], strict=False)
620
+ return model
621
+
622
+ def _load_checkpoint(self, url_or_filename):
623
+ """
624
+ Resume from a checkpoint.
625
+ """
626
+ if is_url(url_or_filename):
627
+ cached_file = download_cached_file(
628
+ url_or_filename, check_hash=False, progress=True
629
+ )
630
+ checkpoint = torch.load(cached_file, map_location=self.device, strict=False)
631
+ elif os.path.isfile(url_or_filename):
632
+ checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False)
633
+ else:
634
+ raise RuntimeError("checkpoint url or path is invalid")
635
+
636
+ state_dict = checkpoint["model"]
637
+ self.unwrap_dist_model(self.model).load_state_dict(state_dict)
638
+
639
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
640
+ if self.scaler and "scaler" in checkpoint:
641
+ self.scaler.load_state_dict(checkpoint["scaler"])
642
+
643
+ self.start_epoch = checkpoint["epoch"] + 1
644
+ logging.info("Resume checkpoint from {}".format(url_or_filename))
645
+
646
+ @main_process
647
+ def log_stats(self, stats, split_name):
648
+ if isinstance(stats, dict):
649
+ log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
650
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
651
+ f.write(json.dumps(log_stats) + "\n")
652
+ elif isinstance(stats, list):
653
+ pass
654
+
655
+ @main_process
656
+ def log_config(self):
657
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
658
+ f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
minigpt4/tasks/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.common.registry import registry
9
+ from minigpt4.tasks.base_task import BaseTask
10
+ from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask
11
+
12
+
13
+ def setup_task(cfg):
14
+ assert "task" in cfg.run_cfg, "Task name must be provided."
15
+
16
+ task_name = cfg.run_cfg.task
17
+ task = registry.get_task_class(task_name).setup_task(cfg=cfg)
18
+ assert task is not None, "Task {} not properly registered.".format(task_name)
19
+
20
+ return task
21
+
22
+
23
+ __all__ = [
24
+ "BaseTask",
25
+ "ImageTextPretrainTask",
26
+ ]
minigpt4/tasks/base_task.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
14
+ from minigpt4.common.logger import MetricLogger, SmoothedValue
15
+ from minigpt4.common.registry import registry
16
+ from minigpt4.datasets.data_utils import prepare_sample
17
+
18
+
19
+ class BaseTask:
20
+ def __init__(self, **kwargs):
21
+ super().__init__()
22
+
23
+ self.inst_id_key = "instance_id"
24
+
25
+ @classmethod
26
+ def setup_task(cls, **kwargs):
27
+ return cls()
28
+
29
+ def build_model(self, cfg):
30
+ model_config = cfg.model_cfg
31
+
32
+ model_cls = registry.get_model_class(model_config.arch)
33
+ return model_cls.from_config(model_config)
34
+
35
+ def build_datasets(self, cfg):
36
+ """
37
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
38
+ Download dataset and annotations automatically if not exist.
39
+
40
+ Args:
41
+ cfg (common.config.Config): _description_
42
+
43
+ Returns:
44
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
45
+ """
46
+
47
+ datasets = dict()
48
+
49
+ datasets_config = cfg.datasets_cfg
50
+
51
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
52
+
53
+ for name in datasets_config:
54
+ dataset_config = datasets_config[name]
55
+
56
+ builder = registry.get_builder_class(name)(dataset_config)
57
+ dataset = builder.build_datasets()
58
+
59
+ dataset['train'].name = name
60
+ if 'sample_ratio' in dataset_config:
61
+ dataset['train'].sample_ratio = dataset_config.sample_ratio
62
+
63
+ datasets[name] = dataset
64
+
65
+ return datasets
66
+
67
+ def train_step(self, model, samples):
68
+ loss = model(samples)["loss"]
69
+ return loss
70
+
71
+ def valid_step(self, model, samples):
72
+ raise NotImplementedError
73
+
74
+ def before_evaluation(self, model, dataset, **kwargs):
75
+ model.before_evaluation(dataset=dataset, task_type=type(self))
76
+
77
+ def after_evaluation(self, **kwargs):
78
+ pass
79
+
80
+ def inference_step(self):
81
+ raise NotImplementedError
82
+
83
+ def evaluation(self, model, data_loader, cuda_enabled=True):
84
+ metric_logger = MetricLogger(delimiter=" ")
85
+ header = "Evaluation"
86
+ # TODO make it configurable
87
+ print_freq = 10
88
+
89
+ results = []
90
+
91
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
92
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
93
+
94
+ eval_output = self.valid_step(model=model, samples=samples)
95
+ results.extend(eval_output)
96
+
97
+ if is_dist_avail_and_initialized():
98
+ dist.barrier()
99
+
100
+ return results
101
+
102
+ def train_epoch(
103
+ self,
104
+ epoch,
105
+ model,
106
+ data_loader,
107
+ optimizer,
108
+ lr_scheduler,
109
+ scaler=None,
110
+ cuda_enabled=False,
111
+ log_freq=50,
112
+ accum_grad_iters=1,
113
+ ):
114
+ return self._train_inner_loop(
115
+ epoch=epoch,
116
+ iters_per_epoch=lr_scheduler.iters_per_epoch,
117
+ model=model,
118
+ data_loader=data_loader,
119
+ optimizer=optimizer,
120
+ scaler=scaler,
121
+ lr_scheduler=lr_scheduler,
122
+ log_freq=log_freq,
123
+ cuda_enabled=cuda_enabled,
124
+ accum_grad_iters=accum_grad_iters,
125
+ )
126
+
127
+ def train_iters(
128
+ self,
129
+ epoch,
130
+ start_iters,
131
+ iters_per_inner_epoch,
132
+ model,
133
+ data_loader,
134
+ optimizer,
135
+ lr_scheduler,
136
+ scaler=None,
137
+ cuda_enabled=False,
138
+ log_freq=50,
139
+ accum_grad_iters=1,
140
+ ):
141
+ return self._train_inner_loop(
142
+ epoch=epoch,
143
+ start_iters=start_iters,
144
+ iters_per_epoch=iters_per_inner_epoch,
145
+ model=model,
146
+ data_loader=data_loader,
147
+ optimizer=optimizer,
148
+ scaler=scaler,
149
+ lr_scheduler=lr_scheduler,
150
+ log_freq=log_freq,
151
+ cuda_enabled=cuda_enabled,
152
+ accum_grad_iters=accum_grad_iters,
153
+ )
154
+
155
+ def _train_inner_loop(
156
+ self,
157
+ epoch,
158
+ iters_per_epoch,
159
+ model,
160
+ data_loader,
161
+ optimizer,
162
+ lr_scheduler,
163
+ scaler=None,
164
+ start_iters=None,
165
+ log_freq=50,
166
+ cuda_enabled=False,
167
+ accum_grad_iters=1,
168
+ ):
169
+ """
170
+ An inner training loop compatible with both epoch-based and iter-based training.
171
+
172
+ When using epoch-based, training stops after one epoch; when using iter-based,
173
+ training stops after #iters_per_epoch iterations.
174
+ """
175
+ use_amp = scaler is not None
176
+
177
+ if not hasattr(data_loader, "__next__"):
178
+ # convert to iterator if not already
179
+ data_loader = iter(data_loader)
180
+
181
+ metric_logger = MetricLogger(delimiter=" ")
182
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
183
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
184
+
185
+ # if iter-based runner, schedule lr based on inner epoch.
186
+ logging.info(
187
+ "Start training epoch {}, {} iters per inner epoch.".format(
188
+ epoch, iters_per_epoch
189
+ )
190
+ )
191
+ header = "Train: data epoch: [{}]".format(epoch)
192
+ if start_iters is None:
193
+ # epoch-based runner
194
+ inner_epoch = epoch
195
+ else:
196
+ # In iter-based runner, we schedule the learning rate based on iterations.
197
+ inner_epoch = start_iters // iters_per_epoch
198
+ header = header + "; inner epoch [{}]".format(inner_epoch)
199
+
200
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
201
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
202
+ if i >= iters_per_epoch:
203
+ break
204
+
205
+ samples = next(data_loader)
206
+
207
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
208
+ samples.update(
209
+ {
210
+ "epoch": inner_epoch,
211
+ "num_iters_per_epoch": iters_per_epoch,
212
+ "iters": i,
213
+ }
214
+ )
215
+
216
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
217
+
218
+ with torch.cuda.amp.autocast(enabled=use_amp):
219
+ loss = self.train_step(model=model, samples=samples)
220
+
221
+ # after_train_step()
222
+ if use_amp:
223
+ scaler.scale(loss).backward()
224
+ else:
225
+ loss.backward()
226
+
227
+ # update gradients every accum_grad_iters iterations
228
+ if (i + 1) % accum_grad_iters == 0:
229
+ if use_amp:
230
+ scaler.step(optimizer)
231
+ scaler.update()
232
+ else:
233
+ optimizer.step()
234
+ optimizer.zero_grad()
235
+
236
+ metric_logger.update(loss=loss.item())
237
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
238
+
239
+ # after train_epoch()
240
+ # gather the stats from all processes
241
+ metric_logger.synchronize_between_processes()
242
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
243
+ return {
244
+ k: "{:.3f}".format(meter.global_avg)
245
+ for k, meter in metric_logger.meters.items()
246
+ }
247
+
248
+ @staticmethod
249
+ def save_result(result, result_dir, filename, remove_duplicate=""):
250
+ import json
251
+
252
+ result_file = os.path.join(
253
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
254
+ )
255
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
256
+
257
+ json.dump(result, open(result_file, "w"))
258
+
259
+ if is_dist_avail_and_initialized():
260
+ dist.barrier()
261
+
262
+ if is_main_process():
263
+ logging.warning("rank %d starts merging results." % get_rank())
264
+ # combine results from all processes
265
+ result = []
266
+
267
+ for rank in range(get_world_size()):
268
+ result_file = os.path.join(
269
+ result_dir, "%s_rank%d.json" % (filename, rank)
270
+ )
271
+ res = json.load(open(result_file, "r"))
272
+ result += res
273
+
274
+ if remove_duplicate:
275
+ result_new = []
276
+ id_list = []
277
+ for res in result:
278
+ if res[remove_duplicate] not in id_list:
279
+ id_list.append(res[remove_duplicate])
280
+ result_new.append(res)
281
+ result = result_new
282
+
283
+ json.dump(result, open(final_result_file, "w"))
284
+ print("result file saved to %s" % final_result_file)
285
+
286
+ return final_result_file
minigpt4/tasks/image_text_pretrain.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.common.registry import registry
9
+ from minigpt4.tasks.base_task import BaseTask
10
+
11
+
12
+ @registry.register_task("image_text_pretrain")
13
+ class ImageTextPretrainTask(BaseTask):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def evaluation(self, model, data_loader, cuda_enabled=True):
18
+ pass