saicharan1234 commited on
Commit
6742cd5
1 Parent(s): adbca6c

Delete minigpt4

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. minigpt4/__init__.py +0 -31
  2. minigpt4/common/__init__.py +0 -0
  3. minigpt4/common/config.py +0 -496
  4. minigpt4/common/dist_utils.py +0 -140
  5. minigpt4/common/eval_utils.py +0 -76
  6. minigpt4/common/gradcam.py +0 -24
  7. minigpt4/common/logger.py +0 -195
  8. minigpt4/common/optims.py +0 -119
  9. minigpt4/common/registry.py +0 -329
  10. minigpt4/common/utils.py +0 -424
  11. minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py +0 -89
  12. minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py +0 -1
  13. minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py +0 -192
  14. minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py +0 -73
  15. minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py +0 -1
  16. minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py +0 -179
  17. minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt +0 -81
  18. minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt +0 -65
  19. minigpt4/common/vqa_tools/VQA/README.md +0 -80
  20. minigpt4/common/vqa_tools/VQA/license.txt +0 -30
  21. minigpt4/common/vqa_tools/__init__.py +0 -8
  22. minigpt4/common/vqa_tools/vqa.py +0 -211
  23. minigpt4/common/vqa_tools/vqa_eval.py +0 -324
  24. minigpt4/configs/datasets/cc_combine/align.yaml +0 -16
  25. minigpt4/configs/datasets/cc_combine/defaults.yaml +0 -11
  26. minigpt4/configs/datasets/laion/defaults.yaml +0 -13
  27. minigpt4/configs/default.yaml +0 -5
  28. minigpt4/configs/models/minigpt4_vicuna0.yaml +0 -32
  29. minigpt4/conversation/__init__.py +0 -0
  30. minigpt4/conversation/conversation.py +0 -233
  31. minigpt4/datasets/__init__.py +0 -0
  32. minigpt4/datasets/builders/__init__.py +0 -72
  33. minigpt4/datasets/builders/base_dataset_builder.py +0 -236
  34. minigpt4/datasets/builders/image_text_pair_builder.py +0 -535
  35. minigpt4/datasets/data_utils.py +0 -199
  36. minigpt4/datasets/datasets/__init__.py +0 -0
  37. minigpt4/datasets/datasets/aok_vqa_datasets.py +0 -116
  38. minigpt4/datasets/datasets/base_dataset.py +0 -78
  39. minigpt4/datasets/datasets/caption_datasets.py +0 -151
  40. minigpt4/datasets/datasets/cc_sbu_dataset.py +0 -47
  41. minigpt4/datasets/datasets/coco_caption.py +0 -120
  42. minigpt4/datasets/datasets/coco_dataset.py +0 -348
  43. minigpt4/datasets/datasets/coco_vqa_datasets.py +0 -145
  44. minigpt4/datasets/datasets/dataloader_utils.py +0 -162
  45. minigpt4/datasets/datasets/flickr.py +0 -159
  46. minigpt4/datasets/datasets/gqa_datasets.py +0 -60
  47. minigpt4/datasets/datasets/laion_dataset.py +0 -31
  48. minigpt4/datasets/datasets/llava_dataset.py +0 -149
  49. minigpt4/datasets/datasets/multitask_conversation.py +0 -75
  50. minigpt4/datasets/datasets/ocrvqa_dataset.py +0 -77
minigpt4/__init__.py DELETED
@@ -1,31 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import os
9
- import sys
10
-
11
- from omegaconf import OmegaConf
12
-
13
- from minigpt4.common.registry import registry
14
-
15
- from minigpt4.datasets.builders import *
16
- from minigpt4.models import *
17
- from minigpt4.processors import *
18
- from minigpt4.tasks import *
19
-
20
-
21
- root_dir = os.path.dirname(os.path.abspath(__file__))
22
- default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
-
24
- registry.register_path("library_root", root_dir)
25
- repo_root = os.path.join(root_dir, "..")
26
- registry.register_path("repo_root", repo_root)
27
- cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
- registry.register_path("cache_root", cache_root)
29
-
30
- registry.register("MAX_INT", sys.maxsize)
31
- registry.register("SPLIT_NAMES", ["train", "val", "test"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/__init__.py DELETED
File without changes
minigpt4/common/config.py DELETED
@@ -1,496 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import logging
9
- import json
10
- from typing import Dict
11
-
12
- from omegaconf import OmegaConf
13
- from minigpt4.common.registry import registry
14
-
15
-
16
- class Config:
17
- def __init__(self, args):
18
- self.config = {}
19
-
20
- self.args = args
21
-
22
- # Register the config and configuration for setup
23
- registry.register("configuration", self)
24
-
25
- user_config = self._build_opt_list(self.args.options)
26
-
27
- config = OmegaConf.load(self.args.cfg_path)
28
-
29
- runner_config = self.build_runner_config(config)
30
- model_config = self.build_model_config(config, **user_config)
31
- dataset_config = self.build_dataset_config(config)
32
- evaluation_dataset_config = self.build_evaluation_dataset_config(config)
33
-
34
- # Validate the user-provided runner configuration
35
- # model and dataset configuration are supposed to be validated by the respective classes
36
- # [TODO] validate the model/dataset configuration
37
- # self._validate_runner_config(runner_config)
38
-
39
- # Override the default configuration with user options.
40
- self.config = OmegaConf.merge(
41
- runner_config, model_config, dataset_config,evaluation_dataset_config, user_config
42
- )
43
-
44
- def _validate_runner_config(self, runner_config):
45
- """
46
- This method validates the configuration, such that
47
- 1) all the user specified options are valid;
48
- 2) no type mismatches between the user specified options and the config.
49
- """
50
- runner_config_validator = create_runner_config_validator()
51
- runner_config_validator.validate(runner_config)
52
-
53
- def _build_opt_list(self, opts):
54
- opts_dot_list = self._convert_to_dot_list(opts)
55
- return OmegaConf.from_dotlist(opts_dot_list)
56
-
57
- @staticmethod
58
- def build_model_config(config, **kwargs):
59
- model = config.get("model", None)
60
- assert model is not None, "Missing model configuration file."
61
-
62
- model_cls = registry.get_model_class(model.arch)
63
- assert model_cls is not None, f"Model '{model.arch}' has not been registered."
64
-
65
- model_type = kwargs.get("model.model_type", None)
66
- if not model_type:
67
- model_type = model.get("model_type", None)
68
- # else use the model type selected by user.
69
-
70
- assert model_type is not None, "Missing model_type."
71
-
72
- model_config_path = model_cls.default_config_path(model_type=model_type)
73
-
74
- model_config = OmegaConf.create()
75
- # hierarchy override, customized config > default config
76
- model_config = OmegaConf.merge(
77
- model_config,
78
- OmegaConf.load(model_config_path),
79
- {"model": config["model"]},
80
- )
81
-
82
- return model_config
83
-
84
- @staticmethod
85
- def build_runner_config(config):
86
- return {"run": config.run}
87
-
88
- @staticmethod
89
- def build_dataset_config(config):
90
- datasets = config.get("datasets", None)
91
- if datasets is None:
92
- raise KeyError(
93
- "Expecting 'datasets' as the root key for dataset configuration."
94
- )
95
-
96
- dataset_config = OmegaConf.create()
97
-
98
- for dataset_name in datasets:
99
- builder_cls = registry.get_builder_class(dataset_name)
100
-
101
- dataset_config_type = datasets[dataset_name].get("type", "default")
102
- dataset_config_path = builder_cls.default_config_path(
103
- type=dataset_config_type
104
- )
105
-
106
- # hierarchy override, customized config > default config
107
- dataset_config = OmegaConf.merge(
108
- dataset_config,
109
- OmegaConf.load(dataset_config_path),
110
- {"datasets": {dataset_name: config["datasets"][dataset_name]}},
111
- )
112
-
113
- return dataset_config
114
-
115
-
116
- @staticmethod
117
- def build_evaluation_dataset_config(config):
118
- datasets = config.get("evaluation_datasets", None)
119
- # if datasets is None:
120
- # raise KeyError(
121
- # "Expecting 'datasets' as the root key for dataset configuration."
122
- # )
123
-
124
- dataset_config = OmegaConf.create()
125
-
126
- if datasets is not None:
127
- for dataset_name in datasets:
128
- builder_cls = registry.get_builder_class(dataset_name)
129
-
130
- # hierarchy override, customized config > default config
131
- dataset_config = OmegaConf.merge(
132
- dataset_config,
133
- {"evaluation_datasets": {dataset_name: config["evaluation_datasets"][dataset_name]}},
134
- )
135
-
136
- return dataset_config
137
-
138
- def _convert_to_dot_list(self, opts):
139
- if opts is None:
140
- opts = []
141
-
142
- if len(opts) == 0:
143
- return opts
144
-
145
- has_equal = opts[0].find("=") != -1
146
-
147
- if has_equal:
148
- return opts
149
-
150
- return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
151
-
152
- def get_config(self):
153
- return self.config
154
-
155
- @property
156
- def run_cfg(self):
157
- return self.config.run
158
-
159
- @property
160
- def datasets_cfg(self):
161
- return self.config.datasets
162
-
163
- @property
164
- def evaluation_datasets_cfg(self):
165
- return self.config.evaluation_datasets
166
-
167
- @property
168
- def model_cfg(self):
169
- return self.config.model
170
-
171
- def pretty_print(self):
172
- logging.info("\n===== Running Parameters =====")
173
- logging.info(self._convert_node_to_json(self.config.run))
174
-
175
- logging.info("\n====== Dataset Attributes ======")
176
- datasets = self.config.datasets
177
-
178
- for dataset in datasets:
179
- if dataset in self.config.datasets:
180
- logging.info(f"\n======== {dataset} =======")
181
- dataset_config = self.config.datasets[dataset]
182
- logging.info(self._convert_node_to_json(dataset_config))
183
- else:
184
- logging.warning(f"No dataset named '{dataset}' in config. Skipping")
185
-
186
- logging.info(f"\n====== Model Attributes ======")
187
- logging.info(self._convert_node_to_json(self.config.model))
188
-
189
- def _convert_node_to_json(self, node):
190
- container = OmegaConf.to_container(node, resolve=True)
191
- return json.dumps(container, indent=4, sort_keys=True)
192
-
193
- def to_dict(self):
194
- return OmegaConf.to_container(self.config)
195
-
196
-
197
- def node_to_dict(node):
198
- return OmegaConf.to_container(node)
199
-
200
-
201
- class ConfigValidator:
202
- """
203
- This is a preliminary implementation to centralize and validate the configuration.
204
- May be altered in the future.
205
-
206
- A helper class to validate configurations from yaml file.
207
-
208
- This serves the following purposes:
209
- 1. Ensure all the options in the yaml are defined, raise error if not.
210
- 2. when type mismatches are found, the validator will raise an error.
211
- 3. a central place to store and display helpful messages for supported configurations.
212
-
213
- """
214
-
215
- class _Argument:
216
- def __init__(self, name, choices=None, type=None, help=None):
217
- self.name = name
218
- self.val = None
219
- self.choices = choices
220
- self.type = type
221
- self.help = help
222
-
223
- def __str__(self):
224
- s = f"{self.name}={self.val}"
225
- if self.type is not None:
226
- s += f", ({self.type})"
227
- if self.choices is not None:
228
- s += f", choices: {self.choices}"
229
- if self.help is not None:
230
- s += f", ({self.help})"
231
- return s
232
-
233
- def __init__(self, description):
234
- self.description = description
235
-
236
- self.arguments = dict()
237
-
238
- self.parsed_args = None
239
-
240
- def __getitem__(self, key):
241
- assert self.parsed_args is not None, "No arguments parsed yet."
242
-
243
- return self.parsed_args[key]
244
-
245
- def __str__(self) -> str:
246
- return self.format_help()
247
-
248
- def add_argument(self, *args, **kwargs):
249
- """
250
- Assume the first argument is the name of the argument.
251
- """
252
- self.arguments[args[0]] = self._Argument(*args, **kwargs)
253
-
254
- def validate(self, config=None):
255
- """
256
- Convert yaml config (dict-like) to list, required by argparse.
257
- """
258
- for k, v in config.items():
259
- assert (
260
- k in self.arguments
261
- ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
262
-
263
- if self.arguments[k].type is not None:
264
- try:
265
- self.arguments[k].val = self.arguments[k].type(v)
266
- except ValueError:
267
- raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
268
-
269
- if self.arguments[k].choices is not None:
270
- assert (
271
- v in self.arguments[k].choices
272
- ), f"""{k} must be one of {self.arguments[k].choices}."""
273
-
274
- return config
275
-
276
- def format_arguments(self):
277
- return str([f"{k}" for k in sorted(self.arguments.keys())])
278
-
279
- def format_help(self):
280
- # description + key-value pair string for each argument
281
- help_msg = str(self.description)
282
- return help_msg + ", available arguments: " + self.format_arguments()
283
-
284
- def print_help(self):
285
- # display help message
286
- print(self.format_help())
287
-
288
-
289
- def create_runner_config_validator():
290
- validator = ConfigValidator(description="Runner configurations")
291
-
292
- validator.add_argument(
293
- "runner",
294
- type=str,
295
- choices=["runner_base", "runner_iter"],
296
- help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
297
- runner runs based on iters. Default: runner_base""",
298
- )
299
- # add argumetns for training dataset ratios
300
- validator.add_argument(
301
- "train_dataset_ratios",
302
- type=Dict[str, float],
303
- help="""Ratios of training dataset. This is used in iteration-based runner.
304
- Do not support for epoch-based runner because how to define an epoch becomes tricky.
305
- Default: None""",
306
- )
307
- validator.add_argument(
308
- "max_iters",
309
- type=float,
310
- help="Maximum number of iterations to run.",
311
- )
312
- validator.add_argument(
313
- "max_epoch",
314
- type=int,
315
- help="Maximum number of epochs to run.",
316
- )
317
- # add arguments for iters_per_inner_epoch
318
- validator.add_argument(
319
- "iters_per_inner_epoch",
320
- type=float,
321
- help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
322
- )
323
- lr_scheds_choices = registry.list_lr_schedulers()
324
- validator.add_argument(
325
- "lr_sched",
326
- type=str,
327
- choices=lr_scheds_choices,
328
- help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
329
- )
330
- task_choices = registry.list_tasks()
331
- validator.add_argument(
332
- "task",
333
- type=str,
334
- choices=task_choices,
335
- help="Task to use, from {}".format(task_choices),
336
- )
337
- # add arguments for init_lr
338
- validator.add_argument(
339
- "init_lr",
340
- type=float,
341
- help="Initial learning rate. This will be the learning rate after warmup and before decay.",
342
- )
343
- # add arguments for min_lr
344
- validator.add_argument(
345
- "min_lr",
346
- type=float,
347
- help="Minimum learning rate (after decay).",
348
- )
349
- # add arguments for warmup_lr
350
- validator.add_argument(
351
- "warmup_lr",
352
- type=float,
353
- help="Starting learning rate for warmup.",
354
- )
355
- # add arguments for learning rate decay rate
356
- validator.add_argument(
357
- "lr_decay_rate",
358
- type=float,
359
- help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
360
- )
361
- # add arguments for weight decay
362
- validator.add_argument(
363
- "weight_decay",
364
- type=float,
365
- help="Weight decay rate.",
366
- )
367
- # add arguments for training batch size
368
- validator.add_argument(
369
- "batch_size_train",
370
- type=int,
371
- help="Training batch size.",
372
- )
373
- # add arguments for evaluation batch size
374
- validator.add_argument(
375
- "batch_size_eval",
376
- type=int,
377
- help="Evaluation batch size, including validation and testing.",
378
- )
379
- # add arguments for number of workers for data loading
380
- validator.add_argument(
381
- "num_workers",
382
- help="Number of workers for data loading.",
383
- )
384
- # add arguments for warm up steps
385
- validator.add_argument(
386
- "warmup_steps",
387
- type=int,
388
- help="Number of warmup steps. Required if a warmup schedule is used.",
389
- )
390
- # add arguments for random seed
391
- validator.add_argument(
392
- "seed",
393
- type=int,
394
- help="Random seed.",
395
- )
396
- # add arguments for output directory
397
- validator.add_argument(
398
- "output_dir",
399
- type=str,
400
- help="Output directory to save checkpoints and logs.",
401
- )
402
- # add arguments for whether only use evaluation
403
- validator.add_argument(
404
- "evaluate",
405
- help="Whether to only evaluate the model. If true, training will not be performed.",
406
- )
407
- # add arguments for splits used for training, e.g. ["train", "val"]
408
- validator.add_argument(
409
- "train_splits",
410
- type=list,
411
- help="Splits to use for training.",
412
- )
413
- # add arguments for splits used for validation, e.g. ["val"]
414
- validator.add_argument(
415
- "valid_splits",
416
- type=list,
417
- help="Splits to use for validation. If not provided, will skip the validation.",
418
- )
419
- # add arguments for splits used for testing, e.g. ["test"]
420
- validator.add_argument(
421
- "test_splits",
422
- type=list,
423
- help="Splits to use for testing. If not provided, will skip the testing.",
424
- )
425
- # add arguments for accumulating gradient for iterations
426
- validator.add_argument(
427
- "accum_grad_iters",
428
- type=int,
429
- help="Number of iterations to accumulate gradient for.",
430
- )
431
-
432
- # ====== distributed training ======
433
- validator.add_argument(
434
- "device",
435
- type=str,
436
- choices=["cpu", "cuda"],
437
- help="Device to use. Support 'cuda' or 'cpu' as for now.",
438
- )
439
- validator.add_argument(
440
- "world_size",
441
- type=int,
442
- help="Number of processes participating in the job.",
443
- )
444
- validator.add_argument("dist_url", type=str)
445
- validator.add_argument("distributed", type=bool)
446
- # add arguments to opt using distributed sampler during evaluation or not
447
- validator.add_argument(
448
- "use_dist_eval_sampler",
449
- type=bool,
450
- help="Whether to use distributed sampler during evaluation or not.",
451
- )
452
-
453
- # ====== task specific ======
454
- # generation task specific arguments
455
- # add arguments for maximal length of text output
456
- validator.add_argument(
457
- "max_len",
458
- type=int,
459
- help="Maximal length of text output.",
460
- )
461
- # add arguments for minimal length of text output
462
- validator.add_argument(
463
- "min_len",
464
- type=int,
465
- help="Minimal length of text output.",
466
- )
467
- # add arguments number of beams
468
- validator.add_argument(
469
- "num_beams",
470
- type=int,
471
- help="Number of beams used for beam search.",
472
- )
473
-
474
- # vqa task specific arguments
475
- # add arguments for number of answer candidates
476
- validator.add_argument(
477
- "num_ans_candidates",
478
- type=int,
479
- help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
480
- )
481
- # add arguments for inference method
482
- validator.add_argument(
483
- "inference_method",
484
- type=str,
485
- choices=["genearte", "rank"],
486
- help="""Inference method to use for question answering. If rank, requires a answer list.""",
487
- )
488
-
489
- # ====== model specific ======
490
- validator.add_argument(
491
- "k_test",
492
- type=int,
493
- help="Number of top k most similar samples from ITC/VTC selection to be tested.",
494
- )
495
-
496
- return validator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/dist_utils.py DELETED
@@ -1,140 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import datetime
9
- import functools
10
- import os
11
-
12
- import torch
13
- import torch.distributed as dist
14
- import timm.models.hub as timm_hub
15
-
16
-
17
- def setup_for_distributed(is_master):
18
- """
19
- This function disables printing when not in master process
20
- """
21
- import builtins as __builtin__
22
-
23
- builtin_print = __builtin__.print
24
-
25
- def print(*args, **kwargs):
26
- force = kwargs.pop("force", False)
27
- if is_master or force:
28
- builtin_print(*args, **kwargs)
29
-
30
- __builtin__.print = print
31
-
32
-
33
- def is_dist_avail_and_initialized():
34
- if not dist.is_available():
35
- return False
36
- if not dist.is_initialized():
37
- return False
38
- return True
39
-
40
-
41
- def get_world_size():
42
- if not is_dist_avail_and_initialized():
43
- return 1
44
- return dist.get_world_size()
45
-
46
-
47
- def get_rank():
48
- if not is_dist_avail_and_initialized():
49
- return 0
50
- return dist.get_rank()
51
-
52
-
53
- def is_main_process():
54
- return get_rank() == 0
55
-
56
-
57
- def init_distributed_mode(args):
58
- if args.distributed is False:
59
- print("Not using distributed mode")
60
- return
61
- elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
62
- args.rank = int(os.environ["RANK"])
63
- args.world_size = int(os.environ["WORLD_SIZE"])
64
- args.gpu = int(os.environ["LOCAL_RANK"])
65
- elif "SLURM_PROCID" in os.environ:
66
- args.rank = int(os.environ["SLURM_PROCID"])
67
- args.gpu = args.rank % torch.cuda.device_count()
68
- else:
69
- print("Not using distributed mode")
70
- args.distributed = False
71
- return
72
-
73
- args.distributed = True
74
-
75
- torch.cuda.set_device(args.gpu)
76
- args.dist_backend = "nccl"
77
- print(
78
- "| distributed init (rank {}, world {}): {}".format(
79
- args.rank, args.world_size, args.dist_url
80
- ),
81
- flush=True,
82
- )
83
- torch.distributed.init_process_group(
84
- backend=args.dist_backend,
85
- init_method=args.dist_url,
86
- world_size=args.world_size,
87
- rank=args.rank,
88
- timeout=datetime.timedelta(
89
- days=365
90
- ), # allow auto-downloading and de-compressing
91
- )
92
- torch.distributed.barrier()
93
- setup_for_distributed(args.rank == 0)
94
-
95
-
96
- def get_dist_info():
97
- if torch.__version__ < "1.0":
98
- initialized = dist._initialized
99
- else:
100
- initialized = dist.is_initialized()
101
- if initialized:
102
- rank = dist.get_rank()
103
- world_size = dist.get_world_size()
104
- else: # non-distributed training
105
- rank = 0
106
- world_size = 1
107
- return rank, world_size
108
-
109
-
110
- def main_process(func):
111
- @functools.wraps(func)
112
- def wrapper(*args, **kwargs):
113
- rank, _ = get_dist_info()
114
- if rank == 0:
115
- return func(*args, **kwargs)
116
-
117
- return wrapper
118
-
119
-
120
- def download_cached_file(url, check_hash=True, progress=False):
121
- """
122
- Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
123
- If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
124
- """
125
-
126
- def get_cached_file_path():
127
- # a hack to sync the file path across processes
128
- parts = torch.hub.urlparse(url)
129
- filename = os.path.basename(parts.path)
130
- cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
131
-
132
- return cached_file
133
-
134
- if is_main_process():
135
- timm_hub.download_cached_file(url, check_hash, progress)
136
-
137
- if is_dist_avail_and_initialized():
138
- dist.barrier()
139
-
140
- return get_cached_file_path()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/eval_utils.py DELETED
@@ -1,76 +0,0 @@
1
- import argparse
2
- import numpy as np
3
- from nltk.translate.bleu_score import sentence_bleu
4
-
5
- from minigpt4.common.registry import registry
6
- from minigpt4.common.config import Config
7
-
8
- # imports modules for registration
9
- from minigpt4.datasets.builders import *
10
- from minigpt4.models import *
11
- from minigpt4.processors import *
12
- from minigpt4.runners import *
13
- from minigpt4.tasks import *
14
-
15
-
16
-
17
- def eval_parser():
18
- parser = argparse.ArgumentParser(description="Demo")
19
- parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
20
- parser.add_argument("--name", type=str, default='A2', help="evaluation name")
21
- parser.add_argument("--ckpt", type=str, help="path to configuration file.")
22
- parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
23
- parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens")
24
- parser.add_argument("--batch_size", type=int, default=32)
25
- parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
26
- parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
27
- parser.add_argument(
28
- "--options",
29
- nargs="+",
30
- help="override some settings in the used config, the key-value pair "
31
- "in xxx=yyy format will be merged into config file (deprecate), "
32
- "change to --cfg-options instead.",
33
- )
34
- return parser
35
-
36
-
37
- def prepare_texts(texts, conv_temp):
38
- convs = [conv_temp.copy() for _ in range(len(texts))]
39
- [conv.append_message(
40
- conv.roles[0], '<Img><ImageHere></Img> {}'.format(text)) for conv, text in zip(convs, texts)]
41
- [conv.append_message(conv.roles[1], None) for conv in convs]
42
- texts = [conv.get_prompt() for conv in convs]
43
- return texts
44
-
45
-
46
- def init_model(args):
47
- print('Initialization Model')
48
- cfg = Config(args)
49
- # cfg.model_cfg.ckpt = args.ckpt
50
- # cfg.model_cfg.lora_r = args.lora_r
51
- # cfg.model_cfg.lora_alpha = args.lora_alpha
52
-
53
- model_config = cfg.model_cfg
54
- model_cls = registry.get_model_class(model_config.arch)
55
- model = model_cls.from_config(model_config).to('cuda:0')
56
-
57
- # import pudb; pudb.set_trace()
58
- key = list(cfg.datasets_cfg.keys())[0]
59
- vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
60
- vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
61
- print('Initialization Finished')
62
- return model, vis_processor
63
-
64
- def computeIoU(bbox1, bbox2):
65
- x1, y1, x2, y2 = bbox1
66
- x3, y3, x4, y4 = bbox2
67
- intersection_x1 = max(x1, x3)
68
- intersection_y1 = max(y1, y3)
69
- intersection_x2 = min(x2, x4)
70
- intersection_y2 = min(y2, y4)
71
- intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
72
- bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
73
- bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
74
- union_area = bbox1_area + bbox2_area - intersection_area
75
- iou = intersection_area / union_area
76
- return iou
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/gradcam.py DELETED
@@ -1,24 +0,0 @@
1
- import numpy as np
2
- from matplotlib import pyplot as plt
3
- from scipy.ndimage import filters
4
- from skimage import transform as skimage_transform
5
-
6
-
7
- def getAttMap(img, attMap, blur=True, overlap=True):
8
- attMap -= attMap.min()
9
- if attMap.max() > 0:
10
- attMap /= attMap.max()
11
- attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
- if blur:
13
- attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
- attMap -= attMap.min()
15
- attMap /= attMap.max()
16
- cmap = plt.get_cmap("jet")
17
- attMapV = cmap(attMap)
18
- attMapV = np.delete(attMapV, 3, 2)
19
- if overlap:
20
- attMap = (
21
- 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
- + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
- )
24
- return attMap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/logger.py DELETED
@@ -1,195 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import datetime
9
- import logging
10
- import time
11
- from collections import defaultdict, deque
12
-
13
- import torch
14
- import torch.distributed as dist
15
-
16
- from minigpt4.common import dist_utils
17
-
18
-
19
- class SmoothedValue(object):
20
- """Track a series of values and provide access to smoothed values over a
21
- window or the global series average.
22
- """
23
-
24
- def __init__(self, window_size=20, fmt=None):
25
- if fmt is None:
26
- fmt = "{median:.4f} ({global_avg:.4f})"
27
- self.deque = deque(maxlen=window_size)
28
- self.total = 0.0
29
- self.count = 0
30
- self.fmt = fmt
31
-
32
- def update(self, value, n=1):
33
- self.deque.append(value)
34
- self.count += n
35
- self.total += value * n
36
-
37
- def synchronize_between_processes(self):
38
- """
39
- Warning: does not synchronize the deque!
40
- """
41
- if not dist_utils.is_dist_avail_and_initialized():
42
- return
43
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
- dist.barrier()
45
- dist.all_reduce(t)
46
- t = t.tolist()
47
- self.count = int(t[0])
48
- self.total = t[1]
49
-
50
- @property
51
- def median(self):
52
- d = torch.tensor(list(self.deque))
53
- return d.median().item()
54
-
55
- @property
56
- def avg(self):
57
- d = torch.tensor(list(self.deque), dtype=torch.float32)
58
- return d.mean().item()
59
-
60
- @property
61
- def global_avg(self):
62
- return self.total / self.count
63
-
64
- @property
65
- def max(self):
66
- return max(self.deque)
67
-
68
- @property
69
- def value(self):
70
- return self.deque[-1]
71
-
72
- def __str__(self):
73
- return self.fmt.format(
74
- median=self.median,
75
- avg=self.avg,
76
- global_avg=self.global_avg,
77
- max=self.max,
78
- value=self.value,
79
- )
80
-
81
-
82
- class MetricLogger(object):
83
- def __init__(self, delimiter="\t"):
84
- self.meters = defaultdict(SmoothedValue)
85
- self.delimiter = delimiter
86
-
87
- def update(self, **kwargs):
88
- for k, v in kwargs.items():
89
- if isinstance(v, torch.Tensor):
90
- v = v.item()
91
- assert isinstance(v, (float, int))
92
- self.meters[k].update(v)
93
-
94
- def __getattr__(self, attr):
95
- if attr in self.meters:
96
- return self.meters[attr]
97
- if attr in self.__dict__:
98
- return self.__dict__[attr]
99
- raise AttributeError(
100
- "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
- )
102
-
103
- def __str__(self):
104
- loss_str = []
105
- for name, meter in self.meters.items():
106
- loss_str.append("{}: {}".format(name, str(meter)))
107
- return self.delimiter.join(loss_str)
108
-
109
- def global_avg(self):
110
- loss_str = []
111
- for name, meter in self.meters.items():
112
- loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
- return self.delimiter.join(loss_str)
114
-
115
- def synchronize_between_processes(self):
116
- for meter in self.meters.values():
117
- meter.synchronize_between_processes()
118
-
119
- def add_meter(self, name, meter):
120
- self.meters[name] = meter
121
-
122
- def log_every(self, iterable, print_freq, header=None):
123
- i = 0
124
- if not header:
125
- header = ""
126
- start_time = time.time()
127
- end = time.time()
128
- iter_time = SmoothedValue(fmt="{avg:.4f}")
129
- data_time = SmoothedValue(fmt="{avg:.4f}")
130
- space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
- log_msg = [
132
- header,
133
- "[{0" + space_fmt + "}/{1}]",
134
- "eta: {eta}",
135
- "{meters}",
136
- "time: {time}",
137
- "data: {data}",
138
- ]
139
- if torch.cuda.is_available():
140
- log_msg.append("max mem: {memory:.0f}")
141
- log_msg = self.delimiter.join(log_msg)
142
- MB = 1024.0 * 1024.0
143
- for obj in iterable:
144
- data_time.update(time.time() - end)
145
- yield obj
146
- iter_time.update(time.time() - end)
147
- if i % print_freq == 0 or i == len(iterable) - 1:
148
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
- if torch.cuda.is_available():
151
- print(
152
- log_msg.format(
153
- i,
154
- len(iterable),
155
- eta=eta_string,
156
- meters=str(self),
157
- time=str(iter_time),
158
- data=str(data_time),
159
- memory=torch.cuda.max_memory_allocated() / MB,
160
- )
161
- )
162
- else:
163
- print(
164
- log_msg.format(
165
- i,
166
- len(iterable),
167
- eta=eta_string,
168
- meters=str(self),
169
- time=str(iter_time),
170
- data=str(data_time),
171
- )
172
- )
173
- i += 1
174
- end = time.time()
175
- total_time = time.time() - start_time
176
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
- print(
178
- "{} Total time: {} ({:.4f} s / it)".format(
179
- header, total_time_str, total_time / len(iterable)
180
- )
181
- )
182
-
183
-
184
- class AttrDict(dict):
185
- def __init__(self, *args, **kwargs):
186
- super(AttrDict, self).__init__(*args, **kwargs)
187
- self.__dict__ = self
188
-
189
-
190
- def setup_logger():
191
- logging.basicConfig(
192
- level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
- format="%(asctime)s [%(levelname)s] %(message)s",
194
- handlers=[logging.StreamHandler()],
195
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/optims.py DELETED
@@ -1,119 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import math
9
-
10
- from minigpt4.common.registry import registry
11
-
12
-
13
- @registry.register_lr_scheduler("linear_warmup_step_lr")
14
- class LinearWarmupStepLRScheduler:
15
- def __init__(
16
- self,
17
- optimizer,
18
- max_epoch,
19
- min_lr,
20
- init_lr,
21
- decay_rate=1,
22
- warmup_start_lr=-1,
23
- warmup_steps=0,
24
- **kwargs
25
- ):
26
- self.optimizer = optimizer
27
-
28
- self.max_epoch = max_epoch
29
- self.min_lr = min_lr
30
-
31
- self.decay_rate = decay_rate
32
-
33
- self.init_lr = init_lr
34
- self.warmup_steps = warmup_steps
35
- self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
-
37
- def step(self, cur_epoch, cur_step):
38
- if cur_epoch == 0:
39
- warmup_lr_schedule(
40
- step=cur_step,
41
- optimizer=self.optimizer,
42
- max_step=self.warmup_steps,
43
- init_lr=self.warmup_start_lr,
44
- max_lr=self.init_lr,
45
- )
46
- else:
47
- step_lr_schedule(
48
- epoch=cur_epoch,
49
- optimizer=self.optimizer,
50
- init_lr=self.init_lr,
51
- min_lr=self.min_lr,
52
- decay_rate=self.decay_rate,
53
- )
54
-
55
-
56
- @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
- class LinearWarmupCosineLRScheduler:
58
- def __init__(
59
- self,
60
- optimizer,
61
- max_epoch,
62
- iters_per_epoch,
63
- min_lr,
64
- init_lr,
65
- warmup_steps=0,
66
- warmup_start_lr=-1,
67
- **kwargs
68
- ):
69
- self.optimizer = optimizer
70
-
71
- self.max_epoch = max_epoch
72
- self.iters_per_epoch = iters_per_epoch
73
- self.min_lr = min_lr
74
-
75
- self.init_lr = init_lr
76
- self.warmup_steps = warmup_steps
77
- self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
-
79
- def step(self, cur_epoch, cur_step):
80
- total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
- if total_cur_step < self.warmup_steps:
82
- warmup_lr_schedule(
83
- step=cur_step,
84
- optimizer=self.optimizer,
85
- max_step=self.warmup_steps,
86
- init_lr=self.warmup_start_lr,
87
- max_lr=self.init_lr,
88
- )
89
- else:
90
- cosine_lr_schedule(
91
- epoch=total_cur_step,
92
- optimizer=self.optimizer,
93
- max_epoch=self.max_epoch * self.iters_per_epoch,
94
- init_lr=self.init_lr,
95
- min_lr=self.min_lr,
96
- )
97
-
98
-
99
- def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
- """Decay the learning rate"""
101
- lr = (init_lr - min_lr) * 0.5 * (
102
- 1.0 + math.cos(math.pi * epoch / max_epoch)
103
- ) + min_lr
104
- for param_group in optimizer.param_groups:
105
- param_group["lr"] = lr
106
-
107
-
108
- def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
- """Warmup the learning rate"""
110
- lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
- for param_group in optimizer.param_groups:
112
- param_group["lr"] = lr
113
-
114
-
115
- def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
- """Decay the learning rate"""
117
- lr = max(min_lr, init_lr * (decay_rate**epoch))
118
- for param_group in optimizer.param_groups:
119
- param_group["lr"] = lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/registry.py DELETED
@@ -1,329 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
-
9
- class Registry:
10
- mapping = {
11
- "builder_name_mapping": {},
12
- "task_name_mapping": {},
13
- "processor_name_mapping": {},
14
- "model_name_mapping": {},
15
- "lr_scheduler_name_mapping": {},
16
- "runner_name_mapping": {},
17
- "state": {},
18
- "paths": {},
19
- }
20
-
21
- @classmethod
22
- def register_builder(cls, name):
23
- r"""Register a dataset builder to registry with key 'name'
24
-
25
- Args:
26
- name: Key with which the builder will be registered.
27
-
28
- Usage:
29
-
30
- from minigpt4.common.registry import registry
31
- from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
32
- """
33
-
34
- def wrap(builder_cls):
35
- from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
-
37
- assert issubclass(
38
- builder_cls, BaseDatasetBuilder
39
- ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
- builder_cls
41
- )
42
- if name in cls.mapping["builder_name_mapping"]:
43
- raise KeyError(
44
- "Name '{}' already registered for {}.".format(
45
- name, cls.mapping["builder_name_mapping"][name]
46
- )
47
- )
48
- cls.mapping["builder_name_mapping"][name] = builder_cls
49
- return builder_cls
50
-
51
- return wrap
52
-
53
- @classmethod
54
- def register_task(cls, name):
55
- r"""Register a task to registry with key 'name'
56
-
57
- Args:
58
- name: Key with which the task will be registered.
59
-
60
- Usage:
61
-
62
- from minigpt4.common.registry import registry
63
- """
64
-
65
- def wrap(task_cls):
66
- from minigpt4.tasks.base_task import BaseTask
67
-
68
- assert issubclass(
69
- task_cls, BaseTask
70
- ), "All tasks must inherit BaseTask class"
71
- if name in cls.mapping["task_name_mapping"]:
72
- raise KeyError(
73
- "Name '{}' already registered for {}.".format(
74
- name, cls.mapping["task_name_mapping"][name]
75
- )
76
- )
77
- cls.mapping["task_name_mapping"][name] = task_cls
78
- return task_cls
79
-
80
- return wrap
81
-
82
- @classmethod
83
- def register_model(cls, name):
84
- r"""Register a task to registry with key 'name'
85
-
86
- Args:
87
- name: Key with which the task will be registered.
88
-
89
- Usage:
90
-
91
- from minigpt4.common.registry import registry
92
- """
93
-
94
- def wrap(model_cls):
95
- from minigpt4.models import BaseModel
96
-
97
- assert issubclass(
98
- model_cls, BaseModel
99
- ), "All models must inherit BaseModel class"
100
- if name in cls.mapping["model_name_mapping"]:
101
- raise KeyError(
102
- "Name '{}' already registered for {}.".format(
103
- name, cls.mapping["model_name_mapping"][name]
104
- )
105
- )
106
- cls.mapping["model_name_mapping"][name] = model_cls
107
- return model_cls
108
-
109
- return wrap
110
-
111
- @classmethod
112
- def register_processor(cls, name):
113
- r"""Register a processor to registry with key 'name'
114
-
115
- Args:
116
- name: Key with which the task will be registered.
117
-
118
- Usage:
119
-
120
- from minigpt4.common.registry import registry
121
- """
122
-
123
- def wrap(processor_cls):
124
- from minigpt4.processors import BaseProcessor
125
-
126
- assert issubclass(
127
- processor_cls, BaseProcessor
128
- ), "All processors must inherit BaseProcessor class"
129
- if name in cls.mapping["processor_name_mapping"]:
130
- raise KeyError(
131
- "Name '{}' already registered for {}.".format(
132
- name, cls.mapping["processor_name_mapping"][name]
133
- )
134
- )
135
- cls.mapping["processor_name_mapping"][name] = processor_cls
136
- return processor_cls
137
-
138
- return wrap
139
-
140
- @classmethod
141
- def register_lr_scheduler(cls, name):
142
- r"""Register a model to registry with key 'name'
143
-
144
- Args:
145
- name: Key with which the task will be registered.
146
-
147
- Usage:
148
-
149
- from minigpt4.common.registry import registry
150
- """
151
-
152
- def wrap(lr_sched_cls):
153
- if name in cls.mapping["lr_scheduler_name_mapping"]:
154
- raise KeyError(
155
- "Name '{}' already registered for {}.".format(
156
- name, cls.mapping["lr_scheduler_name_mapping"][name]
157
- )
158
- )
159
- cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160
- return lr_sched_cls
161
-
162
- return wrap
163
-
164
- @classmethod
165
- def register_runner(cls, name):
166
- r"""Register a model to registry with key 'name'
167
-
168
- Args:
169
- name: Key with which the task will be registered.
170
-
171
- Usage:
172
-
173
- from minigpt4.common.registry import registry
174
- """
175
-
176
- def wrap(runner_cls):
177
- if name in cls.mapping["runner_name_mapping"]:
178
- raise KeyError(
179
- "Name '{}' already registered for {}.".format(
180
- name, cls.mapping["runner_name_mapping"][name]
181
- )
182
- )
183
- cls.mapping["runner_name_mapping"][name] = runner_cls
184
- return runner_cls
185
-
186
- return wrap
187
-
188
- @classmethod
189
- def register_path(cls, name, path):
190
- r"""Register a path to registry with key 'name'
191
-
192
- Args:
193
- name: Key with which the path will be registered.
194
-
195
- Usage:
196
-
197
- from minigpt4.common.registry import registry
198
- """
199
- assert isinstance(path, str), "All path must be str."
200
- if name in cls.mapping["paths"]:
201
- raise KeyError("Name '{}' already registered.".format(name))
202
- cls.mapping["paths"][name] = path
203
-
204
- @classmethod
205
- def register(cls, name, obj):
206
- r"""Register an item to registry with key 'name'
207
-
208
- Args:
209
- name: Key with which the item will be registered.
210
-
211
- Usage::
212
-
213
- from minigpt4.common.registry import registry
214
-
215
- registry.register("config", {})
216
- """
217
- path = name.split(".")
218
- current = cls.mapping["state"]
219
-
220
- for part in path[:-1]:
221
- if part not in current:
222
- current[part] = {}
223
- current = current[part]
224
-
225
- current[path[-1]] = obj
226
-
227
- # @classmethod
228
- # def get_trainer_class(cls, name):
229
- # return cls.mapping["trainer_name_mapping"].get(name, None)
230
-
231
- @classmethod
232
- def get_builder_class(cls, name):
233
- return cls.mapping["builder_name_mapping"].get(name, None)
234
-
235
- @classmethod
236
- def get_model_class(cls, name):
237
- return cls.mapping["model_name_mapping"].get(name, None)
238
-
239
- @classmethod
240
- def get_task_class(cls, name):
241
- return cls.mapping["task_name_mapping"].get(name, None)
242
-
243
- @classmethod
244
- def get_processor_class(cls, name):
245
- return cls.mapping["processor_name_mapping"].get(name, None)
246
-
247
- @classmethod
248
- def get_lr_scheduler_class(cls, name):
249
- return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250
-
251
- @classmethod
252
- def get_runner_class(cls, name):
253
- return cls.mapping["runner_name_mapping"].get(name, None)
254
-
255
- @classmethod
256
- def list_runners(cls):
257
- return sorted(cls.mapping["runner_name_mapping"].keys())
258
-
259
- @classmethod
260
- def list_models(cls):
261
- return sorted(cls.mapping["model_name_mapping"].keys())
262
-
263
- @classmethod
264
- def list_tasks(cls):
265
- return sorted(cls.mapping["task_name_mapping"].keys())
266
-
267
- @classmethod
268
- def list_processors(cls):
269
- return sorted(cls.mapping["processor_name_mapping"].keys())
270
-
271
- @classmethod
272
- def list_lr_schedulers(cls):
273
- return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274
-
275
- @classmethod
276
- def list_datasets(cls):
277
- return sorted(cls.mapping["builder_name_mapping"].keys())
278
-
279
- @classmethod
280
- def get_path(cls, name):
281
- return cls.mapping["paths"].get(name, None)
282
-
283
- @classmethod
284
- def get(cls, name, default=None, no_warning=False):
285
- r"""Get an item from registry with key 'name'
286
-
287
- Args:
288
- name (string): Key whose value needs to be retrieved.
289
- default: If passed and key is not in registry, default value will
290
- be returned with a warning. Default: None
291
- no_warning (bool): If passed as True, warning when key doesn't exist
292
- will not be generated. Useful for MMF's
293
- internal operations. Default: False
294
- """
295
- original_name = name
296
- name = name.split(".")
297
- value = cls.mapping["state"]
298
- for subname in name:
299
- value = value.get(subname, default)
300
- if value is default:
301
- break
302
-
303
- if (
304
- "writer" in cls.mapping["state"]
305
- and value == default
306
- and no_warning is False
307
- ):
308
- cls.mapping["state"]["writer"].warning(
309
- "Key {} is not present in registry, returning default value "
310
- "of {}".format(original_name, default)
311
- )
312
- return value
313
-
314
- @classmethod
315
- def unregister(cls, name):
316
- r"""Remove an item from registry with key 'name'
317
-
318
- Args:
319
- name: Key which needs to be removed.
320
- Usage::
321
-
322
- from mmf.common.registry import registry
323
-
324
- config = registry.unregister("config")
325
- """
326
- return cls.mapping["state"].pop(name, None)
327
-
328
-
329
- registry = Registry()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/utils.py DELETED
@@ -1,424 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import io
9
- import json
10
- import logging
11
- import os
12
- import pickle
13
- import re
14
- import shutil
15
- import urllib
16
- import urllib.error
17
- import urllib.request
18
- from typing import Optional
19
- from urllib.parse import urlparse
20
-
21
- import numpy as np
22
- import pandas as pd
23
- import yaml
24
- from iopath.common.download import download
25
- from iopath.common.file_io import file_lock, g_pathmgr
26
- from minigpt4.common.registry import registry
27
- from torch.utils.model_zoo import tqdm
28
- from torchvision.datasets.utils import (
29
- check_integrity,
30
- download_file_from_google_drive,
31
- extract_archive,
32
- )
33
-
34
-
35
- def now():
36
- from datetime import datetime
37
-
38
- return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
-
40
-
41
- def is_url(url_or_filename):
42
- parsed = urlparse(url_or_filename)
43
- return parsed.scheme in ("http", "https")
44
-
45
-
46
- def get_cache_path(rel_path):
47
- return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
-
49
-
50
- def get_abs_path(rel_path):
51
- return os.path.join(registry.get_path("library_root"), rel_path)
52
-
53
-
54
- def load_json(filename):
55
- with open(filename, "r") as f:
56
- return json.load(f)
57
-
58
-
59
- # The following are adapted from torchvision and vissl
60
- # torchvision: https://github.com/pytorch/vision
61
- # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
-
63
-
64
- def makedir(dir_path):
65
- """
66
- Create the directory if it does not exist.
67
- """
68
- is_success = False
69
- try:
70
- if not g_pathmgr.exists(dir_path):
71
- g_pathmgr.mkdirs(dir_path)
72
- is_success = True
73
- except BaseException:
74
- print(f"Error creating directory: {dir_path}")
75
- return is_success
76
-
77
-
78
- def get_redirected_url(url: str):
79
- """
80
- Given a URL, returns the URL it redirects to or the
81
- original URL in case of no indirection
82
- """
83
- import requests
84
-
85
- with requests.Session() as session:
86
- with session.get(url, stream=True, allow_redirects=True) as response:
87
- if response.history:
88
- return response.url
89
- else:
90
- return url
91
-
92
-
93
- def to_google_drive_download_url(view_url: str) -> str:
94
- """
95
- Utility function to transform a view URL of google drive
96
- to a download URL for google drive
97
- Example input:
98
- https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
- Example output:
100
- https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
- """
102
- splits = view_url.split("/")
103
- assert splits[-1] == "view"
104
- file_id = splits[-2]
105
- return f"https://drive.google.com/uc?export=download&id={file_id}"
106
-
107
-
108
- def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
- """
110
- Download a file from google drive
111
- Downloading an URL from google drive requires confirmation when
112
- the file of the size is too big (google drive notifies that
113
- anti-viral checks cannot be performed on such files)
114
- """
115
- import requests
116
-
117
- with requests.Session() as session:
118
-
119
- # First get the confirmation token and append it to the URL
120
- with session.get(url, stream=True, allow_redirects=True) as response:
121
- for k, v in response.cookies.items():
122
- if k.startswith("download_warning"):
123
- url = url + "&confirm=" + v
124
-
125
- # Then download the content of the file
126
- with session.get(url, stream=True, verify=True) as response:
127
- makedir(output_path)
128
- path = os.path.join(output_path, output_file_name)
129
- total_size = int(response.headers.get("Content-length", 0))
130
- with open(path, "wb") as file:
131
- from tqdm import tqdm
132
-
133
- with tqdm(total=total_size) as progress_bar:
134
- for block in response.iter_content(
135
- chunk_size=io.DEFAULT_BUFFER_SIZE
136
- ):
137
- file.write(block)
138
- progress_bar.update(len(block))
139
-
140
-
141
- def _get_google_drive_file_id(url: str) -> Optional[str]:
142
- parts = urlparse(url)
143
-
144
- if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
- return None
146
-
147
- match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
- if match is None:
149
- return None
150
-
151
- return match.group("id")
152
-
153
-
154
- def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
- with open(filename, "wb") as fh:
156
- with urllib.request.urlopen(
157
- urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
- ) as response:
159
- with tqdm(total=response.length) as pbar:
160
- for chunk in iter(lambda: response.read(chunk_size), ""):
161
- if not chunk:
162
- break
163
- pbar.update(chunk_size)
164
- fh.write(chunk)
165
-
166
-
167
- def download_url(
168
- url: str,
169
- root: str,
170
- filename: Optional[str] = None,
171
- md5: Optional[str] = None,
172
- ) -> None:
173
- """Download a file from a url and place it in root.
174
- Args:
175
- url (str): URL to download file from
176
- root (str): Directory to place downloaded file in
177
- filename (str, optional): Name to save the file under.
178
- If None, use the basename of the URL.
179
- md5 (str, optional): MD5 checksum of the download. If None, do not check
180
- """
181
- root = os.path.expanduser(root)
182
- if not filename:
183
- filename = os.path.basename(url)
184
- fpath = os.path.join(root, filename)
185
-
186
- makedir(root)
187
-
188
- # check if file is already present locally
189
- if check_integrity(fpath, md5):
190
- print("Using downloaded and verified file: " + fpath)
191
- return
192
-
193
- # expand redirect chain if needed
194
- url = get_redirected_url(url)
195
-
196
- # check if file is located on Google Drive
197
- file_id = _get_google_drive_file_id(url)
198
- if file_id is not None:
199
- return download_file_from_google_drive(file_id, root, filename, md5)
200
-
201
- # download the file
202
- try:
203
- print("Downloading " + url + " to " + fpath)
204
- _urlretrieve(url, fpath)
205
- except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
- if url[:5] == "https":
207
- url = url.replace("https:", "http:")
208
- print(
209
- "Failed download. Trying https -> http instead."
210
- " Downloading " + url + " to " + fpath
211
- )
212
- _urlretrieve(url, fpath)
213
- else:
214
- raise e
215
-
216
- # check integrity of downloaded file
217
- if not check_integrity(fpath, md5):
218
- raise RuntimeError("File not found or corrupted.")
219
-
220
-
221
- def download_and_extract_archive(
222
- url: str,
223
- download_root: str,
224
- extract_root: Optional[str] = None,
225
- filename: Optional[str] = None,
226
- md5: Optional[str] = None,
227
- remove_finished: bool = False,
228
- ) -> None:
229
- download_root = os.path.expanduser(download_root)
230
- if extract_root is None:
231
- extract_root = download_root
232
- if not filename:
233
- filename = os.path.basename(url)
234
-
235
- download_url(url, download_root, filename, md5)
236
-
237
- archive = os.path.join(download_root, filename)
238
- print("Extracting {} to {}".format(archive, extract_root))
239
- extract_archive(archive, extract_root, remove_finished)
240
-
241
-
242
- def cache_url(url: str, cache_dir: str) -> str:
243
- """
244
- This implementation downloads the remote resource and caches it locally.
245
- The resource will only be downloaded if not previously requested.
246
- """
247
- parsed_url = urlparse(url)
248
- dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
- makedir(dirname)
250
- filename = url.split("/")[-1]
251
- cached = os.path.join(dirname, filename)
252
- with file_lock(cached):
253
- if not os.path.isfile(cached):
254
- logging.info(f"Downloading {url} to {cached} ...")
255
- cached = download(url, dirname, filename=filename)
256
- logging.info(f"URL {url} cached in {cached}")
257
- return cached
258
-
259
-
260
- # TODO (prigoyal): convert this into RAII-style API
261
- def create_file_symlink(file1, file2):
262
- """
263
- Simply create the symlinks for a given file1 to file2.
264
- Useful during model checkpointing to symlinks to the
265
- latest successful checkpoint.
266
- """
267
- try:
268
- if g_pathmgr.exists(file2):
269
- g_pathmgr.rm(file2)
270
- g_pathmgr.symlink(file1, file2)
271
- except Exception as e:
272
- logging.info(f"Could NOT create symlink. Error: {e}")
273
-
274
-
275
- def save_file(data, filename, append_to_json=True, verbose=True):
276
- """
277
- Common i/o utility to handle saving data to various file formats.
278
- Supported:
279
- .pkl, .pickle, .npy, .json
280
- Specifically for .json, users have the option to either append (default)
281
- or rewrite by passing in Boolean value to append_to_json.
282
- """
283
- if verbose:
284
- logging.info(f"Saving data to file: {filename}")
285
- file_ext = os.path.splitext(filename)[1]
286
- if file_ext in [".pkl", ".pickle"]:
287
- with g_pathmgr.open(filename, "wb") as fopen:
288
- pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
- elif file_ext == ".npy":
290
- with g_pathmgr.open(filename, "wb") as fopen:
291
- np.save(fopen, data)
292
- elif file_ext == ".json":
293
- if append_to_json:
294
- with g_pathmgr.open(filename, "a") as fopen:
295
- fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
- fopen.flush()
297
- else:
298
- with g_pathmgr.open(filename, "w") as fopen:
299
- fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
- fopen.flush()
301
- elif file_ext == ".yaml":
302
- with g_pathmgr.open(filename, "w") as fopen:
303
- dump = yaml.dump(data)
304
- fopen.write(dump)
305
- fopen.flush()
306
- else:
307
- raise Exception(f"Saving {file_ext} is not supported yet")
308
-
309
- if verbose:
310
- logging.info(f"Saved data to file: {filename}")
311
-
312
-
313
- def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
- """
315
- Common i/o utility to handle loading data from various file formats.
316
- Supported:
317
- .pkl, .pickle, .npy, .json
318
- For the npy files, we support reading the files in mmap_mode.
319
- If the mmap_mode of reading is not successful, we load data without the
320
- mmap_mode.
321
- """
322
- if verbose:
323
- logging.info(f"Loading data from file: {filename}")
324
-
325
- file_ext = os.path.splitext(filename)[1]
326
- if file_ext == ".txt":
327
- with g_pathmgr.open(filename, "r") as fopen:
328
- data = fopen.readlines()
329
- elif file_ext in [".pkl", ".pickle"]:
330
- with g_pathmgr.open(filename, "rb") as fopen:
331
- data = pickle.load(fopen, encoding="latin1")
332
- elif file_ext == ".npy":
333
- if mmap_mode:
334
- try:
335
- with g_pathmgr.open(filename, "rb") as fopen:
336
- data = np.load(
337
- fopen,
338
- allow_pickle=allow_pickle,
339
- encoding="latin1",
340
- mmap_mode=mmap_mode,
341
- )
342
- except ValueError as e:
343
- logging.info(
344
- f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
- )
346
- data = np.load(
347
- filename,
348
- allow_pickle=allow_pickle,
349
- encoding="latin1",
350
- mmap_mode=mmap_mode,
351
- )
352
- logging.info("Successfully loaded without g_pathmgr")
353
- except Exception:
354
- logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
- with g_pathmgr.open(filename, "rb") as fopen:
356
- data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
- else:
358
- with g_pathmgr.open(filename, "rb") as fopen:
359
- data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
- elif file_ext == ".json":
361
- with g_pathmgr.open(filename, "r") as fopen:
362
- data = json.load(fopen)
363
- elif file_ext == ".yaml":
364
- with g_pathmgr.open(filename, "r") as fopen:
365
- data = yaml.load(fopen, Loader=yaml.FullLoader)
366
- elif file_ext == ".csv":
367
- with g_pathmgr.open(filename, "r") as fopen:
368
- data = pd.read_csv(fopen)
369
- else:
370
- raise Exception(f"Reading from {file_ext} is not supported yet")
371
- return data
372
-
373
-
374
- def abspath(resource_path: str):
375
- """
376
- Make a path absolute, but take into account prefixes like
377
- "http://" or "manifold://"
378
- """
379
- regex = re.compile(r"^\w+://")
380
- if regex.match(resource_path) is None:
381
- return os.path.abspath(resource_path)
382
- else:
383
- return resource_path
384
-
385
-
386
- def makedir(dir_path):
387
- """
388
- Create the directory if it does not exist.
389
- """
390
- is_success = False
391
- try:
392
- if not g_pathmgr.exists(dir_path):
393
- g_pathmgr.mkdirs(dir_path)
394
- is_success = True
395
- except BaseException:
396
- logging.info(f"Error creating directory: {dir_path}")
397
- return is_success
398
-
399
-
400
- def is_url(input_url):
401
- """
402
- Check if an input string is a url. look for http(s):// and ignoring the case
403
- """
404
- is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
- return is_url
406
-
407
-
408
- def cleanup_dir(dir):
409
- """
410
- Utility for deleting a directory. Useful for cleaning the storage space
411
- that contains various training artifacts like checkpoints, data etc.
412
- """
413
- if os.path.exists(dir):
414
- logging.info(f"Deleting directory: {dir}")
415
- shutil.rmtree(dir)
416
- logging.info(f"Deleted contents of directory: {dir}")
417
-
418
-
419
- def get_file_size(filename):
420
- """
421
- Given a file, get the size of file in MB
422
- """
423
- size_in_mb = os.path.getsize(filename) / float(1024**2)
424
- return size_in_mb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py DELETED
@@ -1,89 +0,0 @@
1
- # coding: utf-8
2
-
3
- import sys
4
- dataDir = '../../VQA'
5
- sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
6
- from vqa import VQA
7
- from vqaEvaluation.vqaEval import VQAEval
8
- import matplotlib.pyplot as plt
9
- import skimage.io as io
10
- import json
11
- import random
12
- import os
13
-
14
- # set up file names and paths
15
- versionType ='v2_' # this should be '' when using VQA v2.0 dataset
16
- taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
17
- dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
18
- dataSubType ='train2014'
19
- annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
20
- quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
21
- imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
22
- resultType ='fake'
23
- fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
24
-
25
- # An example result json file has been provided in './Results' folder.
26
-
27
- [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
28
- resultType, fileType) for fileType in fileTypes]
29
-
30
- # create vqa object and vqaRes object
31
- vqa = VQA(annFile, quesFile)
32
- vqaRes = vqa.loadRes(resFile, quesFile)
33
-
34
- # create vqaEval object by taking vqa and vqaRes
35
- vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
36
-
37
- # evaluate results
38
- """
39
- If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
40
- By default it uses all the question ids in annotation file
41
- """
42
- vqaEval.evaluate()
43
-
44
- # print accuracies
45
- print "\n"
46
- print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
47
- print "Per Question Type Accuracy is the following:"
48
- for quesType in vqaEval.accuracy['perQuestionType']:
49
- print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
50
- print "\n"
51
- print "Per Answer Type Accuracy is the following:"
52
- for ansType in vqaEval.accuracy['perAnswerType']:
53
- print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
54
- print "\n"
55
- # demo how to use evalQA to retrieve low score result
56
- evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
57
- if len(evals) > 0:
58
- print 'ground truth answers'
59
- randomEval = random.choice(evals)
60
- randomAnn = vqa.loadQA(randomEval)
61
- vqa.showQA(randomAnn)
62
-
63
- print '\n'
64
- print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
65
- ann = vqaRes.loadQA(randomEval)[0]
66
- print "Answer: %s\n" %(ann['answer'])
67
-
68
- imgId = randomAnn[0]['image_id']
69
- imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
70
- if os.path.isfile(imgDir + imgFilename):
71
- I = io.imread(imgDir + imgFilename)
72
- plt.imshow(I)
73
- plt.axis('off')
74
- plt.show()
75
-
76
- # plot accuracy for various question types
77
- plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
78
- plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
79
- plt.title('Per Question Type Accuracy', fontsize=10)
80
- plt.xlabel('Question Types', fontsize=10)
81
- plt.ylabel('Accuracy', fontsize=10)
82
- plt.show()
83
-
84
- # save evaluation results to ./Results folder
85
- json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
86
- json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
87
- json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
88
- json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
89
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py DELETED
@@ -1 +0,0 @@
1
- author='aagrawal'
 
 
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py DELETED
@@ -1,192 +0,0 @@
1
- # coding=utf-8
2
-
3
- __author__='aagrawal'
4
-
5
- import re
6
- # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7
- # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
8
- import sys
9
-
10
-
11
- class VQAEval:
12
- def __init__(self, vqa, vqaRes, n=2):
13
- self.n = n
14
- self.accuracy = {}
15
- self.evalQA = {}
16
- self.evalQuesType = {}
17
- self.evalAnsType = {}
18
- self.vqa = vqa
19
- self.vqaRes = vqaRes
20
- self.params = {'question_id': vqa.getQuesIds()}
21
- self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
22
- "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
23
- "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
24
- "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
25
- "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
26
- "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
27
- "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
28
- "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
29
- "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
30
- "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
31
- "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
32
- "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
33
- "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
34
- "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
35
- "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
36
- "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
37
- "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
38
- "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
39
- "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
40
- "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
41
- "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
42
- "youll": "you'll", "youre": "you're", "youve": "you've"}
43
- self.manualMap = { 'none': '0',
44
- 'zero': '0',
45
- 'one': '1',
46
- 'two': '2',
47
- 'three': '3',
48
- 'four': '4',
49
- 'five': '5',
50
- 'six': '6',
51
- 'seven': '7',
52
- 'eight': '8',
53
- 'nine': '9',
54
- 'ten': '10'
55
- }
56
- self.articles = ['a',
57
- 'an',
58
- 'the'
59
- ]
60
-
61
-
62
- self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
63
- self.commaStrip = re.compile("(\d)(\,)(\d)")
64
- self.punct = [';', r"/", '[', ']', '"', '{', '}',
65
- '(', ')', '=', '+', '\\', '_', '-',
66
- '>', '<', '@', '`', ',', '?', '!']
67
-
68
-
69
- def evaluate(self, quesIds=None):
70
- if quesIds == None:
71
- quesIds = [quesId for quesId in self.params['question_id']]
72
- gts = {}
73
- res = {}
74
- for quesId in quesIds:
75
- gts[quesId] = self.vqa.qa[quesId]
76
- res[quesId] = self.vqaRes.qa[quesId]
77
-
78
- # =================================================
79
- # Compute accuracy
80
- # =================================================
81
- accQA = []
82
- accQuesType = {}
83
- accAnsType = {}
84
- # print "computing accuracy"
85
- step = 0
86
- for quesId in quesIds:
87
- for ansDic in gts[quesId]['answers']:
88
- ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
89
- ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
90
- ansDic['answer'] = ansDic['answer'].strip()
91
- resAns = res[quesId]['answer']
92
- resAns = resAns.replace('\n', ' ')
93
- resAns = resAns.replace('\t', ' ')
94
- resAns = resAns.strip()
95
- gtAcc = []
96
- gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
97
-
98
- if len(set(gtAnswers)) > 1:
99
- for ansDic in gts[quesId]['answers']:
100
- ansDic['answer'] = self.processPunctuation(ansDic['answer'])
101
- ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
102
- resAns = self.processPunctuation(resAns)
103
- resAns = self.processDigitArticle(resAns)
104
-
105
- for gtAnsDatum in gts[quesId]['answers']:
106
- otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
107
- matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
108
- acc = min(1, float(len(matchingAns))/3)
109
- gtAcc.append(acc)
110
- quesType = gts[quesId]['question_type']
111
- ansType = gts[quesId]['answer_type']
112
- avgGTAcc = float(sum(gtAcc))/len(gtAcc)
113
- accQA.append(avgGTAcc)
114
- if quesType not in accQuesType:
115
- accQuesType[quesType] = []
116
- accQuesType[quesType].append(avgGTAcc)
117
- if ansType not in accAnsType:
118
- accAnsType[ansType] = []
119
- accAnsType[ansType].append(avgGTAcc)
120
- self.setEvalQA(quesId, avgGTAcc)
121
- self.setEvalQuesType(quesId, quesType, avgGTAcc)
122
- self.setEvalAnsType(quesId, ansType, avgGTAcc)
123
- if step%100 == 0:
124
- self.updateProgress(step/float(len(quesIds)))
125
- step = step + 1
126
-
127
- self.setAccuracy(accQA, accQuesType, accAnsType)
128
- # print "Done computing accuracy"
129
-
130
- def processPunctuation(self, inText):
131
- outText = inText
132
- for p in self.punct:
133
- if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
134
- outText = outText.replace(p, '')
135
- else:
136
- outText = outText.replace(p, ' ')
137
- outText = self.periodStrip.sub("",
138
- outText,
139
- re.UNICODE)
140
- return outText
141
-
142
- def processDigitArticle(self, inText):
143
- outText = []
144
- tempText = inText.lower().split()
145
- for word in tempText:
146
- word = self.manualMap.setdefault(word, word)
147
- if word not in self.articles:
148
- outText.append(word)
149
- else:
150
- pass
151
- for wordId, word in enumerate(outText):
152
- if word in self.contractions:
153
- outText[wordId] = self.contractions[word]
154
- outText = ' '.join(outText)
155
- return outText
156
-
157
- def setAccuracy(self, accQA, accQuesType, accAnsType):
158
- self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
159
- self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
160
- self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
161
-
162
- def setEvalQA(self, quesId, acc):
163
- self.evalQA[quesId] = round(100*acc, self.n)
164
-
165
- def setEvalQuesType(self, quesId, quesType, acc):
166
- if quesType not in self.evalQuesType:
167
- self.evalQuesType[quesType] = {}
168
- self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
169
-
170
- def setEvalAnsType(self, quesId, ansType, acc):
171
- if ansType not in self.evalAnsType:
172
- self.evalAnsType[ansType] = {}
173
- self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
174
-
175
- def updateProgress(self, progress):
176
- barLength = 20
177
- status = ""
178
- if isinstance(progress, int):
179
- progress = float(progress)
180
- if not isinstance(progress, float):
181
- progress = 0
182
- status = "error: progress var must be float\r\n"
183
- if progress < 0:
184
- progress = 0
185
- status = "Halt...\r\n"
186
- if progress >= 1:
187
- progress = 1
188
- status = "Done...\r\n"
189
- block = int(round(barLength*progress))
190
- text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
191
- sys.stdout.write(text)
192
- sys.stdout.flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py DELETED
@@ -1,73 +0,0 @@
1
- # coding: utf-8
2
-
3
- from vqaTools.vqa import VQA
4
- import random
5
- import skimage.io as io
6
- import matplotlib.pyplot as plt
7
- import os
8
-
9
- dataDir ='../../VQA'
10
- versionType ='v2_' # this should be '' when using VQA v2.0 dataset
11
- taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
12
- dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
13
- dataSubType ='train2014'
14
- annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
15
- quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
16
- imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
17
-
18
- # initialize VQA api for QA annotations
19
- vqa=VQA(annFile, quesFile)
20
-
21
- # load and display QA annotations for given question types
22
- """
23
- All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
24
- """
25
- annIds = vqa.getQuesIds(quesTypes='how many');
26
- anns = vqa.loadQA(annIds)
27
- randomAnn = random.choice(anns)
28
- vqa.showQA([randomAnn])
29
- imgId = randomAnn['image_id']
30
- imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
31
- if os.path.isfile(imgDir + imgFilename):
32
- I = io.imread(imgDir + imgFilename)
33
- plt.imshow(I)
34
- plt.axis('off')
35
- plt.show()
36
-
37
- # load and display QA annotations for given answer types
38
- """
39
- ansTypes can be one of the following
40
- yes/no
41
- number
42
- other
43
- """
44
- annIds = vqa.getQuesIds(ansTypes='yes/no');
45
- anns = vqa.loadQA(annIds)
46
- randomAnn = random.choice(anns)
47
- vqa.showQA([randomAnn])
48
- imgId = randomAnn['image_id']
49
- imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
50
- if os.path.isfile(imgDir + imgFilename):
51
- I = io.imread(imgDir + imgFilename)
52
- plt.imshow(I)
53
- plt.axis('off')
54
- plt.show()
55
-
56
- # load and display QA annotations for given images
57
- """
58
- Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
59
- Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
60
- """
61
- ids = vqa.getImgIds()
62
- annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
63
- anns = vqa.loadQA(annIds)
64
- randomAnn = random.choice(anns)
65
- vqa.showQA([randomAnn])
66
- imgId = randomAnn['image_id']
67
- imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
68
- if os.path.isfile(imgDir + imgFilename):
69
- I = io.imread(imgDir + imgFilename)
70
- plt.imshow(I)
71
- plt.axis('off')
72
- plt.show()
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py DELETED
@@ -1 +0,0 @@
1
- __author__ = 'aagrawal'
 
 
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py DELETED
@@ -1,179 +0,0 @@
1
- __author__ = 'aagrawal'
2
- __version__ = '0.9'
3
-
4
- # Interface for accessing the VQA dataset.
5
-
6
- # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7
- # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
8
-
9
- # The following functions are defined:
10
- # VQA - VQA class that loads VQA annotation file and prepares data structures.
11
- # getQuesIds - Get question ids that satisfy given filter conditions.
12
- # getImgIds - Get image ids that satisfy given filter conditions.
13
- # loadQA - Load questions and answers with the specified question ids.
14
- # showQA - Display the specified questions and answers.
15
- # loadRes - Load result file and create result object.
16
-
17
- # Help on each function can be accessed by: "help(COCO.function)"
18
-
19
- import json
20
- import datetime
21
- import copy
22
-
23
-
24
- class VQA:
25
- def __init__(self, annotation_file=None, question_file=None):
26
- """
27
- Constructor of VQA helper class for reading and visualizing questions and answers.
28
- :param annotation_file (str): location of VQA annotation file
29
- :return:
30
- """
31
- # load dataset
32
- self.dataset = {}
33
- self.questions = {}
34
- self.qa = {}
35
- self.qqa = {}
36
- self.imgToQA = {}
37
- if not annotation_file == None and not question_file == None:
38
- # print 'loading VQA annotations and questions into memory...'
39
- time_t = datetime.datetime.utcnow()
40
- dataset = json.load(open(annotation_file, 'r'))
41
- questions = json.load(open(question_file, 'r'))
42
- # print datetime.datetime.utcnow() - time_t
43
- self.dataset = dataset
44
- self.questions = questions
45
- self.createIndex()
46
-
47
- def createIndex(self):
48
- imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
49
- qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
50
- qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
51
- for ann in self.dataset['annotations']:
52
- imgToQA[ann['image_id']] += [ann]
53
- qa[ann['question_id']] = ann
54
- for ques in self.questions['questions']:
55
- qqa[ques['question_id']] = ques
56
- # print 'index created!'
57
-
58
- # create class members
59
- self.qa = qa
60
- self.qqa = qqa
61
- self.imgToQA = imgToQA
62
-
63
- def info(self):
64
- """
65
- Print information about the VQA annotation file.
66
- :return:
67
- """
68
-
69
- # for key, value in self.datset['info'].items():
70
- # print '%s: %s'%(key, value)
71
-
72
- def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
73
- """
74
- Get question ids that satisfy given filter conditions. default skips that filter
75
- :param imgIds (int array) : get question ids for given imgs
76
- quesTypes (str array) : get question ids for given question types
77
- ansTypes (str array) : get question ids for given answer types
78
- :return: ids (int array) : integer array of question ids
79
- """
80
- imgIds = imgIds if type(imgIds) == list else [imgIds]
81
- quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
82
- ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
83
-
84
- if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
85
- anns = self.dataset['annotations']
86
- else:
87
- if not len(imgIds) == 0:
88
- anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
89
- else:
90
- anns = self.dataset['annotations']
91
- anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
92
- anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
93
- ids = [ann['question_id'] for ann in anns]
94
- return ids
95
-
96
- def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
97
- """
98
- Get image ids that satisfy given filter conditions. default skips that filter
99
- :param quesIds (int array) : get image ids for given question ids
100
- quesTypes (str array) : get image ids for given question types
101
- ansTypes (str array) : get image ids for given answer types
102
- :return: ids (int array) : integer array of image ids
103
- """
104
- quesIds = quesIds if type(quesIds) == list else [quesIds]
105
- quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
106
- ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
107
-
108
- if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
109
- anns = self.dataset['annotations']
110
- else:
111
- if not len(quesIds) == 0:
112
- anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
113
- else:
114
- anns = self.dataset['annotations']
115
- anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
116
- anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
117
- ids = [ann['image_id'] for ann in anns]
118
- return ids
119
-
120
- def loadQA(self, ids=[]):
121
- """
122
- Load questions and answers with the specified question ids.
123
- :param ids (int array) : integer ids specifying question ids
124
- :return: qa (object array) : loaded qa objects
125
- """
126
- if type(ids) == list:
127
- return [self.qa[id] for id in ids]
128
- elif type(ids) == int:
129
- return [self.qa[ids]]
130
-
131
- def showQA(self, anns):
132
- """
133
- Display the specified annotations.
134
- :param anns (array of object): annotations to display
135
- :return: None
136
- """
137
- if len(anns) == 0:
138
- return 0
139
- for ann in anns:
140
- quesId = ann['question_id']
141
- print("Question: %s" % (self.qqa[quesId]['question']))
142
- for ans in ann['answers']:
143
- print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
144
-
145
- def loadRes(self, resFile, quesFile):
146
- """
147
- Load result file and return a result object.
148
- :param resFile (str) : file name of result file
149
- :return: res (obj) : result api object
150
- """
151
- res = VQA()
152
- res.questions = json.load(open(quesFile))
153
- res.dataset['info'] = copy.deepcopy(self.questions['info'])
154
- res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
155
- res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
156
- res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
157
- res.dataset['license'] = copy.deepcopy(self.questions['license'])
158
-
159
- # print 'Loading and preparing results... '
160
- time_t = datetime.datetime.utcnow()
161
- anns = json.load(open(resFile))
162
- assert type(anns) == list, 'results is not an array of objects'
163
- annsQuesIds = [ann['question_id'] for ann in anns]
164
- assert set(annsQuesIds) == set(self.getQuesIds()), \
165
- 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
166
- for ann in anns:
167
- quesId = ann['question_id']
168
- if res.dataset['task_type'] == 'Multiple Choice':
169
- assert ann['answer'] in self.qqa[quesId][
170
- 'multiple_choices'], 'predicted answer is not one of the multiple choices'
171
- qaAnn = self.qa[quesId]
172
- ann['image_id'] = qaAnn['image_id']
173
- ann['question_type'] = qaAnn['question_type']
174
- ann['answer_type'] = qaAnn['answer_type']
175
- # print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
176
-
177
- res.dataset['annotations'] = anns
178
- res.createIndex()
179
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt DELETED
@@ -1,81 +0,0 @@
1
- how many
2
- what color is the
3
- is the
4
- where is the
5
- what
6
- what is
7
- are the
8
- what is the
9
- is there a
10
- does the
11
- is the woman
12
- is the man
13
- what is on the
14
- is it
15
- is the girl
16
- is the boy
17
- is the dog
18
- are they
19
- who is
20
- what kind of
21
- what color are the
22
- what is in the
23
- what is the man
24
- is there
25
- what is the woman
26
- what are the
27
- what is the boy
28
- are there
29
- what is the girl
30
- is this
31
- how
32
- which
33
- how many people are
34
- is the cat
35
- why is the
36
- are
37
- will the
38
- what type of
39
- what is the dog
40
- do
41
- is she
42
- does
43
- do the
44
- is
45
- is the baby
46
- are there any
47
- is the lady
48
- can
49
- what animal is
50
- where are the
51
- is the sun
52
- what are they
53
- did the
54
- what is the cat
55
- what is the lady
56
- how many clouds are
57
- is that
58
- is the little girl
59
- is he
60
- are these
61
- how many trees are
62
- how many pillows
63
- are the people
64
- why
65
- is the young
66
- how many windows are
67
- is this a
68
- what is the little
69
- is the tv
70
- how many animals are
71
- who
72
- how many pictures
73
- how many plants are
74
- how many birds are
75
- what color is
76
- what is the baby
77
- is anyone
78
- what color
79
- how many bushes
80
- is the old man
81
- none of the above
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt DELETED
@@ -1,65 +0,0 @@
1
- how many
2
- is the
3
- what
4
- what color is the
5
- what is the
6
- is this
7
- is this a
8
- what is
9
- are the
10
- what kind of
11
- is there a
12
- what type of
13
- is it
14
- what are the
15
- where is the
16
- is there
17
- does the
18
- what color are the
19
- are these
20
- are there
21
- which
22
- is
23
- what is the man
24
- is the man
25
- are
26
- how
27
- does this
28
- what is on the
29
- what does the
30
- how many people are
31
- what is in the
32
- what is this
33
- do
34
- what are
35
- are they
36
- what time
37
- what sport is
38
- are there any
39
- is he
40
- what color is
41
- why
42
- where are the
43
- what color
44
- who is
45
- what animal is
46
- is the woman
47
- is this an
48
- do you
49
- how many people are in
50
- what room is
51
- has
52
- is this person
53
- what is the woman
54
- can you
55
- why is the
56
- is the person
57
- what is the color of the
58
- what is the person
59
- could
60
- was
61
- is that a
62
- what number is
63
- what is the name
64
- what brand
65
- none of the above
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/README.md DELETED
@@ -1,80 +0,0 @@
1
- Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
2
- ===================
3
- ## VQA v2.0 release ##
4
- This release consists of
5
- - Real
6
- - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
7
- - 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
8
- - 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
9
-
10
- There is only one type of task
11
- - Open-ended task
12
-
13
- ## VQA v1.0 release ##
14
- This release consists of
15
- - Real
16
- - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
17
- - 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
18
- - 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
19
- - Abstract
20
- - 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
21
- - 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
22
- - 600,000 answers for training and 300,000 answers for validation (10 per question)
23
-
24
- There are two types of tasks
25
- - Open-ended task
26
- - Multiple-choice task (18 choices per question)
27
-
28
- ## Requirements ##
29
- - python 2.7
30
- - scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
31
- - matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
32
-
33
- ## Files ##
34
- ./Questions
35
- - For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
36
- - For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
37
- - Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
38
- - [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
39
- - [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
40
- - Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
41
-
42
- ./Annotations
43
- - For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
44
- - For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
45
- - Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
46
- - [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
47
- - [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
48
- - Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
49
-
50
- ./Images
51
- - For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
52
- - For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
53
-
54
- ./PythonHelperTools
55
- - This directory contains the Python API to read and visualize the VQA dataset
56
- - vqaDemo.py (demo script)
57
- - vqaTools (API to read and visualize data)
58
-
59
- ./PythonEvaluationTools
60
- - This directory contains the Python evaluation code
61
- - vqaEvalDemo.py (evaluation demo script)
62
- - vqaEvaluation (evaluation code)
63
-
64
- ./Results
65
- - OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
66
- - Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
67
-
68
- ./QuestionTypes
69
- - This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
70
- - mscoco_question_types.txt
71
- - abstract_v002_question_types.txt
72
-
73
- ## References ##
74
- - [VQA: Visual Question Answering](http://visualqa.org/)
75
- - [Microsoft COCO](http://mscoco.org/)
76
-
77
- ## Developers ##
78
- - Aishwarya Agrawal (Virginia Tech)
79
- - Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
80
- - The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/VQA/license.txt DELETED
@@ -1,30 +0,0 @@
1
- Copyright (c) 2014, Aishwarya Agrawal
2
- All rights reserved.
3
-
4
- Redistribution and use in source and binary forms, with or without
5
- modification, are permitted provided that the following conditions are met:
6
-
7
- 1. Redistributions of source code must retain the above copyright notice, this
8
- list of conditions and the following disclaimer.
9
- 2. Redistributions in binary form must reproduce the above copyright notice,
10
- this list of conditions and the following disclaimer in the documentation
11
- and/or other materials provided with the distribution.
12
-
13
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
- AND
15
- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16
- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
18
- FOR
19
- ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
- (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
- ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
- SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
-
26
- The views and conclusions contained in the software and documentation are
27
- those
28
- of the authors and should not be interpreted as representing official
29
- policies,
30
- either expressed or implied, of the FreeBSD Project.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- __author__ = "aagrawal"
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/vqa.py DELETED
@@ -1,211 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- __author__ = "aagrawal"
9
- __version__ = "0.9"
10
-
11
- # Interface for accessing the VQA dataset.
12
-
13
- # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
14
- # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
15
-
16
- # The following functions are defined:
17
- # VQA - VQA class that loads VQA annotation file and prepares data structures.
18
- # getQuesIds - Get question ids that satisfy given filter conditions.
19
- # getImgIds - Get image ids that satisfy given filter conditions.
20
- # loadQA - Load questions and answers with the specified question ids.
21
- # showQA - Display the specified questions and answers.
22
- # loadRes - Load result file and create result object.
23
-
24
- # Help on each function can be accessed by: "help(COCO.function)"
25
-
26
- import json
27
- import datetime
28
- import copy
29
-
30
-
31
- class VQA:
32
- def __init__(self, annotation_file=None, question_file=None):
33
- """
34
- Constructor of VQA helper class for reading and visualizing questions and answers.
35
- :param annotation_file (str): location of VQA annotation file
36
- :return:
37
- """
38
- # load dataset
39
- self.dataset = {}
40
- self.questions = {}
41
- self.qa = {}
42
- self.qqa = {}
43
- self.imgToQA = {}
44
- if not annotation_file == None and not question_file == None:
45
- print("loading VQA annotations and questions into memory...")
46
- time_t = datetime.datetime.utcnow()
47
- dataset = json.load(open(annotation_file, "r"))
48
- questions = json.load(open(question_file, "r"))
49
- self.dataset = dataset
50
- self.questions = questions
51
- self.createIndex()
52
-
53
- def createIndex(self):
54
- # create index
55
- print("creating index...")
56
- imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
57
- qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
58
- qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
59
- for ann in self.dataset["annotations"]:
60
- imgToQA[ann["image_id"]] += [ann]
61
- qa[ann["question_id"]] = ann
62
- for ques in self.questions["questions"]:
63
- qqa[ques["question_id"]] = ques
64
- print("index created!")
65
-
66
- # create class members
67
- self.qa = qa
68
- self.qqa = qqa
69
- self.imgToQA = imgToQA
70
-
71
- def info(self):
72
- """
73
- Print information about the VQA annotation file.
74
- :return:
75
- """
76
- for key, value in self.datset["info"].items():
77
- print("%s: %s" % (key, value))
78
-
79
- def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
80
- """
81
- Get question ids that satisfy given filter conditions. default skips that filter
82
- :param imgIds (int array) : get question ids for given imgs
83
- quesTypes (str array) : get question ids for given question types
84
- ansTypes (str array) : get question ids for given answer types
85
- :return: ids (int array) : integer array of question ids
86
- """
87
- imgIds = imgIds if type(imgIds) == list else [imgIds]
88
- quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
89
- ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
90
-
91
- if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
92
- anns = self.dataset["annotations"]
93
- else:
94
- if not len(imgIds) == 0:
95
- anns = sum(
96
- [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
97
- [],
98
- )
99
- else:
100
- anns = self.dataset["annotations"]
101
- anns = (
102
- anns
103
- if len(quesTypes) == 0
104
- else [ann for ann in anns if ann["question_type"] in quesTypes]
105
- )
106
- anns = (
107
- anns
108
- if len(ansTypes) == 0
109
- else [ann for ann in anns if ann["answer_type"] in ansTypes]
110
- )
111
- ids = [ann["question_id"] for ann in anns]
112
- return ids
113
-
114
- def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
115
- """
116
- Get image ids that satisfy given filter conditions. default skips that filter
117
- :param quesIds (int array) : get image ids for given question ids
118
- quesTypes (str array) : get image ids for given question types
119
- ansTypes (str array) : get image ids for given answer types
120
- :return: ids (int array) : integer array of image ids
121
- """
122
- quesIds = quesIds if type(quesIds) == list else [quesIds]
123
- quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
124
- ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
125
-
126
- if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
127
- anns = self.dataset["annotations"]
128
- else:
129
- if not len(quesIds) == 0:
130
- anns = sum(
131
- [self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
132
- )
133
- else:
134
- anns = self.dataset["annotations"]
135
- anns = (
136
- anns
137
- if len(quesTypes) == 0
138
- else [ann for ann in anns if ann["question_type"] in quesTypes]
139
- )
140
- anns = (
141
- anns
142
- if len(ansTypes) == 0
143
- else [ann for ann in anns if ann["answer_type"] in ansTypes]
144
- )
145
- ids = [ann["image_id"] for ann in anns]
146
- return ids
147
-
148
- def loadQA(self, ids=[]):
149
- """
150
- Load questions and answers with the specified question ids.
151
- :param ids (int array) : integer ids specifying question ids
152
- :return: qa (object array) : loaded qa objects
153
- """
154
- if type(ids) == list:
155
- return [self.qa[id] for id in ids]
156
- elif type(ids) == int:
157
- return [self.qa[ids]]
158
-
159
- def showQA(self, anns):
160
- """
161
- Display the specified annotations.
162
- :param anns (array of object): annotations to display
163
- :return: None
164
- """
165
- if len(anns) == 0:
166
- return 0
167
- for ann in anns:
168
- quesId = ann["question_id"]
169
- print("Question: %s" % (self.qqa[quesId]["question"]))
170
- for ans in ann["answers"]:
171
- print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
172
-
173
- def loadRes(self, resFile, quesFile):
174
- """
175
- Load result file and return a result object.
176
- :param resFile (str) : file name of result file
177
- :return: res (obj) : result api object
178
- """
179
- res = VQA()
180
- res.questions = json.load(open(quesFile))
181
- res.dataset["info"] = copy.deepcopy(self.questions["info"])
182
- res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
183
- res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
184
- res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
185
- res.dataset["license"] = copy.deepcopy(self.questions["license"])
186
-
187
- print("Loading and preparing results... ")
188
- time_t = datetime.datetime.utcnow()
189
- anns = json.load(open(resFile))
190
- assert type(anns) == list, "results is not an array of objects"
191
- annsQuesIds = [ann["question_id"] for ann in anns]
192
- assert set(annsQuesIds) == set(
193
- self.getQuesIds()
194
- ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
195
- for ann in anns:
196
- quesId = ann["question_id"]
197
- if res.dataset["task_type"] == "Multiple Choice":
198
- assert (
199
- ann["answer"] in self.qqa[quesId]["multiple_choices"]
200
- ), "predicted answer is not one of the multiple choices"
201
- qaAnn = self.qa[quesId]
202
- ann["image_id"] = qaAnn["image_id"]
203
- ann["question_type"] = qaAnn["question_type"]
204
- ann["answer_type"] = qaAnn["answer_type"]
205
- print(
206
- "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
207
- )
208
-
209
- res.dataset["annotations"] = anns
210
- res.createIndex()
211
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/common/vqa_tools/vqa_eval.py DELETED
@@ -1,324 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- # coding=utf-8
9
-
10
- __author__ = "aagrawal"
11
-
12
- # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
13
- # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
14
- import sys
15
- import re
16
-
17
-
18
- class VQAEval:
19
- def __init__(self, vqa=None, vqaRes=None, n=2):
20
- self.n = n
21
- self.accuracy = {}
22
- self.evalQA = {}
23
- self.evalQuesType = {}
24
- self.evalAnsType = {}
25
- self.vqa = vqa
26
- self.vqaRes = vqaRes
27
- if vqa is not None:
28
- self.params = {"question_id": vqa.getQuesIds()}
29
- self.contractions = {
30
- "aint": "ain't",
31
- "arent": "aren't",
32
- "cant": "can't",
33
- "couldve": "could've",
34
- "couldnt": "couldn't",
35
- "couldn'tve": "couldn't've",
36
- "couldnt've": "couldn't've",
37
- "didnt": "didn't",
38
- "doesnt": "doesn't",
39
- "dont": "don't",
40
- "hadnt": "hadn't",
41
- "hadnt've": "hadn't've",
42
- "hadn'tve": "hadn't've",
43
- "hasnt": "hasn't",
44
- "havent": "haven't",
45
- "hed": "he'd",
46
- "hed've": "he'd've",
47
- "he'dve": "he'd've",
48
- "hes": "he's",
49
- "howd": "how'd",
50
- "howll": "how'll",
51
- "hows": "how's",
52
- "Id've": "I'd've",
53
- "I'dve": "I'd've",
54
- "Im": "I'm",
55
- "Ive": "I've",
56
- "isnt": "isn't",
57
- "itd": "it'd",
58
- "itd've": "it'd've",
59
- "it'dve": "it'd've",
60
- "itll": "it'll",
61
- "let's": "let's",
62
- "maam": "ma'am",
63
- "mightnt": "mightn't",
64
- "mightnt've": "mightn't've",
65
- "mightn'tve": "mightn't've",
66
- "mightve": "might've",
67
- "mustnt": "mustn't",
68
- "mustve": "must've",
69
- "neednt": "needn't",
70
- "notve": "not've",
71
- "oclock": "o'clock",
72
- "oughtnt": "oughtn't",
73
- "ow's'at": "'ow's'at",
74
- "'ows'at": "'ow's'at",
75
- "'ow'sat": "'ow's'at",
76
- "shant": "shan't",
77
- "shed've": "she'd've",
78
- "she'dve": "she'd've",
79
- "she's": "she's",
80
- "shouldve": "should've",
81
- "shouldnt": "shouldn't",
82
- "shouldnt've": "shouldn't've",
83
- "shouldn'tve": "shouldn't've",
84
- "somebody'd": "somebodyd",
85
- "somebodyd've": "somebody'd've",
86
- "somebody'dve": "somebody'd've",
87
- "somebodyll": "somebody'll",
88
- "somebodys": "somebody's",
89
- "someoned": "someone'd",
90
- "someoned've": "someone'd've",
91
- "someone'dve": "someone'd've",
92
- "someonell": "someone'll",
93
- "someones": "someone's",
94
- "somethingd": "something'd",
95
- "somethingd've": "something'd've",
96
- "something'dve": "something'd've",
97
- "somethingll": "something'll",
98
- "thats": "that's",
99
- "thered": "there'd",
100
- "thered've": "there'd've",
101
- "there'dve": "there'd've",
102
- "therere": "there're",
103
- "theres": "there's",
104
- "theyd": "they'd",
105
- "theyd've": "they'd've",
106
- "they'dve": "they'd've",
107
- "theyll": "they'll",
108
- "theyre": "they're",
109
- "theyve": "they've",
110
- "twas": "'twas",
111
- "wasnt": "wasn't",
112
- "wed've": "we'd've",
113
- "we'dve": "we'd've",
114
- "weve": "we've",
115
- "werent": "weren't",
116
- "whatll": "what'll",
117
- "whatre": "what're",
118
- "whats": "what's",
119
- "whatve": "what've",
120
- "whens": "when's",
121
- "whered": "where'd",
122
- "wheres": "where's",
123
- "whereve": "where've",
124
- "whod": "who'd",
125
- "whod've": "who'd've",
126
- "who'dve": "who'd've",
127
- "wholl": "who'll",
128
- "whos": "who's",
129
- "whove": "who've",
130
- "whyll": "why'll",
131
- "whyre": "why're",
132
- "whys": "why's",
133
- "wont": "won't",
134
- "wouldve": "would've",
135
- "wouldnt": "wouldn't",
136
- "wouldnt've": "wouldn't've",
137
- "wouldn'tve": "wouldn't've",
138
- "yall": "y'all",
139
- "yall'll": "y'all'll",
140
- "y'allll": "y'all'll",
141
- "yall'd've": "y'all'd've",
142
- "y'alld've": "y'all'd've",
143
- "y'all'dve": "y'all'd've",
144
- "youd": "you'd",
145
- "youd've": "you'd've",
146
- "you'dve": "you'd've",
147
- "youll": "you'll",
148
- "youre": "you're",
149
- "youve": "you've",
150
- }
151
- self.manualMap = {
152
- "none": "0",
153
- "zero": "0",
154
- "one": "1",
155
- "two": "2",
156
- "three": "3",
157
- "four": "4",
158
- "five": "5",
159
- "six": "6",
160
- "seven": "7",
161
- "eight": "8",
162
- "nine": "9",
163
- "ten": "10",
164
- }
165
- self.articles = ["a", "an", "the"]
166
-
167
- self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
168
- self.commaStrip = re.compile("(\d)(,)(\d)")
169
- self.punct = [
170
- ";",
171
- r"/",
172
- "[",
173
- "]",
174
- '"',
175
- "{",
176
- "}",
177
- "(",
178
- ")",
179
- "=",
180
- "+",
181
- "\\",
182
- "_",
183
- "-",
184
- ">",
185
- "<",
186
- "@",
187
- "`",
188
- ",",
189
- "?",
190
- "!",
191
- ]
192
-
193
- def evaluate(self, quesIds=None):
194
- if quesIds == None:
195
- quesIds = [quesId for quesId in self.params["question_id"]]
196
- gts = {}
197
- res = {}
198
- for quesId in quesIds:
199
- gts[quesId] = self.vqa.qa[quesId]
200
- res[quesId] = self.vqaRes.qa[quesId]
201
-
202
- # =================================================
203
- # Compute accuracy
204
- # =================================================
205
- accQA = []
206
- accQuesType = {}
207
- accAnsType = {}
208
- print("computing accuracy")
209
- step = 0
210
- for quesId in quesIds:
211
- resAns = res[quesId]["answer"]
212
- resAns = resAns.replace("\n", " ")
213
- resAns = resAns.replace("\t", " ")
214
- resAns = resAns.strip()
215
- resAns = self.processPunctuation(resAns)
216
- resAns = self.processDigitArticle(resAns)
217
- gtAcc = []
218
- gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
219
- if len(set(gtAnswers)) > 1:
220
- for ansDic in gts[quesId]["answers"]:
221
- ansDic["answer"] = self.processPunctuation(ansDic["answer"])
222
- for gtAnsDatum in gts[quesId]["answers"]:
223
- otherGTAns = [
224
- item for item in gts[quesId]["answers"] if item != gtAnsDatum
225
- ]
226
- matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
227
- acc = min(1, float(len(matchingAns)) / 3)
228
- gtAcc.append(acc)
229
- quesType = gts[quesId]["question_type"]
230
- ansType = gts[quesId]["answer_type"]
231
- avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
232
- accQA.append(avgGTAcc)
233
- if quesType not in accQuesType:
234
- accQuesType[quesType] = []
235
- accQuesType[quesType].append(avgGTAcc)
236
- if ansType not in accAnsType:
237
- accAnsType[ansType] = []
238
- accAnsType[ansType].append(avgGTAcc)
239
- self.setEvalQA(quesId, avgGTAcc)
240
- self.setEvalQuesType(quesId, quesType, avgGTAcc)
241
- self.setEvalAnsType(quesId, ansType, avgGTAcc)
242
- if step % 100 == 0:
243
- self.updateProgress(step / float(len(quesIds)))
244
- step = step + 1
245
-
246
- self.setAccuracy(accQA, accQuesType, accAnsType)
247
- print("Done computing accuracy")
248
-
249
- def processPunctuation(self, inText):
250
- outText = inText
251
- for p in self.punct:
252
- if (p + " " in inText or " " + p in inText) or (
253
- re.search(self.commaStrip, inText) != None
254
- ):
255
- outText = outText.replace(p, "")
256
- else:
257
- outText = outText.replace(p, " ")
258
- outText = self.periodStrip.sub("", outText, re.UNICODE)
259
- return outText
260
-
261
- def processDigitArticle(self, inText):
262
- outText = []
263
- tempText = inText.lower().split()
264
- for word in tempText:
265
- word = self.manualMap.setdefault(word, word)
266
- if word not in self.articles:
267
- outText.append(word)
268
- else:
269
- pass
270
- for wordId, word in enumerate(outText):
271
- if word in self.contractions:
272
- outText[wordId] = self.contractions[word]
273
- outText = " ".join(outText)
274
- return outText
275
-
276
- def setAccuracy(self, accQA, accQuesType, accAnsType):
277
- self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
278
- self.accuracy["perQuestionType"] = {
279
- quesType: round(
280
- 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
281
- self.n,
282
- )
283
- for quesType in accQuesType
284
- }
285
- self.accuracy["perAnswerType"] = {
286
- ansType: round(
287
- 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
288
- )
289
- for ansType in accAnsType
290
- }
291
-
292
- def setEvalQA(self, quesId, acc):
293
- self.evalQA[quesId] = round(100 * acc, self.n)
294
-
295
- def setEvalQuesType(self, quesId, quesType, acc):
296
- if quesType not in self.evalQuesType:
297
- self.evalQuesType[quesType] = {}
298
- self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
299
-
300
- def setEvalAnsType(self, quesId, ansType, acc):
301
- if ansType not in self.evalAnsType:
302
- self.evalAnsType[ansType] = {}
303
- self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
304
-
305
- def updateProgress(self, progress):
306
- barLength = 20
307
- status = ""
308
- if isinstance(progress, int):
309
- progress = float(progress)
310
- if not isinstance(progress, float):
311
- progress = 0
312
- status = "error: progress var must be float\r\n"
313
- if progress < 0:
314
- progress = 0
315
- status = "Halt...\r\n"
316
- if progress >= 1:
317
- progress = 1
318
- status = "Done...\r\n"
319
- block = int(round(barLength * progress))
320
- text = "\rFinshed Percent: [{0}] {1}% {2}".format(
321
- "#" * block + "-" * (barLength - block), int(progress * 100), status
322
- )
323
- sys.stdout.write(text)
324
- sys.stdout.flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/configs/datasets/cc_combine/align.yaml DELETED
@@ -1,16 +0,0 @@
1
- # Copyright (c) 2022, salesforce.com, inc.
2
- # All rights reserved.
3
- # SPDX-License-Identifier: BSD-3-Clause
4
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
-
6
- datasets:
7
- cc_align:
8
- data_type: images
9
- build_info:
10
- # Be careful not to append minus sign (-) before split to avoid itemizing
11
- annotations:
12
- train:
13
- url: placeholder
14
- storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/filter_cap.json
15
- images:
16
- storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/configs/datasets/cc_combine/defaults.yaml DELETED
@@ -1,11 +0,0 @@
1
- # Copyright (c) 2022, salesforce.com, inc.
2
- # All rights reserved.
3
- # SPDX-License-Identifier: BSD-3-Clause
4
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
-
6
- datasets:
7
- cc_combine:
8
- data_type: images
9
- build_info:
10
- # Be careful not to append minus sign (-) before split to avoid itemizing
11
- storage: /ibex/project/c2133/blip_dataset/cc3m/cc3m_cc12m_sbu/{00000..01255}.tar
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/configs/datasets/laion/defaults.yaml DELETED
@@ -1,13 +0,0 @@
1
- # Copyright (c) 2022, salesforce.com, inc.
2
- # All rights reserved.
3
- # SPDX-License-Identifier: BSD-3-Clause
4
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
-
6
- datasets:
7
- laion:
8
-
9
- data_type: images
10
-
11
- build_info:
12
- # Be careful not to append minus sign (-) before split to avoid itemizing
13
- storage: /ibex/project/c2133/blip_dataset/laion_1b/laion_gpu/{00000..10488}.tar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/configs/default.yaml DELETED
@@ -1,5 +0,0 @@
1
- env:
2
- # For default users
3
- # cache_root: "cache"
4
- # For internal use with persistent storage
5
- cache_root: "/export/home/.cache/minigpt4"
 
 
 
 
 
 
minigpt4/configs/models/minigpt4_vicuna0.yaml DELETED
@@ -1,32 +0,0 @@
1
- model:
2
- arch: minigpt4
3
-
4
- # vit encoder
5
- image_size: 224
6
- drop_path_rate: 0
7
- use_grad_checkpoint: False
8
- vit_precision: "fp16"
9
- freeze_vit: True
10
- freeze_qformer: True
11
-
12
- # Q-Former
13
- num_query_token: 32
14
-
15
- # generation configs
16
- prompt: ""
17
-
18
- llama_model: "wangrongsheng/MiniGPT-4-LLaMA-7B"
19
-
20
- preprocess:
21
- vis_processor:
22
- train:
23
- name: "blip2_image_train"
24
- image_size: 224
25
- eval:
26
- name: "blip2_image_eval"
27
- image_size: 224
28
- text_processor:
29
- train:
30
- name: "blip_caption"
31
- eval:
32
- name: "blip_caption"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/conversation/__init__.py DELETED
File without changes
minigpt4/conversation/conversation.py DELETED
@@ -1,233 +0,0 @@
1
- import argparse
2
- import time
3
- from threading import Thread
4
- from PIL import Image
5
-
6
- import torch
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
8
- from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
9
-
10
- import dataclasses
11
- from enum import auto, Enum
12
- from typing import List, Tuple, Any
13
-
14
- from minigpt4.common.registry import registry
15
-
16
-
17
- class SeparatorStyle(Enum):
18
- """Different separator style."""
19
- SINGLE = auto()
20
- TWO = auto()
21
-
22
-
23
- @dataclasses.dataclass
24
- class Conversation:
25
- """A class that keeps all conversation history."""
26
- system: str
27
- roles: List[str]
28
- messages: List[List[str]]
29
- offset: int
30
- # system_img: List[Image.Image] = []
31
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
32
- sep: str = "###"
33
- sep2: str = None
34
-
35
- skip_next: bool = False
36
- conv_id: Any = None
37
-
38
- def get_prompt(self):
39
- if self.sep_style == SeparatorStyle.SINGLE:
40
- ret = self.system + self.sep
41
- for role, message in self.messages:
42
- if message:
43
- ret += role + message + self.sep
44
- else:
45
- ret += role
46
- return ret
47
- elif self.sep_style == SeparatorStyle.TWO:
48
- seps = [self.sep, self.sep2]
49
- ret = self.system + seps[0]
50
- for i, (role, message) in enumerate(self.messages):
51
- if message:
52
- ret += role + message + seps[i % 2]
53
- else:
54
- ret += role
55
- return ret
56
- else:
57
- raise ValueError(f"Invalid style: {self.sep_style}")
58
-
59
- def append_message(self, role, message):
60
- self.messages.append([role, message])
61
-
62
- def to_gradio_chatbot(self):
63
- ret = []
64
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
65
- if i % 2 == 0:
66
- ret.append([msg, None])
67
- else:
68
- ret[-1][-1] = msg
69
- return ret
70
-
71
- def copy(self):
72
- return Conversation(
73
- system=self.system,
74
- # system_img=self.system_img,
75
- roles=self.roles,
76
- messages=[[x, y] for x, y in self.messages],
77
- offset=self.offset,
78
- sep_style=self.sep_style,
79
- sep=self.sep,
80
- sep2=self.sep2,
81
- conv_id=self.conv_id)
82
-
83
- def dict(self):
84
- return {
85
- "system": self.system,
86
- # "system_img": self.system_img,
87
- "roles": self.roles,
88
- "messages": self.messages,
89
- "offset": self.offset,
90
- "sep": self.sep,
91
- "sep2": self.sep2,
92
- "conv_id": self.conv_id,
93
- }
94
-
95
-
96
- class StoppingCriteriaSub(StoppingCriteria):
97
-
98
- def __init__(self, stops=[], encounters=1):
99
- super().__init__()
100
- self.stops = stops
101
-
102
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
103
- for stop in self.stops:
104
- if torch.all(input_ids[:, -len(stop):] == stop).item():
105
- return True
106
-
107
- return False
108
-
109
-
110
- CONV_VISION_Vicuna0 = Conversation(
111
- system="Give the following image: <Img>ImageContent</Img>. "
112
- "You will be able to see the image once I provide it to you. Please answer my questions.",
113
- roles=("Human: ", "Assistant: "),
114
- messages=[],
115
- offset=2,
116
- sep_style=SeparatorStyle.SINGLE,
117
- sep="###",
118
- )
119
-
120
- CONV_VISION_LLama2 = Conversation(
121
- system="Give the following image: <Img>ImageContent</Img>. "
122
- "You will be able to see the image once I provide it to you. Please answer my questions.",
123
- roles=("<s>[INST] ", " [/INST] "),
124
- messages=[],
125
- offset=2,
126
- sep_style=SeparatorStyle.SINGLE,
127
- sep="",
128
- )
129
-
130
- CONV_VISION_minigptv2 = Conversation(
131
- system="",
132
- roles=("<s>[INST] ", " [/INST]"),
133
- messages=[],
134
- offset=2,
135
- sep_style=SeparatorStyle.SINGLE,
136
- sep="",
137
- )
138
-
139
- class Chat:
140
- def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
141
- self.device = device
142
- self.model = model
143
- self.vis_processor = vis_processor
144
-
145
- if stopping_criteria is not None:
146
- self.stopping_criteria = stopping_criteria
147
- else:
148
- stop_words_ids = [torch.tensor([2]).to(self.device)]
149
- self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
150
-
151
- def ask(self, text, conv):
152
- if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
153
- and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
154
- conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
155
- else:
156
- conv.append_message(conv.roles[0], text)
157
-
158
- def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
159
- repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
160
- conv.append_message(conv.roles[1], None)
161
- prompt = conv.get_prompt()
162
- embs = self.model.get_context_emb(prompt, img_list)
163
-
164
- current_max_len = embs.shape[1] + max_new_tokens
165
- if current_max_len - max_length > 0:
166
- print('Warning: The number of tokens in current conversation exceeds the max length. '
167
- 'The model will not see the contexts outside the range.')
168
- begin_idx = max(0, current_max_len - max_length)
169
- embs = embs[:, begin_idx:]
170
-
171
- generation_kwargs = dict(
172
- inputs_embeds=embs,
173
- max_new_tokens=max_new_tokens,
174
- stopping_criteria=self.stopping_criteria,
175
- num_beams=num_beams,
176
- do_sample=True,
177
- min_length=min_length,
178
- top_p=top_p,
179
- repetition_penalty=repetition_penalty,
180
- length_penalty=length_penalty,
181
- temperature=float(temperature),
182
- )
183
- return generation_kwargs
184
-
185
- def answer(self, conv, img_list, **kargs):
186
- generation_dict = self.answer_prepare(conv, img_list, **kargs)
187
- output_token = self.model_generate(**generation_dict)[0]
188
- output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
189
-
190
- output_text = output_text.split('###')[0] # remove the stop sign '###'
191
- output_text = output_text.split('Assistant:')[-1].strip()
192
-
193
- conv.messages[-1][1] = output_text
194
- return output_text, output_token.cpu().numpy()
195
-
196
- def stream_answer(self, conv, img_list, **kargs):
197
- generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
198
- streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
199
- generation_kwargs['streamer'] = streamer
200
- thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
201
- thread.start()
202
- return streamer
203
-
204
- def model_generate(self, *args, **kwargs):
205
- # for 8 bit and 16 bit compatibility
206
- with self.model.maybe_autocast():
207
- output = self.model.llama_model.generate(*args, **kwargs)
208
- return output
209
-
210
- def encode_img(self, img_list):
211
- image = img_list[0]
212
- img_list.pop(0)
213
- if isinstance(image, str): # is a image path
214
- raw_image = Image.open(image).convert('RGB')
215
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
216
- elif isinstance(image, Image.Image):
217
- raw_image = image
218
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
219
- elif isinstance(image, torch.Tensor):
220
- if len(image.shape) == 3:
221
- image = image.unsqueeze(0)
222
- image = image.to(self.device)
223
-
224
- image_emb, _ = self.model.encode_img(image)
225
- img_list.append(image_emb)
226
-
227
- def upload_img(self, image, conv, img_list):
228
- conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
229
- img_list.append(image)
230
- msg = "Received."
231
-
232
- return msg
233
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/__init__.py DELETED
File without changes
minigpt4/datasets/builders/__init__.py DELETED
@@ -1,72 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
9
- from minigpt4.datasets.builders.image_text_pair_builder import (
10
- CCSBUBuilder,
11
- LaionBuilder,
12
- CCSBUAlignBuilder
13
- )
14
- from minigpt4.common.registry import registry
15
-
16
- __all__ = [
17
- "CCSBUBuilder",
18
- "LaionBuilder",
19
- "CCSBUAlignBuilder"
20
- ]
21
-
22
-
23
- def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
24
- """
25
- Example
26
-
27
- >>> dataset = load_dataset("coco_caption", cfg=None)
28
- >>> splits = dataset.keys()
29
- >>> print([len(dataset[split]) for split in splits])
30
-
31
- """
32
- if cfg_path is None:
33
- cfg = None
34
- else:
35
- cfg = load_dataset_config(cfg_path)
36
-
37
- try:
38
- builder = registry.get_builder_class(name)(cfg)
39
- except TypeError:
40
- print(
41
- f"Dataset {name} not found. Available datasets:\n"
42
- + ", ".join([str(k) for k in dataset_zoo.get_names()])
43
- )
44
- exit(1)
45
-
46
- if vis_path is not None:
47
- if data_type is None:
48
- # use default data type in the config
49
- data_type = builder.config.data_type
50
-
51
- assert (
52
- data_type in builder.config.build_info
53
- ), f"Invalid data_type {data_type} for {name}."
54
-
55
- builder.config.build_info.get(data_type).storage = vis_path
56
-
57
- dataset = builder.build_datasets()
58
- return dataset
59
-
60
-
61
- class DatasetZoo:
62
- def __init__(self) -> None:
63
- self.dataset_zoo = {
64
- k: list(v.DATASET_CONFIG_DICT.keys())
65
- for k, v in sorted(registry.mapping["builder_name_mapping"].items())
66
- }
67
-
68
- def get_names(self):
69
- return list(self.dataset_zoo.keys())
70
-
71
-
72
- dataset_zoo = DatasetZoo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/builders/base_dataset_builder.py DELETED
@@ -1,236 +0,0 @@
1
- """
2
- This file is from
3
- Copyright (c) 2022, salesforce.com, inc.
4
- All rights reserved.
5
- SPDX-License-Identifier: BSD-3-Clause
6
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
- """
8
-
9
- import logging
10
- import os
11
- import shutil
12
- import warnings
13
-
14
- from omegaconf import OmegaConf
15
- import torch.distributed as dist
16
- from torchvision.datasets.utils import download_url
17
-
18
- import minigpt4.common.utils as utils
19
- from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20
- from minigpt4.common.registry import registry
21
- from minigpt4.processors.base_processor import BaseProcessor
22
-
23
-
24
-
25
- class BaseDatasetBuilder:
26
- train_dataset_cls, eval_dataset_cls = None, None
27
-
28
- def __init__(self, cfg=None):
29
- super().__init__()
30
-
31
- if cfg is None:
32
- # help to create datasets from default config.
33
- self.config = load_dataset_config(self.default_config_path())
34
- elif isinstance(cfg, str):
35
- self.config = load_dataset_config(cfg)
36
- else:
37
- # when called from task.build_dataset()
38
- self.config = cfg
39
-
40
- self.data_type = self.config.data_type
41
-
42
- self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43
- self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
44
-
45
- def build_datasets(self):
46
- # download, split, etc...
47
- # only called on 1 GPU/TPU in distributed
48
-
49
- if is_main_process():
50
- self._download_data()
51
-
52
- if is_dist_avail_and_initialized():
53
- dist.barrier()
54
-
55
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
56
- logging.info("Building datasets...")
57
- datasets = self.build() # dataset['train'/'val'/'test']
58
-
59
- return datasets
60
-
61
- def build_processors(self):
62
- vis_proc_cfg = self.config.get("vis_processor")
63
- txt_proc_cfg = self.config.get("text_processor")
64
-
65
- if vis_proc_cfg is not None:
66
- vis_train_cfg = vis_proc_cfg.get("train")
67
- vis_eval_cfg = vis_proc_cfg.get("eval")
68
-
69
- self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
70
- self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
71
-
72
- if txt_proc_cfg is not None:
73
- txt_train_cfg = txt_proc_cfg.get("train")
74
- txt_eval_cfg = txt_proc_cfg.get("eval")
75
-
76
- self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
77
- self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
78
-
79
- @staticmethod
80
- def _build_proc_from_cfg(cfg):
81
- return (
82
- registry.get_processor_class(cfg.name).from_config(cfg)
83
- if cfg is not None
84
- else None
85
- )
86
-
87
- @classmethod
88
- def default_config_path(cls, type="default"):
89
- return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
90
-
91
- def _download_data(self):
92
- self._download_ann()
93
- self._download_vis()
94
-
95
- def _download_ann(self):
96
- """
97
- Download annotation files if necessary.
98
- All the vision-language datasets should have annotations of unified format.
99
-
100
- storage_path can be:
101
- (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
102
- (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
103
-
104
- Local annotation paths should be relative.
105
- """
106
- anns = self.config.build_info.annotations
107
-
108
- splits = anns.keys()
109
-
110
- cache_root = registry.get_path("cache_root")
111
-
112
- for split in splits:
113
- info = anns[split]
114
-
115
- urls, storage_paths = info.get("url", None), info.storage
116
-
117
- if isinstance(urls, str):
118
- urls = [urls]
119
- if isinstance(storage_paths, str):
120
- storage_paths = [storage_paths]
121
-
122
- assert len(urls) == len(storage_paths)
123
-
124
- for url_or_filename, storage_path in zip(urls, storage_paths):
125
- # if storage_path is relative, make it full by prefixing with cache_root.
126
- if not os.path.isabs(storage_path):
127
- storage_path = os.path.join(cache_root, storage_path)
128
-
129
- dirname = os.path.dirname(storage_path)
130
- if not os.path.exists(dirname):
131
- os.makedirs(dirname)
132
-
133
- if os.path.isfile(url_or_filename):
134
- src, dst = url_or_filename, storage_path
135
- if not os.path.exists(dst):
136
- shutil.copyfile(src=src, dst=dst)
137
- else:
138
- logging.info("Using existing file {}.".format(dst))
139
- else:
140
- if os.path.isdir(storage_path):
141
- # if only dirname is provided, suffix with basename of URL.
142
- raise ValueError(
143
- "Expecting storage_path to be a file path, got directory {}".format(
144
- storage_path
145
- )
146
- )
147
- else:
148
- filename = os.path.basename(storage_path)
149
-
150
- download_url(url=url_or_filename, root=dirname, filename=filename)
151
-
152
- def _download_vis(self):
153
-
154
- storage_path = self.config.build_info.get(self.data_type).storage
155
- storage_path = utils.get_cache_path(storage_path)
156
-
157
- if not os.path.exists(storage_path):
158
- warnings.warn(
159
- f"""
160
- The specified path {storage_path} for visual inputs does not exist.
161
- Please provide a correct path to the visual inputs or
162
- refer to datasets/download_scripts/README.md for downloading instructions.
163
- """
164
- )
165
-
166
- def build(self):
167
- """
168
- Create by split datasets inheriting torch.utils.data.Datasets.
169
-
170
- # build() can be dataset-specific. Overwrite to customize.
171
- """
172
- self.build_processors()
173
-
174
- build_info = self.config.build_info
175
-
176
- ann_info = build_info.annotations
177
- vis_info = build_info.get(self.data_type)
178
-
179
- datasets = dict()
180
- for split in ann_info.keys():
181
- if split not in ["train", "val", "test"]:
182
- continue
183
-
184
- is_train = split == "train"
185
-
186
- # processors
187
- vis_processor = (
188
- self.vis_processors["train"]
189
- if is_train
190
- else self.vis_processors["eval"]
191
- )
192
- text_processor = (
193
- self.text_processors["train"]
194
- if is_train
195
- else self.text_processors["eval"]
196
- )
197
-
198
- # annotation path
199
- ann_paths = ann_info.get(split).storage
200
- if isinstance(ann_paths, str):
201
- ann_paths = [ann_paths]
202
-
203
- abs_ann_paths = []
204
- for ann_path in ann_paths:
205
- if not os.path.isabs(ann_path):
206
- ann_path = utils.get_cache_path(ann_path)
207
- abs_ann_paths.append(ann_path)
208
- ann_paths = abs_ann_paths
209
-
210
- # visual data storage path
211
- vis_path = os.path.join(vis_info.storage, split)
212
-
213
- if not os.path.isabs(vis_path):
214
- # vis_path = os.path.join(utils.get_cache_path(), vis_path)
215
- vis_path = utils.get_cache_path(vis_path)
216
-
217
- if not os.path.exists(vis_path):
218
- warnings.warn("storage path {} does not exist.".format(vis_path))
219
-
220
- # create datasets
221
- dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
222
- datasets[split] = dataset_cls(
223
- vis_processor=vis_processor,
224
- text_processor=text_processor,
225
- ann_paths=ann_paths,
226
- vis_root=vis_path,
227
- )
228
-
229
- return datasets
230
-
231
-
232
- def load_dataset_config(cfg_path):
233
- cfg = OmegaConf.load(cfg_path).datasets
234
- cfg = cfg[list(cfg.keys())[0]]
235
-
236
- return cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/builders/image_text_pair_builder.py DELETED
@@ -1,535 +0,0 @@
1
- import os
2
- import logging
3
- import warnings
4
-
5
- from minigpt4.common.registry import registry
6
- from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
- from minigpt4.datasets.datasets.laion_dataset import LaionDataset
8
- from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
9
- from minigpt4.datasets.datasets.text_caps import TextCapDataset
10
- from minigpt4.datasets.datasets.llava_dataset import LlavaDetailDataset, LlavaReasonDataset, LlavaConversationDataset
11
- from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset
12
- from minigpt4.datasets.datasets.multitask_conversation import MultiTaskConversationDataset
13
- from minigpt4.datasets.datasets.flickr import GroundedDetailDataset,CaptionToObjectDataset,PhraseToObjectDataset
14
- from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
15
- from minigpt4.datasets.datasets.coco_dataset import ReferCOCODataset, InvReferCOCODataset
16
- from minigpt4.datasets.datasets.gqa_datasets import GQADataset
17
- from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset
18
- from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
19
- from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
20
- from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
21
-
22
-
23
- @registry.register_builder("multitask_conversation")
24
- class MultitaskConversationBuilder(BaseDatasetBuilder):
25
- train_dataset_cls = MultiTaskConversationDataset
26
- DATASET_CONFIG_DICT = {
27
- "default": "configs/datasets/multitask_conversation/default.yaml",
28
- }
29
-
30
- def build_datasets(self):
31
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
32
- logging.info("Building datasets...")
33
- self.build_processors()
34
- build_info = self.config.build_info
35
- datasets = dict()
36
-
37
- # create datasets
38
- dataset_cls = self.train_dataset_cls
39
- datasets['train'] = dataset_cls(
40
- vis_processor=self.vis_processors["train"],
41
- text_processor=self.text_processors["train"],
42
- ann_path=build_info.ann_path,
43
- vis_root=build_info.image_path,
44
- )
45
-
46
- return datasets
47
-
48
-
49
- @registry.register_builder("unnatural_instruction")
50
- class UnnaturalInstructionBuilder(BaseDatasetBuilder):
51
- train_dataset_cls = UnnaturalDataset
52
- DATASET_CONFIG_DICT = {
53
- "default": "configs/datasets/nlp/unnatural_instruction.yaml",
54
- }
55
-
56
- def build_datasets(self):
57
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
58
- logging.info("Building datasets...")
59
- self.build_processors()
60
- build_info = self.config.build_info
61
- datasets = dict()
62
-
63
- # create datasets
64
- dataset_cls = self.train_dataset_cls
65
- datasets['train'] = dataset_cls(
66
- text_processor=self.text_processors["train"],
67
- ann_path=build_info.ann_path,
68
- )
69
-
70
- return datasets
71
-
72
-
73
-
74
- @registry.register_builder("llava_detail")
75
- class LlavaDetailBuilder(BaseDatasetBuilder):
76
- train_dataset_cls = LlavaDetailDataset
77
- DATASET_CONFIG_DICT = {
78
- "default": "configs/datasets/llava/detail.yaml",
79
- }
80
-
81
- def build_datasets(self):
82
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
83
- logging.info("Building datasets...")
84
- self.build_processors()
85
- build_info = self.config.build_info
86
- datasets = dict()
87
-
88
- # create datasets
89
- dataset_cls = self.train_dataset_cls
90
- datasets['train'] = dataset_cls(
91
- vis_processor=self.vis_processors["train"],
92
- text_processor=self.text_processors["train"],
93
- ann_path=build_info.ann_path,
94
- vis_root=build_info.image_path,
95
- )
96
-
97
- return datasets
98
-
99
-
100
-
101
- @registry.register_builder("llava_reason")
102
- class LlavaReasonBuilder(BaseDatasetBuilder):
103
- train_dataset_cls = LlavaReasonDataset
104
- DATASET_CONFIG_DICT = {
105
- "default": "configs/datasets/llava/reason.yaml",
106
- }
107
-
108
- def build_datasets(self):
109
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
110
- logging.info("Building datasets...")
111
- self.build_processors()
112
- build_info = self.config.build_info
113
- datasets = dict()
114
-
115
- # create datasets
116
- dataset_cls = self.train_dataset_cls
117
- datasets['train'] = dataset_cls(
118
- vis_processor=self.vis_processors["train"],
119
- text_processor=self.text_processors["train"],
120
- ann_path=build_info.ann_path,
121
- vis_root=build_info.image_path,
122
- )
123
-
124
- return datasets
125
-
126
- @registry.register_builder("llava_conversation")
127
- class LlavaReasonBuilder(BaseDatasetBuilder):
128
- train_dataset_cls = LlavaConversationDataset
129
- DATASET_CONFIG_DICT = {
130
- "default": "configs/datasets/llava/conversation.yaml",
131
- }
132
-
133
- def build_datasets(self):
134
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
135
- logging.info("Building datasets...")
136
- self.build_processors()
137
- build_info = self.config.build_info
138
- datasets = dict()
139
-
140
- # create datasets
141
- dataset_cls = self.train_dataset_cls
142
- datasets['train'] = dataset_cls(
143
- vis_processor=self.vis_processors["train"],
144
- text_processor=self.text_processors["train"],
145
- ann_path=build_info.ann_path,
146
- vis_root=build_info.image_path,
147
- )
148
-
149
- return datasets
150
-
151
-
152
- class AllRefCOCOBuilder(BaseDatasetBuilder):
153
-
154
- def build_datasets(self):
155
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
156
- logging.info("Building datasets...")
157
- self.build_processors()
158
-
159
- build_info = self.config.build_info
160
- image_path = build_info.image_path
161
- ann_path = build_info.ann_path
162
-
163
- datasets = dict()
164
-
165
- if not os.path.exists(image_path):
166
- warnings.warn("image path {} does not exist.".format(image_path))
167
- if not os.path.exists(ann_path):
168
- warnings.warn("ann path {} does not exist.".format(ann_path))
169
-
170
- # create datasets
171
- dataset_cls = self.train_dataset_cls
172
- datasets['train'] = dataset_cls(
173
- vis_processor=self.vis_processors["train"],
174
- text_processor=self.text_processors["train"],
175
- ann_path=ann_path,
176
- vis_root=image_path,
177
- dataset=build_info.dataset,
178
- splitBy=build_info.splitBy
179
- )
180
-
181
- return datasets
182
-
183
-
184
- @registry.register_builder("refcoco")
185
- class RefCOCOBuilder(AllRefCOCOBuilder):
186
- train_dataset_cls = ReferCOCODataset
187
- DATASET_CONFIG_DICT = {
188
- "default": "configs/datasets/coco_bbox/refcoco.yaml",
189
- }
190
-
191
- @registry.register_builder("refcocop")
192
- class RefCOCOPBuilder(AllRefCOCOBuilder):
193
- train_dataset_cls = ReferCOCODataset
194
- DATASET_CONFIG_DICT = {
195
- "default": "configs/datasets/coco_bbox/refcocop.yaml",
196
- }
197
-
198
-
199
- @registry.register_builder("refcocog")
200
- class RefCOCOGBuilder(AllRefCOCOBuilder):
201
- train_dataset_cls = ReferCOCODataset
202
- DATASET_CONFIG_DICT = {
203
- "default": "configs/datasets/coco_bbox/refcocog.yaml",
204
- }
205
-
206
- @registry.register_builder("invrefcoco")
207
- class RefCOCOBuilder(AllRefCOCOBuilder):
208
- train_dataset_cls = InvReferCOCODataset
209
- DATASET_CONFIG_DICT = {
210
- "default": "configs/datasets/coco_bbox/invrefcoco.yaml",
211
- }
212
-
213
-
214
- @registry.register_builder("invrefcocop")
215
- class RefCOCOPBuilder(AllRefCOCOBuilder):
216
- train_dataset_cls = InvReferCOCODataset
217
- DATASET_CONFIG_DICT = {
218
- "default": "configs/datasets/coco_bbox/invrefcocop.yaml",
219
- }
220
-
221
-
222
- @registry.register_builder("invrefcocog")
223
- class RefCOCOGBuilder(AllRefCOCOBuilder):
224
- train_dataset_cls = InvReferCOCODataset
225
- DATASET_CONFIG_DICT = {
226
- "default": "configs/datasets/coco_bbox/invrefcocog.yaml",
227
- }
228
-
229
- @registry.register_builder("refvg")
230
- class RefVisualGenomeBuilder(BaseDatasetBuilder):
231
- train_dataset_cls = ReferVisualGenomeDataset
232
- DATASET_CONFIG_DICT = {
233
- "default": "configs/datasets/vg/ref.yaml",
234
- }
235
-
236
- def build_datasets(self):
237
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
238
- logging.info("Building datasets...")
239
- self.build_processors()
240
-
241
- build_info = self.config.build_info
242
- data_dir = build_info.data_dir
243
- datasets = dict()
244
-
245
- # create datasets
246
- dataset_cls = self.train_dataset_cls
247
- datasets['train'] = dataset_cls(
248
- vis_processor=self.vis_processors["train"],
249
- text_processor=self.text_processors["train"],
250
- data_dir=data_dir,
251
- )
252
-
253
- return datasets
254
-
255
-
256
- @registry.register_builder("textcaps_caption")
257
- class TextcapCaptionBuilder(BaseDatasetBuilder):
258
- train_dataset_cls = TextCapDataset
259
-
260
- DATASET_CONFIG_DICT = {"default": "configs/datasets/textcaps/caption.yaml"}
261
-
262
- def _download_ann(self):
263
- pass
264
-
265
- def _download_vis(self):
266
- pass
267
-
268
- def build(self):
269
- self.build_processors()
270
-
271
- build_info = self.config.build_info
272
-
273
- datasets = dict()
274
- split = "train"
275
-
276
- # create datasets
277
- # [NOTE] return inner_datasets (wds.DataPipeline)
278
- dataset_cls = self.train_dataset_cls
279
- datasets[split] = dataset_cls(
280
- vis_processor=self.vis_processors[split],
281
- text_processor=self.text_processors[split],
282
- ann_path=build_info.ann_path,
283
- vis_root=build_info.image_path,
284
- )
285
-
286
- return datasets
287
-
288
- @registry.register_builder("coco_vqa")
289
- class COCOVQABuilder(BaseDatasetBuilder):
290
- train_dataset_cls = COCOVQADataset
291
-
292
- DATASET_CONFIG_DICT = {
293
- "default": "configs/datasets/coco/defaults_vqa.yaml",
294
- }
295
-
296
- @registry.register_builder("ok_vqa")
297
- class OKVQABuilder(COCOVQABuilder):
298
- DATASET_CONFIG_DICT = {
299
- "default": "configs/datasets/okvqa/defaults.yaml",
300
- }
301
-
302
-
303
- @registry.register_builder("aok_vqa")
304
- class AOKVQABuilder(BaseDatasetBuilder):
305
- train_dataset_cls = AOKVQADataset
306
-
307
- DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"}
308
-
309
-
310
- @registry.register_builder("gqa")
311
- class GQABuilder(BaseDatasetBuilder):
312
- train_dataset_cls = GQADataset
313
- DATASET_CONFIG_DICT = {
314
- "default": "configs/datasets/gqa/balanced_val.yaml",
315
- }
316
-
317
-
318
-
319
-
320
- @registry.register_builder("flickr_grounded_caption")
321
- class GroundedCaptionBuilder(BaseDatasetBuilder):
322
- train_dataset_cls = GroundedDetailDataset
323
- DATASET_CONFIG_DICT = {
324
- "default": "configs/datasets/flickr/default.yaml",
325
- }
326
-
327
- def build_datasets(self):
328
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
329
- logging.info("Building datasets...")
330
- self.build_processors()
331
- build_info = self.config.build_info
332
- datasets = dict()
333
-
334
- # create datasets
335
- dataset_cls = self.train_dataset_cls
336
- datasets['train'] = dataset_cls(
337
- vis_processor=self.vis_processors["train"],
338
- text_processor=self.text_processors["train"],
339
- ann_path=build_info.ann_path,
340
- vis_root=build_info.image_path,
341
- )
342
-
343
- return datasets
344
-
345
-
346
- @registry.register_builder("flickr_CaptionToPhrase")
347
- class CaptionToPhraseBuilder(BaseDatasetBuilder):
348
- train_dataset_cls = CaptionToObjectDataset
349
- DATASET_CONFIG_DICT = {
350
- "default": "configs/datasets/flickr/caption_to_phrase.yaml",
351
- }
352
-
353
- def build_datasets(self):
354
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
355
- logging.info("Building datasets...")
356
- self.build_processors()
357
- build_info = self.config.build_info
358
- datasets = dict()
359
-
360
- # create datasets
361
- dataset_cls = self.train_dataset_cls
362
- datasets['train'] = dataset_cls(
363
- vis_processor=self.vis_processors["train"],
364
- text_processor=self.text_processors["train"],
365
- ann_path=build_info.ann_path,
366
- vis_root=build_info.image_path,
367
- )
368
-
369
- return datasets
370
-
371
- @registry.register_builder("flickr_ObjectToPhrase")
372
- class CaptionToPhraseBuilder(BaseDatasetBuilder):
373
- train_dataset_cls = PhraseToObjectDataset
374
- DATASET_CONFIG_DICT = {
375
- "default": "configs/datasets/flickr/object_to_phrase.yaml",
376
- }
377
-
378
- def build_datasets(self):
379
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
380
- logging.info("Building datasets...")
381
- self.build_processors()
382
- build_info = self.config.build_info
383
- datasets = dict()
384
-
385
- # create datasets
386
- dataset_cls = self.train_dataset_cls
387
- datasets['train'] = dataset_cls(
388
- vis_processor=self.vis_processors["train"],
389
- text_processor=self.text_processors["train"],
390
- ann_path=build_info.ann_path,
391
- vis_root=build_info.image_path,
392
- )
393
-
394
- return datasets
395
-
396
-
397
-
398
-
399
- class DocumentVQABuilder(BaseDatasetBuilder):
400
- def _download_ann(self):
401
- pass
402
-
403
- def _download_vis(self):
404
- pass
405
-
406
- def build(self):
407
- self.build_processors()
408
- build_info = self.config.build_info
409
-
410
- datasets = dict()
411
- split = "train"
412
-
413
- dataset_cls = self.train_dataset_cls
414
- datasets[split] = dataset_cls(
415
- vis_processor=self.vis_processors[split],
416
- text_processor=self.text_processors[split],
417
- vis_root=build_info.image_path,
418
- ann_path=build_info.ann_path
419
- )
420
-
421
- return datasets
422
-
423
-
424
- @registry.register_builder("ocrvqa")
425
- class OCRVQABuilder(DocumentVQABuilder):
426
- train_dataset_cls = OCRVQADataset
427
- DATASET_CONFIG_DICT = {"default": "configs/datasets/ocrvqa/ocrvqa.yaml"}
428
-
429
-
430
- @registry.register_builder("cc_sbu")
431
- class CCSBUBuilder(BaseDatasetBuilder):
432
- train_dataset_cls = CCSBUDataset
433
-
434
- DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
435
-
436
- def _download_ann(self):
437
- pass
438
-
439
- def _download_vis(self):
440
- pass
441
-
442
- def build(self):
443
- self.build_processors()
444
-
445
- build_info = self.config.build_info
446
-
447
- datasets = dict()
448
- split = "train"
449
-
450
- # create datasets
451
- # [NOTE] return inner_datasets (wds.DataPipeline)
452
- dataset_cls = self.train_dataset_cls
453
- datasets[split] = dataset_cls(
454
- vis_processor=self.vis_processors[split],
455
- text_processor=self.text_processors[split],
456
- location=build_info.storage,
457
- ).inner_dataset
458
-
459
- return datasets
460
-
461
-
462
- @registry.register_builder("laion")
463
- class LaionBuilder(BaseDatasetBuilder):
464
- train_dataset_cls = LaionDataset
465
-
466
- DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
467
-
468
- def _download_ann(self):
469
- pass
470
-
471
- def _download_vis(self):
472
- pass
473
-
474
- def build(self):
475
- self.build_processors()
476
-
477
- build_info = self.config.build_info
478
-
479
- datasets = dict()
480
- split = "train"
481
-
482
- # create datasets
483
- # [NOTE] return inner_datasets (wds.DataPipeline)
484
- dataset_cls = self.train_dataset_cls
485
- datasets[split] = dataset_cls(
486
- vis_processor=self.vis_processors[split],
487
- text_processor=self.text_processors[split],
488
- location=build_info.storage,
489
- ).inner_dataset
490
-
491
- return datasets
492
-
493
-
494
-
495
- @registry.register_builder("coco_caption")
496
- class COCOCapBuilder(BaseDatasetBuilder):
497
- train_dataset_cls = COCOCapDataset
498
-
499
- DATASET_CONFIG_DICT = {
500
- "default": "configs/datasets/coco/caption.yaml",
501
- }
502
-
503
-
504
-
505
- @registry.register_builder("cc_sbu_align")
506
- class CCSBUAlignBuilder(BaseDatasetBuilder):
507
- train_dataset_cls = CCSBUAlignDataset
508
-
509
- DATASET_CONFIG_DICT = {
510
- "default": "configs/datasets/cc_sbu/align.yaml",
511
- }
512
-
513
- def build_datasets(self):
514
- # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
515
- logging.info("Building datasets...")
516
- self.build_processors()
517
-
518
- build_info = self.config.build_info
519
- storage_path = build_info.storage
520
-
521
- datasets = dict()
522
-
523
- if not os.path.exists(storage_path):
524
- warnings.warn("storage path {} does not exist.".format(storage_path))
525
-
526
- # create datasets
527
- dataset_cls = self.train_dataset_cls
528
- datasets['train'] = dataset_cls(
529
- vis_processor=self.vis_processors["train"],
530
- text_processor=self.text_processors["train"],
531
- ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
532
- vis_root=os.path.join(storage_path, 'image'),
533
- )
534
-
535
- return datasets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/data_utils.py DELETED
@@ -1,199 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import gzip
9
- import logging
10
- import os
11
- import random as rnd
12
- import tarfile
13
- import zipfile
14
- import random
15
- from typing import List
16
- from tqdm import tqdm
17
-
18
- import decord
19
- from decord import VideoReader
20
- import webdataset as wds
21
- import numpy as np
22
- import torch
23
- from torch.utils.data.dataset import IterableDataset
24
-
25
- from minigpt4.common.registry import registry
26
- from minigpt4.datasets.datasets.base_dataset import ConcatDataset
27
-
28
-
29
- decord.bridge.set_bridge("torch")
30
- MAX_INT = registry.get("MAX_INT")
31
-
32
-
33
- class ChainDataset(wds.DataPipeline):
34
- r"""Dataset for chaining multiple :class:`DataPipeline` s.
35
-
36
- This class is useful to assemble different existing dataset streams. The
37
- chaining operation is done on-the-fly, so concatenating large-scale
38
- datasets with this class will be efficient.
39
-
40
- Args:
41
- datasets (iterable of IterableDataset): datasets to be chained together
42
- """
43
- def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44
- super().__init__()
45
- self.datasets = datasets
46
- self.prob = []
47
- self.names = []
48
- for dataset in self.datasets:
49
- if hasattr(dataset, 'name'):
50
- self.names.append(dataset.name)
51
- else:
52
- self.names.append('Unknown')
53
- if hasattr(dataset, 'sample_ratio'):
54
- self.prob.append(dataset.sample_ratio)
55
- else:
56
- self.prob.append(1)
57
- logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58
-
59
- def __iter__(self):
60
- datastreams = [iter(dataset) for dataset in self.datasets]
61
- while True:
62
- select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63
- yield next(select_datastream)
64
-
65
-
66
- def apply_to_sample(f, sample):
67
- if len(sample) == 0:
68
- return {}
69
-
70
- def _apply(x):
71
- if torch.is_tensor(x):
72
- return f(x)
73
- elif isinstance(x, dict):
74
- return {key: _apply(value) for key, value in x.items()}
75
- elif isinstance(x, list):
76
- return [_apply(x) for x in x]
77
- else:
78
- return x
79
-
80
- return _apply(sample)
81
-
82
-
83
- def move_to_cuda(sample):
84
- def _move_to_cuda(tensor):
85
- return tensor.cuda()
86
-
87
- return apply_to_sample(_move_to_cuda, sample)
88
-
89
-
90
- def prepare_sample(samples, cuda_enabled=True):
91
- if cuda_enabled:
92
- samples = move_to_cuda(samples)
93
-
94
- # TODO fp16 support
95
-
96
- return samples
97
-
98
-
99
- def reorg_datasets_by_split(datasets, batch_sizes):
100
- """
101
- Organizes datasets by split.
102
-
103
- Args:
104
- datasets: dict of torch.utils.data.Dataset objects by name.
105
-
106
- Returns:
107
- Dict of datasets by split {split_name: List[Datasets]}.
108
- """
109
- # if len(datasets) == 1:
110
- # return datasets[list(datasets.keys())[0]]
111
- # else:
112
- reorg_datasets = dict()
113
- reorg_batch_sizes = dict()
114
-
115
- # reorganize by split
116
- for dataset_name, dataset in datasets.items():
117
- for split_name, dataset_split in dataset.items():
118
- if split_name not in reorg_datasets:
119
- reorg_datasets[split_name] = [dataset_split]
120
- reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
121
- else:
122
- reorg_datasets[split_name].append(dataset_split)
123
- reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
124
-
125
- return reorg_datasets, reorg_batch_sizes
126
-
127
-
128
- def concat_datasets(datasets):
129
- """
130
- Concatenates multiple datasets into a single dataset.
131
-
132
- It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
133
- generic IterableDataset because it requires creating separate samplers.
134
-
135
- Now only supports conctenating training datasets and assuming validation and testing
136
- have only a single dataset. This is because metrics should not be computed on the concatenated
137
- datasets.
138
-
139
- Args:
140
- datasets: dict of torch.utils.data.Dataset objects by split.
141
-
142
- Returns:
143
- Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
144
- "val" and "test" remain the same.
145
-
146
- If the input training datasets contain both map-style and DataPipeline datasets, returns
147
- a tuple, where the first element is a concatenated map-style dataset and the second
148
- element is a chained DataPipeline dataset.
149
-
150
- """
151
- # concatenate datasets in the same split
152
- for split_name in datasets:
153
- if split_name != "train":
154
- assert (
155
- len(datasets[split_name]) == 1
156
- ), "Do not support multiple {} datasets.".format(split_name)
157
- datasets[split_name] = datasets[split_name][0]
158
- else:
159
- iterable_datasets, map_datasets = [], []
160
- for dataset in datasets[split_name]:
161
- if isinstance(dataset, wds.DataPipeline):
162
- logging.info(
163
- "Dataset {} is IterableDataset, can't be concatenated.".format(
164
- dataset
165
- )
166
- )
167
- iterable_datasets.append(dataset)
168
- elif isinstance(dataset, IterableDataset):
169
- raise NotImplementedError(
170
- "Do not support concatenation of generic IterableDataset."
171
- )
172
- else:
173
- map_datasets.append(dataset)
174
-
175
- # if len(iterable_datasets) > 0:
176
- # concatenate map-style datasets and iterable-style datasets separately
177
- if len(iterable_datasets) > 1:
178
- chained_datasets = (
179
- ChainDataset(iterable_datasets)
180
- )
181
- elif len(iterable_datasets) == 1:
182
- chained_datasets = iterable_datasets[0]
183
- else:
184
- chained_datasets = None
185
-
186
- concat_datasets = (
187
- ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
188
- )
189
-
190
- train_datasets = concat_datasets, chained_datasets
191
- train_datasets = tuple([x for x in train_datasets if x is not None])
192
- train_datasets = (
193
- train_datasets[0] if len(train_datasets) == 1 else train_datasets
194
- )
195
-
196
- datasets[split_name] = train_datasets
197
-
198
- return datasets
199
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/__init__.py DELETED
File without changes
minigpt4/datasets/datasets/aok_vqa_datasets.py DELETED
@@ -1,116 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- from collections import OrderedDict
9
- import json
10
- import os
11
- import random
12
- import torch
13
-
14
- from PIL import Image
15
-
16
- from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset
17
-
18
-
19
- class __DisplMixin:
20
- def displ_item(self, index):
21
- sample, ann = self.__getitem__(index), self.annotation[index]
22
- return OrderedDict(
23
- {
24
- "file": ann["image"],
25
- "question": ann["question"],
26
- "question_id": ann["question_id"],
27
- "direct_answers": "; ".join(ann["direct_answers"]),
28
- "choices": "; ".join(ann["choices"]),
29
- "correct_choice": ann["choices"][ann["correct_choice_idx"]],
30
- "image": sample["image"],
31
- }
32
- )
33
-
34
-
35
- class AOKVQADataset(VQADataset, __DisplMixin):
36
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
37
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
38
-
39
- self.instruction_pool =[
40
- "[vqa] {}",
41
- "[vqa] Based on the image, respond to this question with a short answer: {}"
42
- ]
43
-
44
- exist_annotation = []
45
- for ann in self.annotation:
46
- image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
47
- if os.path.exists(image_path):
48
- exist_annotation.append(ann)
49
- self.annotation = exist_annotation
50
-
51
- def get_data(self, index):
52
- ann = self.annotation[index]
53
-
54
- image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
55
- image = Image.open(image_path).convert("RGB")
56
-
57
- image = self.vis_processor(image)
58
- question = self.text_processor(ann["question"])
59
-
60
- answer_key = "direct_answers"
61
-
62
- answer_weight = {}
63
- for answer in ann[answer_key]:
64
- if answer in answer_weight.keys():
65
- answer_weight[answer] += 1 / len(ann[answer_key])
66
- else:
67
- answer_weight[answer] = 1 / len(ann[answer_key])
68
-
69
- answers = list(answer_weight.keys())
70
- weights = list(answer_weight.values())
71
-
72
- answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
73
-
74
- return {
75
- "image": image,
76
- "question": question,
77
- "answer": answer,
78
- }
79
-
80
- def __getitem__(self, index):
81
- data = self.get_data(index)
82
- question = self.text_processor(data["question"])
83
- instruction = random.choice(self.instruction_pool).format(question)
84
-
85
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
86
- answer = self.text_processor(data['answer'])
87
-
88
- return {
89
- "image": data['image'],
90
- "instruction_input": instruction,
91
- "answer": answer,
92
- }
93
-
94
-
95
- class AOKVQGDataset(AOKVQADataset):
96
-
97
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
98
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
99
- self.instruction_pool = [
100
- 'Given the image, generate a question whose answer is: {}',
101
- 'Based on the image, provide a question with the answer: {}',
102
- 'Given the visual representation, create a question for which the answer is "{}"',
103
- 'From the image provided, craft a question that leads to the reply: {}',
104
- 'Considering the picture, come up with a question where the answer is: {}',
105
- 'Taking the image into account, generate an question that has the answer: {}'
106
- ]
107
-
108
- def __getitem__(self, index):
109
- data = self.get_data(index)
110
- instruction = random.choice(self.instruction_pool).format(data['answer'])
111
-
112
- return {
113
- "image": data['image'],
114
- "instruction_input": instruction,
115
- "answer": data['question'],
116
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/base_dataset.py DELETED
@@ -1,78 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import json
9
- from typing import Iterable
10
-
11
- from torch.utils.data import Dataset, ConcatDataset
12
- from torch.utils.data.dataloader import default_collate
13
-
14
-
15
-
16
-
17
- class BaseDataset(Dataset):
18
- def __init__(
19
- self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
20
- ):
21
- """
22
- vis_root (string): Root directory of images (e.g. coco/images/)
23
- ann_root (string): directory to store the annotation file
24
- """
25
- self.vis_root = vis_root
26
-
27
- self.annotation = []
28
- # print("ann paths", ann_paths)
29
- for ann_path in ann_paths:
30
- # print("ann_path", ann_path)
31
- ann = json.load(open(ann_path, "r"))
32
- if isinstance(ann, dict):
33
- self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
34
- # self.annotation.extend(json.load(open(ann_path, "r")))
35
- else:
36
- self.annotation.extend(json.load(open(ann_path, "r")))
37
-
38
- self.vis_processor = vis_processor
39
- self.text_processor = text_processor
40
-
41
- self._add_instance_ids()
42
-
43
- def __len__(self):
44
- return len(self.annotation)
45
-
46
- def collater(self, samples):
47
- return default_collate(samples)
48
-
49
- def set_processors(self, vis_processor, text_processor):
50
- self.vis_processor = vis_processor
51
- self.text_processor = text_processor
52
-
53
- def _add_instance_ids(self, key="instance_id"):
54
- for idx, ann in enumerate(self.annotation):
55
- ann[key] = str(idx)
56
-
57
-
58
-
59
- class ConcatDataset(ConcatDataset):
60
- def __init__(self, datasets: Iterable[Dataset]) -> None:
61
- super().__init__(datasets)
62
-
63
- def collater(self, samples):
64
- # TODO For now only supports datasets with same underlying collater implementations
65
-
66
- all_keys = set()
67
- for s in samples:
68
- all_keys.update(s)
69
-
70
- shared_keys = all_keys
71
- for s in samples:
72
- shared_keys = shared_keys & set(s.keys())
73
-
74
- samples_shared_keys = []
75
- for s in samples:
76
- samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
77
-
78
- return self.datasets[0].collater(samples_shared_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/caption_datasets.py DELETED
@@ -1,151 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import os
9
- from collections import OrderedDict
10
-
11
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
12
- from PIL import Image
13
- import random
14
-
15
-
16
- class __DisplMixin:
17
- def displ_item(self, index):
18
- sample, ann = self.__getitem__(index), self.annotation[index]
19
-
20
- return OrderedDict(
21
- {
22
- "file": ann["image"],
23
- "caption": ann["caption"],
24
- "image": sample["image"],
25
- }
26
- )
27
-
28
-
29
- class CaptionDataset(BaseDataset, __DisplMixin):
30
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
31
- """
32
- vis_root (string): Root directory of images (e.g. coco/images/)
33
- ann_root (string): directory to store the annotation file
34
- """
35
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
36
-
37
- self.img_ids = {}
38
- n = 0
39
- for ann in self.annotation:
40
- img_id = ann["image_id"]
41
- if img_id not in self.img_ids.keys():
42
- self.img_ids[img_id] = n
43
- n += 1
44
-
45
- def __getitem__(self, index):
46
-
47
- # TODO this assumes image input, not general enough
48
- ann = self.annotation[index]
49
-
50
- img_file = '{:0>12}.jpg'.format(ann["image_id"])
51
- image_path = os.path.join(self.vis_root, img_file)
52
- image = Image.open(image_path).convert("RGB")
53
-
54
- image = self.vis_processor(image)
55
- caption = self.text_processor(ann["caption"])
56
-
57
- return {
58
- "image": image,
59
- "text_input": caption,
60
- "image_id": self.img_ids[ann["image_id"]],
61
- }
62
-
63
-
64
-
65
- class COCOCaptionDataset(BaseDataset, __DisplMixin):
66
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
67
- """
68
- vis_root (string): Root directory of images (e.g. coco/images/)
69
- ann_root (string): directory to store the annotation file
70
- """
71
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
72
-
73
- self.img_ids = {}
74
- n = 0
75
-
76
- self.filter_anntation = []
77
-
78
- for ann in self.annotation:
79
- if "train" in ann["image"]:
80
- self.filter_anntation.append(ann)
81
- self.annotation = self.filter_anntation
82
-
83
- for ann in self.annotation:
84
- img_id = ann["image_id"]
85
- if img_id not in self.img_ids.keys():
86
- self.img_ids[img_id] = n
87
- n += 1
88
-
89
- self.instruction_pool = [
90
- 'Briefly describe this image.',
91
- 'Provide a concise depiction of this image.',
92
- 'Present a short description of this image.',
93
- 'Summarize this image in a few words.',
94
- 'A short image caption:',
95
- 'A short image description:',
96
- 'A photo of ',
97
- 'An image that shows ',
98
- 'Write a short description for the image. ',
99
- 'Write a description for the photo.',
100
- 'Provide a description of what is presented in the photo.',
101
- 'Briefly describe the content of the image.',
102
- 'Can you briefly explain what you see in the image?',
103
- 'Could you use a few words to describe what you perceive in the photo?',
104
- 'Please provide a short depiction of the picture.',
105
- 'Using language, provide a short account of the image.',
106
- 'Use a few words to illustrate what is happening in the picture.',
107
- ]
108
- def __getitem__(self, index):
109
-
110
- # TODO this assumes image input, not general enough
111
- ann = self.annotation[index]
112
-
113
- img_file = ann["image"].split("/")[-1]
114
- image_path = os.path.join(self.vis_root, img_file)
115
- image = Image.open(image_path).convert("RGB")
116
-
117
- image = self.vis_processor(image)
118
- caption = self.text_processor(ann["caption"])
119
-
120
- instruction = random.choice(self.instruction_pool)
121
- instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
122
-
123
- return {
124
- "image": image,
125
- "answer": caption,
126
- "instruction_input": instruction,
127
- }
128
-
129
- class CaptionEvalDataset(BaseDataset, __DisplMixin):
130
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
131
- """
132
- vis_root (string): Root directory of images (e.g. coco/images/)
133
- ann_root (string): directory to store the annotation file
134
- split (string): val or test
135
- """
136
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
137
-
138
- def __getitem__(self, index):
139
-
140
- ann = self.annotation[index]
141
-
142
- image_path = os.path.join(self.vis_root, ann["image"])
143
- image = Image.open(image_path).convert("RGB")
144
-
145
- image = self.vis_processor(image)
146
-
147
- return {
148
- "image": image,
149
- "image_id": ann["image_id"],
150
- "instance_id": ann["instance_id"],
151
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/cc_sbu_dataset.py DELETED
@@ -1,47 +0,0 @@
1
- import os
2
- from PIL import Image
3
- import webdataset as wds
4
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
5
- from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
6
-
7
-
8
- class CCSBUDataset(BaseDataset):
9
- def __init__(self, vis_processor, text_processor, location):
10
- super().__init__(vis_processor=vis_processor, text_processor=text_processor)
11
-
12
- self.inner_dataset = wds.DataPipeline(
13
- wds.ResampledShards(location),
14
- wds.tarfile_to_samples(handler=wds.warn_and_continue),
15
- wds.shuffle(1000, handler=wds.warn_and_continue),
16
- wds.decode("pilrgb", handler=wds.warn_and_continue),
17
- wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
18
- wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
19
- wds.map(self.to_dict, handler=wds.warn_and_continue),
20
- )
21
-
22
- def to_dict(self, sample):
23
- return {
24
- "image": sample[0],
25
- "answer": self.text_processor(sample[1]["caption"]),
26
- }
27
-
28
-
29
- class CCSBUAlignDataset(CaptionDataset):
30
-
31
- def __getitem__(self, index):
32
-
33
- # TODO this assumes image input, not general enough
34
- ann = self.annotation[index]
35
-
36
- img_file = '{}.jpg'.format(ann["image_id"])
37
- image_path = os.path.join(self.vis_root, img_file)
38
- image = Image.open(image_path).convert("RGB")
39
-
40
- image = self.vis_processor(image)
41
- caption = ann["caption"]
42
-
43
- return {
44
- "image": image,
45
- "answer": caption,
46
- "image_id": self.img_ids[ann["image_id"]],
47
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/coco_caption.py DELETED
@@ -1,120 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import os
9
- import json
10
- import torch
11
- import numpy as np
12
-
13
- from PIL import Image
14
- from PIL import ImageFile
15
-
16
- ImageFile.LOAD_TRUNCATED_IMAGES = True
17
-
18
- from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset
19
-
20
- COCOCapDataset = COCOCaptionDataset
21
-
22
-
23
-
24
-
25
-
26
- class COCOCapEvalDataset(CaptionEvalDataset):
27
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
28
- """
29
- vis_root (string): Root directory of images (e.g. coco/images/)
30
- ann_root (string): directory to store the annotation file
31
- split (string): val or test
32
- """
33
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
34
-
35
- def __getitem__(self, index):
36
- ann = self.annotation[index]
37
-
38
- image_path = os.path.join(self.vis_root, ann["image"])
39
- image = Image.open(image_path).convert("RGB")
40
-
41
- image = self.vis_processor(image)
42
-
43
- img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
44
-
45
- return {
46
- "image": image,
47
- "image_id": img_id,
48
- "instance_id": ann["instance_id"],
49
- }
50
-
51
-
52
- class NoCapsEvalDataset(CaptionEvalDataset):
53
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
54
- """
55
- vis_root (string): Root directory of images (e.g. coco/images/)
56
- ann_root (string): directory to store the annotation file
57
- split (string): val or test
58
- """
59
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
60
-
61
- def __getitem__(self, index):
62
- ann = self.annotation[index]
63
-
64
- image_path = os.path.join(self.vis_root, ann["image"])
65
- image = Image.open(image_path).convert("RGB")
66
-
67
- image = self.vis_processor(image)
68
-
69
- img_id = ann["img_id"]
70
-
71
- return {
72
- "image": image,
73
- "image_id": img_id,
74
- "instance_id": ann["instance_id"],
75
- }
76
-
77
-
78
- class RefCOCOEvalData(torch.utils.data.Dataset):
79
- def __init__(self, loaded_data, vis_processor, root_path):
80
- self.loaded_data = loaded_data
81
- self.root_path = root_path
82
- self.vis_processor = vis_processor
83
-
84
- def __len__(self):
85
- return len(self.loaded_data)
86
-
87
- def __getitem__(self, idx):
88
- data = self.loaded_data[idx]
89
- img_id = data['img_id']
90
- sent = data['sents']
91
- image_path = os.path.join(self.root_path, f'{img_id[:27]}.jpg')
92
- image = Image.open(image_path).convert('RGB')
93
- image = self.vis_processor(image)
94
- question = f"[refer] give me the location of {sent}"
95
- return image, question, img_id
96
-
97
- class EvalCaptionData(torch.utils.data.Dataset):
98
- def __init__(self, loaded_data, vis_processor, root_path):
99
- self.loaded_data = loaded_data
100
- self.root_path = root_path
101
- self.vis_processor = vis_processor
102
- ann = dict()
103
- for item in self.loaded_data:
104
- image_id = item['image_id']
105
- ann[image_id] = item['image']
106
- self.ann = [{'image_id':image_id, 'image': ann[image_id]} for image_id in ann]
107
-
108
- def __len__(self):
109
- return len(self.ann)
110
-
111
- def __getitem__(self, idx):
112
- data = self.ann[idx]
113
- image_id = data['image_id']
114
- img_file = data['image'].split('/')[-1]
115
- image_path = os.path.join(self.root_path, img_file)
116
- image = Image.open(image_path).convert('RGB')
117
-
118
- image = self.vis_processor(image)
119
- question = f"[caption] please describe this image?"
120
- return image, question, image_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/coco_dataset.py DELETED
@@ -1,348 +0,0 @@
1
- import os
2
- import json
3
- import pickle
4
- import random
5
- import time
6
- import itertools
7
-
8
- import numpy as np
9
- from PIL import Image
10
- import skimage.io as io
11
- import matplotlib.pyplot as plt
12
- from matplotlib.collections import PatchCollection
13
- from matplotlib.patches import Polygon, Rectangle
14
- from torch.utils.data import Dataset
15
- import webdataset as wds
16
-
17
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
18
- from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19
-
20
-
21
- class ReferCOCODataset(Dataset):
22
- def __init__(self, vis_processor, text_processor, vis_root, ann_path, dataset='refcoco', splitBy='unc'):
23
- """
24
- vis_root (string): Root directory of images (e.g. coco/images/)
25
- ann_root (string): directory to store the annotation file
26
- """
27
- self.vis_root = vis_root
28
-
29
- self.vis_processor = vis_processor
30
- self.text_processor = text_processor
31
-
32
- self.refer = REFER(ann_path, vis_root, dataset, splitBy)
33
- self.ref_ids = self.refer.getRefIds(split="train")
34
-
35
- self.instruction_pool = [
36
- "[refer] {}",
37
- "[refer] give me the location of {}",
38
- "[refer] where is {} ?",
39
- "[refer] from this image, tell me the location of {}",
40
- "[refer] the location of {} is",
41
- "[refer] could you tell me the location for {} ?",
42
- "[refer] where can I locate the {} ?",
43
- ]
44
-
45
-
46
- def __len__(self):
47
- return len(self.ref_ids)
48
-
49
- def preprocess(self, index):
50
- ref_id = self.ref_ids[index]
51
- ref = self.refer.loadRefs(ref_id)[0]
52
-
53
- image_file = 'COCO_train2014_{:0>12}.jpg'.format(ref["image_id"])
54
- image_path = os.path.join(self.vis_root, image_file)
55
- image = Image.open(image_path).convert("RGB")
56
- image_orig_size = image.size
57
- image = self.vis_processor(image)
58
- image_new_size = [image.shape[1], image.shape[2]]
59
-
60
- image_new_size = [100,100]
61
-
62
- sample_sentence = random.choice(ref['sentences'])['raw']
63
- refer_sentence = self.text_processor(sample_sentence)
64
-
65
-
66
- bbox = self.refer.getRefBox(ref['ref_id'])
67
- bbox = [
68
- bbox[0] / image_orig_size[0] * image_new_size[0],
69
- bbox[1] / image_orig_size[1] * image_new_size[1],
70
- (bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0],
71
- (bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1]
72
- ]
73
- bbox = [int(x) for x in bbox]
74
- bbox = "{{<{}><{}><{}><{}>}}".format(*bbox)
75
- return {
76
- "image": image,
77
- "refer_sentence": refer_sentence,
78
- "bbox": bbox,
79
- "image_id": ref['image_id'],
80
- }
81
-
82
- def __getitem__(self, index):
83
- data = self.preprocess(index)
84
- instruction = random.choice(self.instruction_pool).format(data['refer_sentence'])
85
-
86
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
87
-
88
- return {
89
- "image": data['image'],
90
- "instruction_input": instruction,
91
- "answer": data['bbox'],
92
- "image_id": data['image_id'],
93
- }
94
-
95
-
96
- class InvReferCOCODataset(ReferCOCODataset):
97
- def __init__(self, *args, **kwargs):
98
- super(InvReferCOCODataset, self).__init__(*args, **kwargs)
99
-
100
- self.instruction_pool = [
101
- "[identify] {}",
102
- "[identify] what object is in this location {}",
103
- "[identify] identify the object present at this location {}",
104
- "[identify] what is it in {}",
105
- "[identify] describe this object in {}",
106
- "[identify] this {} is",
107
- "[identify] the object in {} is",
108
- ]
109
-
110
- def __getitem__(self, index):
111
- data = self.preprocess(index)
112
-
113
- instruction = random.choice(self.instruction_pool).format(data['bbox'])
114
-
115
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
116
-
117
- return {
118
- "image": data['image'],
119
- "instruction_input": instruction,
120
- "answer": self.text_processor(data['refer_sentence']),
121
- "image_id": data['image_id'],
122
- }
123
-
124
-
125
- class REFER:
126
- def __init__(self, data_root, vis_root, dataset='refcoco', splitBy='unc'):
127
- # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
128
- # also provide dataset name and splitBy information
129
- # e.g., dataset = 'refcoco', splitBy = 'unc'
130
- dataset = dataset.split('inv')[-1] # inv dataset is stored in the same path as normal dataset
131
- print('loading dataset %s into memory...' % dataset)
132
- self.ann_dir = os.path.join(data_root, dataset)
133
- if dataset in ['refcoco', 'refcoco+', 'refcocog']:
134
- self.vis_root = vis_root
135
- elif dataset == 'refclef':
136
- raise 'No RefClef image data'
137
- else:
138
- raise 'No refer dataset is called [%s]' % dataset
139
-
140
- # load refs from data/dataset/refs(dataset).json
141
- tic = time.time()
142
- ref_file = os.path.join(self.ann_dir, 'refs(' + splitBy + ').p')
143
- self.data = {}
144
- self.data['dataset'] = dataset
145
- self.data['refs'] = pickle.load(open(ref_file, 'rb'))
146
-
147
- # load annotations from data/dataset/instances.json
148
- instances_file = os.path.join(self.ann_dir, 'instances.json')
149
- instances = json.load(open(instances_file, 'r'))
150
- self.data['images'] = instances['images']
151
- self.data['annotations'] = instances['annotations']
152
- self.data['categories'] = instances['categories']
153
-
154
- # create index
155
- self.createIndex()
156
- print('DONE (t=%.2fs)' % (time.time() - tic))
157
-
158
- def createIndex(self):
159
- # create sets of mapping
160
- # 1) Refs: {ref_id: ref}
161
- # 2) Anns: {ann_id: ann}
162
- # 3) Imgs: {image_id: image}
163
- # 4) Cats: {category_id: category_name}
164
- # 5) Sents: {sent_id: sent}
165
- # 6) imgToRefs: {image_id: refs}
166
- # 7) imgToAnns: {image_id: anns}
167
- # 8) refToAnn: {ref_id: ann}
168
- # 9) annToRef: {ann_id: ref}
169
- # 10) catToRefs: {category_id: refs}
170
- # 11) sentToRef: {sent_id: ref}
171
- # 12) sentToTokens: {sent_id: tokens}
172
- print('creating index...')
173
- # fetch info from instances
174
- Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
175
- for ann in self.data['annotations']:
176
- Anns[ann['id']] = ann
177
- imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
178
- for img in self.data['images']:
179
- Imgs[img['id']] = img
180
- for cat in self.data['categories']:
181
- Cats[cat['id']] = cat['name']
182
-
183
- # fetch info from refs
184
- Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
185
- Sents, sentToRef, sentToTokens = {}, {}, {}
186
- for ref in self.data['refs']:
187
- # ids
188
- ref_id = ref['ref_id']
189
- ann_id = ref['ann_id']
190
- category_id = ref['category_id']
191
- image_id = ref['image_id']
192
-
193
- # add mapping related to ref
194
- Refs[ref_id] = ref
195
- imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
196
- catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
197
- refToAnn[ref_id] = Anns[ann_id]
198
- annToRef[ann_id] = ref
199
-
200
- # add mapping of sent
201
- for sent in ref['sentences']:
202
- Sents[sent['sent_id']] = sent
203
- sentToRef[sent['sent_id']] = ref
204
- sentToTokens[sent['sent_id']] = sent['tokens']
205
-
206
- # create class members
207
- self.Refs = Refs
208
- self.Anns = Anns
209
- self.Imgs = Imgs
210
- self.Cats = Cats
211
- self.Sents = Sents
212
- self.imgToRefs = imgToRefs
213
- self.imgToAnns = imgToAnns
214
- self.refToAnn = refToAnn
215
- self.annToRef = annToRef
216
- self.catToRefs = catToRefs
217
- self.sentToRef = sentToRef
218
- self.sentToTokens = sentToTokens
219
- print('index created.')
220
-
221
- def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
222
- image_ids = image_ids if type(image_ids) == list else [image_ids]
223
- cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
224
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
225
-
226
- if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
227
- refs = self.data['refs']
228
- else:
229
- if not len(image_ids) == 0:
230
- refs = [self.imgToRefs[image_id] for image_id in image_ids]
231
- else:
232
- refs = self.data['refs']
233
- if not len(cat_ids) == 0:
234
- refs = [ref for ref in refs if ref['category_id'] in cat_ids]
235
- if not len(ref_ids) == 0:
236
- refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
237
- if not len(split) == 0:
238
- if split in ['testA', 'testB', 'testC']:
239
- refs = [ref for ref in refs if
240
- split[-1] in ref['split']] # we also consider testAB, testBC, ...
241
- elif split in ['testAB', 'testBC', 'testAC']:
242
- refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
243
- elif split == 'test':
244
- refs = [ref for ref in refs if 'test' in ref['split']]
245
- elif split == 'train' or split == 'val':
246
- refs = [ref for ref in refs if ref['split'] == split]
247
- else:
248
- raise 'No such split [%s]' % split
249
- ref_ids = [ref['ref_id'] for ref in refs]
250
- return ref_ids
251
-
252
- def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
253
- image_ids = image_ids if type(image_ids) == list else [image_ids]
254
- cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
255
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
256
-
257
- if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
258
- ann_ids = [ann['id'] for ann in self.data['annotations']]
259
- else:
260
- if not len(image_ids) == 0:
261
- lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
262
- anns = list(itertools.chain.from_iterable(lists))
263
- else:
264
- anns = self.data['annotations']
265
- if not len(cat_ids) == 0:
266
- anns = [ann for ann in anns if ann['category_id'] in cat_ids]
267
- ann_ids = [ann['id'] for ann in anns]
268
- if not len(ref_ids) == 0:
269
- ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
270
- return ann_ids
271
-
272
- def getImgIds(self, ref_ids=[]):
273
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
274
-
275
- if not len(ref_ids) == 0:
276
- image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
277
- else:
278
- image_ids = self.Imgs.keys()
279
- return image_ids
280
-
281
- def getCatIds(self):
282
- return self.Cats.keys()
283
-
284
- def loadRefs(self, ref_ids=[]):
285
- if type(ref_ids) == list:
286
- return [self.Refs[ref_id] for ref_id in ref_ids]
287
- elif type(ref_ids) == int:
288
- return [self.Refs[ref_ids]]
289
-
290
- def loadAnns(self, ann_ids=[]):
291
- if type(ann_ids) == list:
292
- return [self.Anns[ann_id] for ann_id in ann_ids]
293
- elif type(ann_ids) == int:
294
- return [self.Anns[ann_ids]]
295
-
296
- def loadImgs(self, image_ids=[]):
297
- if type(image_ids) == list:
298
- return [self.Imgs[image_id] for image_id in image_ids]
299
- elif type(image_ids) == int:
300
- return [self.Imgs[image_ids]]
301
-
302
- def loadCats(self, cat_ids=[]):
303
- if type(cat_ids) == list:
304
- return [self.Cats[cat_id] for cat_id in cat_ids]
305
- elif type(cat_ids) == int:
306
- return [self.Cats[cat_ids]]
307
-
308
- def getRefBox(self, ref_id):
309
- ref = self.Refs[ref_id]
310
- ann = self.refToAnn[ref_id]
311
- return ann['bbox'] # [x, y, w, h]
312
-
313
- def showRef(self, ref, seg_box='box'):
314
- ax = plt.gca()
315
- # show image
316
- image = self.Imgs[ref['image_id']]
317
- I = io.imread(os.path.join(self.vis_root, image['file_name']))
318
- ax.imshow(I)
319
- # show refer expression
320
- for sid, sent in enumerate(ref['sentences']):
321
- print('%s. %s' % (sid + 1, sent['sent']))
322
- # show segmentations
323
- if seg_box == 'seg':
324
- ann_id = ref['ann_id']
325
- ann = self.Anns[ann_id]
326
- polygons = []
327
- color = []
328
- c = 'none'
329
- if type(ann['segmentation'][0]) == list:
330
- # polygon used for refcoco*
331
- for seg in ann['segmentation']:
332
- poly = np.array(seg).reshape((len(seg) / 2, 2))
333
- polygons.append(Polygon(poly, True, alpha=0.4))
334
- color.append(c)
335
- p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 1, 0, 0), linewidths=3, alpha=1)
336
- ax.add_collection(p) # thick yellow polygon
337
- p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 0, 0, 0), linewidths=1, alpha=1)
338
- ax.add_collection(p) # thin red polygon
339
- else:
340
- # mask used for refclef
341
- raise NotImplementedError('RefClef is not downloaded')
342
- # show bounding-box
343
- elif seg_box == 'box':
344
- ann_id = ref['ann_id']
345
- ann = self.Anns[ann_id]
346
- bbox = self.getRefBox(ref['ref_id'])
347
- box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
348
- ax.add_patch(box_plot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/coco_vqa_datasets.py DELETED
@@ -1,145 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import os
9
- import json
10
- import random
11
-
12
- from PIL import Image
13
-
14
- from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
15
-
16
- from collections import OrderedDict
17
-
18
-
19
- class __DisplMixin:
20
- def displ_item(self, index):
21
- sample, ann = self.__getitem__(index), self.annotation[index]
22
-
23
- return OrderedDict(
24
- {
25
- "file": ann["image"],
26
- "question": ann["question"],
27
- "question_id": ann["question_id"],
28
- "answers": "; ".join(ann["answer"]),
29
- "image": sample["image"],
30
- }
31
- )
32
-
33
-
34
- class COCOVQADataset(VQADataset, __DisplMixin):
35
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
36
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
37
-
38
- self.instruction_pool =[
39
- "[vqa] {}",
40
- "[vqa] Based on the image, respond to this question with a short answer: {}"
41
- ]
42
-
43
- exist_annotation = []
44
- for ann in self.annotation:
45
- image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
46
- if os.path.exists(image_path):
47
- exist_annotation.append(ann)
48
- self.annotation = exist_annotation
49
-
50
-
51
- def get_data(self, index):
52
- ann = self.annotation[index]
53
-
54
- image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
55
- image = Image.open(image_path).convert("RGB")
56
-
57
- image = self.vis_processor(image)
58
- question = self.text_processor(ann["question"])
59
- question_id = ann["question_id"]
60
-
61
- answer_weight = {}
62
- for answer in ann["answer"]:
63
- if answer in answer_weight.keys():
64
- answer_weight[answer] += 1 / len(ann["answer"])
65
- else:
66
- answer_weight[answer] = 1 / len(ann["answer"])
67
-
68
- answers = list(answer_weight.keys())
69
- weights = list(answer_weight.values())
70
-
71
- answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
72
-
73
-
74
- return {
75
- "image": image,
76
- "question": question,
77
- "question_id": question_id,
78
- "answer": answer,
79
- }
80
-
81
- def __getitem__(self, index):
82
- data = self.get_data(index)
83
- instruction = random.choice(self.instruction_pool).format(data['question'])
84
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
85
-
86
- return {
87
- "image": data['image'],
88
- "question_id": data["question_id"],
89
- "instruction_input": instruction,
90
- "answer": self.text_processor(data['answer']),
91
- }
92
-
93
-
94
- class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin):
95
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
96
- """
97
- vis_root (string): Root directory of images (e.g. coco/images/)
98
- ann_root (string): directory to store the annotation file
99
- """
100
-
101
- self.instruction_pool = [
102
- 'Question: {} Short answer:',
103
- ]
104
- self.vis_root = vis_root
105
-
106
- self.annotation = json.load(open(ann_paths[0]))
107
-
108
- answer_list_path = ann_paths[1]
109
- if os.path.exists(answer_list_path):
110
- self.answer_list = json.load(open(answer_list_path))
111
- else:
112
- self.answer_list = None
113
-
114
- try:
115
- self.coco_fmt_qust_file = ann_paths[2]
116
- self.coco_fmt_anno_file = ann_paths[3]
117
- except IndexError:
118
- self.coco_fmt_qust_file = None
119
- self.coco_fmt_anno_file = None
120
-
121
- self.vis_processor = vis_processor
122
- self.text_processor = text_processor
123
-
124
- self._add_instance_ids()
125
-
126
- def __getitem__(self, index):
127
- ann = self.annotation[index]
128
-
129
- image_path = os.path.join(self.vis_root, ann["image"])
130
- image = Image.open(image_path).convert("RGB")
131
-
132
- image = self.vis_processor(image)
133
- question = self.text_processor(ann["question"])
134
-
135
- instruction = random.choice(self.instruction_pool).format(question)
136
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
137
-
138
- return {
139
- "image": image,
140
- 'image_path': image_path,
141
- "question": question,
142
- "question_id": ann["question_id"],
143
- "instruction_input": instruction,
144
- "instance_id": ann["instance_id"],
145
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/dataloader_utils.py DELETED
@@ -1,162 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import time
9
- import random
10
- import torch
11
- from minigpt4.datasets.data_utils import move_to_cuda
12
- from torch.utils.data import DataLoader
13
-
14
-
15
- class MultiIterLoader:
16
- """
17
- A simple wrapper for iterating over multiple iterators.
18
-
19
- Args:
20
- loaders (List[Loader]): List of Iterator loaders.
21
- ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22
- """
23
-
24
- def __init__(self, loaders, ratios=None):
25
- # assert all loaders has __next__ method
26
- for loader in loaders:
27
- assert hasattr(
28
- loader, "__next__"
29
- ), "Loader {} has no __next__ method.".format(loader)
30
-
31
- if ratios is None:
32
- ratios = [1.0] * len(loaders)
33
- else:
34
- assert len(ratios) == len(loaders)
35
- ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36
-
37
- self.loaders = loaders
38
- self.ratios = ratios
39
-
40
- def __next__(self):
41
- # random sample from each loader by ratio
42
- loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43
- return next(self.loaders[loader_idx])
44
-
45
-
46
- class PrefetchLoader(object):
47
- """
48
- Modified from https://github.com/ChenRocks/UNITER.
49
-
50
- overlap compute and cuda data transfer
51
- (copied and then modified from nvidia apex)
52
- """
53
-
54
- def __init__(self, loader):
55
- self.loader = loader
56
- self.stream = torch.cuda.Stream()
57
-
58
- def __iter__(self):
59
- loader_it = iter(self.loader)
60
- self.preload(loader_it)
61
- batch = self.next(loader_it)
62
- while batch is not None:
63
- is_tuple = isinstance(batch, tuple)
64
- if is_tuple:
65
- task, batch = batch
66
-
67
- if is_tuple:
68
- yield task, batch
69
- else:
70
- yield batch
71
- batch = self.next(loader_it)
72
-
73
- def __len__(self):
74
- return len(self.loader)
75
-
76
- def preload(self, it):
77
- try:
78
- self.batch = next(it)
79
- except StopIteration:
80
- self.batch = None
81
- return
82
- # if record_stream() doesn't work, another option is to make sure
83
- # device inputs are created on the main stream.
84
- # self.next_input_gpu = torch.empty_like(self.next_input,
85
- # device='cuda')
86
- # self.next_target_gpu = torch.empty_like(self.next_target,
87
- # device='cuda')
88
- # Need to make sure the memory allocated for next_* is not still in use
89
- # by the main stream at the time we start copying to next_*:
90
- # self.stream.wait_stream(torch.cuda.current_stream())
91
- with torch.cuda.stream(self.stream):
92
- self.batch = move_to_cuda(self.batch)
93
- # more code for the alternative if record_stream() doesn't work:
94
- # copy_ will record the use of the pinned source tensor in this
95
- # side stream.
96
- # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97
- # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98
- # self.next_input = self.next_input_gpu
99
- # self.next_target = self.next_target_gpu
100
-
101
- def next(self, it):
102
- torch.cuda.current_stream().wait_stream(self.stream)
103
- batch = self.batch
104
- if batch is not None:
105
- record_cuda_stream(batch)
106
- self.preload(it)
107
- return batch
108
-
109
- def __getattr__(self, name):
110
- method = self.loader.__getattribute__(name)
111
- return method
112
-
113
-
114
- def record_cuda_stream(batch):
115
- if isinstance(batch, torch.Tensor):
116
- batch.record_stream(torch.cuda.current_stream())
117
- elif isinstance(batch, list) or isinstance(batch, tuple):
118
- for t in batch:
119
- record_cuda_stream(t)
120
- elif isinstance(batch, dict):
121
- for t in batch.values():
122
- record_cuda_stream(t)
123
- else:
124
- pass
125
-
126
-
127
- class IterLoader:
128
- """
129
- A wrapper to convert DataLoader as an infinite iterator.
130
-
131
- Modified from:
132
- https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133
- """
134
-
135
- def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136
- self._dataloader = dataloader
137
- self.iter_loader = iter(self._dataloader)
138
- self._use_distributed = use_distributed
139
- self._epoch = 0
140
-
141
- @property
142
- def epoch(self) -> int:
143
- return self._epoch
144
-
145
- def __next__(self):
146
- try:
147
- data = next(self.iter_loader)
148
- except StopIteration:
149
- self._epoch += 1
150
- if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151
- self._dataloader.sampler.set_epoch(self._epoch)
152
- time.sleep(2) # Prevent possible deadlock during epoch transition
153
- self.iter_loader = iter(self._dataloader)
154
- data = next(self.iter_loader)
155
-
156
- return data
157
-
158
- def __iter__(self):
159
- return self
160
-
161
- def __len__(self):
162
- return len(self._dataloader)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/flickr.py DELETED
@@ -1,159 +0,0 @@
1
- import os
2
- import json
3
- import pickle
4
- import random
5
- import time
6
- import itertools
7
-
8
- import numpy as np
9
- from PIL import Image
10
- import skimage.io as io
11
- import matplotlib.pyplot as plt
12
- from matplotlib.collections import PatchCollection
13
- from matplotlib.patches import Polygon, Rectangle
14
- from torch.utils.data import Dataset
15
- import webdataset as wds
16
-
17
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
18
- from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19
-
20
-
21
- class GroundedDetailDataset(Dataset):
22
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
23
- """
24
- vis_root (string): Root directory of images (e.g. coco/images/)
25
- ann_root (string): directory to store the annotation file
26
- """
27
- self.vis_root = vis_root
28
-
29
- self.vis_processor = vis_processor
30
- self.text_processor = text_processor
31
-
32
- self.instruction_pool = [
33
- '[grounding] please describe this image in details',
34
- '[grounding] describe this image as detailed as possible',
35
- '[grounding] summarize this image in details',
36
- '[grounding] give a thorough description of what you see in this image',
37
- ]
38
-
39
- with open(ann_path, 'r') as f:
40
- self.ann = json.load(f)
41
-
42
- def __len__(self):
43
- return len(self.ann)
44
-
45
- def __getitem__(self, index):
46
- info = self.ann[index]
47
-
48
- # image_file = 'COCO_train2014_{}.jpg'.format(info['image_id'])
49
- image_file = '{}.jpg'.format(info['image_id'])
50
- image_path = os.path.join(self.vis_root, image_file)
51
- image = Image.open(image_path).convert("RGB")
52
- image = self.vis_processor(image)
53
-
54
- answer = info['grounded_caption']
55
- instruction = random.choice(self.instruction_pool)
56
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
57
-
58
- return {
59
- "image": image,
60
- "instruction_input": instruction,
61
- "answer": answer,
62
- "image_id": info['image_id'],
63
- }
64
-
65
-
66
-
67
-
68
- class CaptionToObjectDataset(Dataset):
69
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
70
- """
71
- vis_root (string): Root directory of images (e.g. coco/images/)
72
- ann_root (string): directory to store the annotation file
73
- """
74
- self.vis_root = vis_root
75
-
76
- self.vis_processor = vis_processor
77
- self.text_processor = text_processor
78
-
79
- self.instruction_pool = [
80
- '[detection] {}',
81
- ]
82
-
83
- with open(ann_path, 'r') as f:
84
- self.ann = json.load(f)
85
-
86
- def __len__(self):
87
- return len(self.ann)
88
-
89
- def __getitem__(self, index):
90
- info = self.ann[index]
91
-
92
- image_file = '{}.jpg'.format(info['image_id'])
93
- image_path = os.path.join(self.vis_root, image_file)
94
- image = Image.open(image_path).convert("RGB")
95
- image = self.vis_processor(image)
96
-
97
- input = info["caption"]
98
- answer = info["output"]
99
-
100
- instruction = random.choice(self.instruction_pool).format(input)
101
-
102
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
103
-
104
- print("CaptionToObject instruction", instruction)
105
- print("CaptionToObject answer", answer)
106
-
107
- return {
108
- "image": image,
109
- "instruction_input": instruction,
110
- "answer": answer,
111
- "image_id": info['image_id'],
112
- }
113
-
114
-
115
-
116
-
117
- class PhraseToObjectDataset(Dataset):
118
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
119
- """
120
- vis_root (string): Root directory of images (e.g. coco/images/)
121
- ann_root (string): directory to store the annotation file
122
- """
123
- self.vis_root = vis_root
124
-
125
- self.vis_processor = vis_processor
126
- self.text_processor = text_processor
127
-
128
- self.instruction_pool = [
129
- '[detection] {}',
130
- ]
131
-
132
- with open(ann_path, 'r') as f:
133
- self.ann = json.load(f)
134
-
135
- def __len__(self):
136
- return len(self.ann)
137
-
138
- def __getitem__(self, index):
139
- info = self.ann[index]
140
- image_file = '{}.jpg'.format(info['image_id'])
141
- image_path = os.path.join(self.vis_root, image_file)
142
- image = Image.open(image_path).convert("RGB")
143
- image = self.vis_processor(image)
144
-
145
- input = info["phrase"]
146
- answer = "<p>"+input+"</p> "+info["bbox"]
147
- instruction = random.choice(self.instruction_pool).format(input)
148
-
149
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
150
-
151
- print("PhraseToObject instruction", instruction)
152
- print("PhraseToObject answer", answer)
153
-
154
- return {
155
- "image": image,
156
- "instruction_input": instruction,
157
- "answer": answer,
158
- "image_id": info['image_id'],
159
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/gqa_datasets.py DELETED
@@ -1,60 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import os
9
- import json
10
-
11
- from PIL import Image
12
-
13
- from minigpt4.datasets.datasets.vqa_datasets import VQADataset
14
-
15
- from collections import OrderedDict
16
- import random
17
-
18
- class __DisplMixin:
19
- def displ_item(self, index):
20
- sample, ann = self.__getitem__(index), self.annotation[index]
21
-
22
- return OrderedDict(
23
- {
24
- "file": ann["image"],
25
- "question": ann["question"],
26
- "question_id": ann["question_id"],
27
- "answers": "; ".join(ann["answer"]),
28
- "image": sample["image"],
29
- }
30
- )
31
-
32
-
33
- class GQADataset(VQADataset, __DisplMixin):
34
- def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
35
- super().__init__(vis_processor, text_processor, vis_root, ann_paths)
36
- self.instruction_pool =[
37
- "[vqa] {}",
38
- "[vqa] Based on the image, respond to this question with a short answer: {}"
39
- ]
40
-
41
- def __getitem__(self, index):
42
- ann = self.annotation[index]
43
-
44
- image_path = os.path.join(self.vis_root, ann["image"])
45
- image = Image.open(image_path).convert("RGB")
46
-
47
- image = self.vis_processor(image)
48
- question = self.text_processor(ann["question"])
49
-
50
- instruction = random.choice(self.instruction_pool).format(question)
51
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
52
-
53
- answers = self.text_processor(ann["answer"])
54
-
55
- return {
56
- "image": image,
57
- "instruction_input": instruction,
58
- "answer": answers,
59
- }
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/laion_dataset.py DELETED
@@ -1,31 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import webdataset as wds
9
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
10
-
11
-
12
- class LaionDataset(BaseDataset):
13
- def __init__(self, vis_processor, text_processor, location):
14
- super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15
-
16
- self.inner_dataset = wds.DataPipeline(
17
- wds.ResampledShards(location),
18
- wds.tarfile_to_samples(handler=wds.warn_and_continue),
19
- wds.shuffle(1000, handler=wds.warn_and_continue),
20
- wds.decode("pilrgb", handler=wds.warn_and_continue),
21
- wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22
- wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23
- wds.map(self.to_dict, handler=wds.warn_and_continue),
24
- )
25
-
26
- def to_dict(self, sample):
27
- return {
28
- "image": sample[0],
29
- "answer": self.text_processor(sample[1]["caption"]),
30
- }
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/llava_dataset.py DELETED
@@ -1,149 +0,0 @@
1
- import os
2
- import json
3
- import pickle
4
- import random
5
- import time
6
- import numpy as np
7
- from PIL import Image
8
- import skimage.io as io
9
- import matplotlib.pyplot as plt
10
- from matplotlib.collections import PatchCollection
11
- from matplotlib.patches import Polygon, Rectangle
12
- from torch.utils.data import Dataset
13
- import webdataset as wds
14
-
15
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
16
- from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
17
-
18
- class LlavaDetailDataset(Dataset):
19
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
20
- """
21
- vis_root (string): Root directory of images (e.g. coco/images/)
22
- ann_root (string): directory to store the annotation file
23
- """
24
- self.vis_root = vis_root
25
-
26
- self.vis_processor = vis_processor
27
- self.text_processor = text_processor
28
-
29
- with open(ann_path, 'r') as f:
30
- self.ann = json.load(f)
31
-
32
- def __len__(self):
33
- return len(self.ann)
34
-
35
- def __getitem__(self, index):
36
- info = self.ann[index]
37
-
38
- image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
39
- image_path = os.path.join(self.vis_root, image_file)
40
- image = Image.open(image_path).convert("RGB")
41
- image = self.vis_processor(image)
42
-
43
- answer = info['conversations'][1]['value']
44
- instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
45
-
46
- instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))
47
-
48
- return {
49
- "image": image,
50
- "instruction_input": instruction,
51
- "answer": answer,
52
- "image_id": info['id'],
53
- }
54
-
55
- class LlavaReasonDataset(Dataset):
56
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
57
- """
58
- vis_root (string): Root directory of images (e.g. coco/images/)
59
- ann_root (string): directory to store the annotation file
60
- """
61
- self.vis_root = vis_root
62
-
63
- self.vis_processor = vis_processor
64
- self.text_processor = text_processor
65
-
66
- with open(ann_path, 'r') as f:
67
- self.ann = json.load(f)
68
-
69
- def __len__(self):
70
- return len(self.ann)
71
-
72
- def __getitem__(self, index):
73
- info = self.ann[index]
74
-
75
- image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
76
- image_path = os.path.join(self.vis_root, image_file)
77
- image = Image.open(image_path).convert("RGB")
78
- image = self.vis_processor(image)
79
-
80
- answer = info['conversations'][1]['value']
81
- instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
82
-
83
- instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))
84
-
85
- return {
86
- "image": image,
87
- "instruction_input": instruction,
88
- "answer": answer,
89
- "image_id": info['id'],
90
- }
91
-
92
-
93
-
94
-
95
- class LlavaConversationDataset(Dataset):
96
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
97
- """
98
- vis_root (string): Root directory of images (e.g. coco/images/)
99
- ann_root (string): directory to store the annotation file
100
- """
101
- self.vis_root = vis_root
102
-
103
- self.vis_processor = vis_processor
104
- self.text_processor = text_processor
105
-
106
- self.ann=[]
107
-
108
-
109
- with open(ann_path, 'r') as f:
110
- self.ann = json.load(f)
111
-
112
- self.connect_sym = "!@#"
113
-
114
- def __len__(self):
115
- return len(self.ann)
116
-
117
- def __getitem__(self, index):
118
- info = self.ann[index]
119
-
120
- image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
121
- image_path = os.path.join(self.vis_root, image_file)
122
- image = Image.open(image_path).convert("RGB")
123
- image = self.vis_processor(image)
124
-
125
- first_instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
126
- first_instruction = '<Img><ImageHere></Img> {} '.format(first_instruction)
127
-
128
- questions = [first_instruction]
129
- answers = []
130
-
131
- for i, item in enumerate(info["conversations"][1:]):
132
- if i % 2 ==0: # assistant
133
- assistant_answer = item["value"]
134
- answers.append(assistant_answer)
135
- else:
136
- human_instruction = item["value"]+" "
137
- questions.append(human_instruction)
138
-
139
- questions = self.connect_sym.join(questions)
140
- answers = self.connect_sym.join(answers)
141
-
142
-
143
- return {
144
- "image": image,
145
- "conv_q": questions,
146
- 'conv_a': answers,
147
- "image_id": info['id'],
148
- "connect_sym": self.connect_sym
149
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/multitask_conversation.py DELETED
@@ -1,75 +0,0 @@
1
- import os
2
- import json
3
- import pickle
4
- import random
5
- import time
6
- import itertools
7
-
8
- import numpy as np
9
- from PIL import Image
10
- import skimage.io as io
11
- import matplotlib.pyplot as plt
12
- from matplotlib.collections import PatchCollection
13
- from matplotlib.patches import Polygon, Rectangle
14
- from torch.utils.data import Dataset
15
- import webdataset as wds
16
-
17
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
18
- from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19
-
20
-
21
-
22
-
23
- class MultiTaskConversationDataset(Dataset):
24
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
25
- """
26
- vis_root (string): Root directory of images (e.g. coco/images/)
27
- ann_root (string): directory to store the annotation file
28
- """
29
- self.vis_root = vis_root
30
-
31
- self.vis_processor = vis_processor
32
- self.text_processor = text_processor
33
-
34
-
35
- with open(ann_path, 'r') as f:
36
- self.ann = json.load(f)
37
-
38
- self.connect_sym = "!@#"
39
-
40
- def __len__(self):
41
- return len(self.ann)
42
-
43
- def __getitem__(self, index):
44
- info = self.ann[index]
45
-
46
- image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
47
- image_path = os.path.join(self.vis_root, image_file)
48
- image = Image.open(image_path).convert("RGB")
49
- image = self.vis_processor(image)
50
-
51
- first_instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
52
- first_instruction = '<Img><ImageHere></Img> {} '.format(first_instruction)
53
-
54
- questions = [first_instruction]
55
- answers = []
56
-
57
- for i, item in enumerate(info["conversations"][1:]):
58
- if i % 2 ==0: # assistant
59
- assistant_answer = item["value"]
60
- answers.append(assistant_answer)
61
- else:
62
- human_instruction = item["value"]+" "
63
- questions.append(human_instruction)
64
-
65
- questions = self.connect_sym.join(questions)
66
- answers = self.connect_sym.join(answers)
67
-
68
-
69
- return {
70
- "image": image,
71
- "conv_q": questions,
72
- 'conv_a': answers,
73
- "image_id": info['id'],
74
- "connect_sym": self.connect_sym
75
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
minigpt4/datasets/datasets/ocrvqa_dataset.py DELETED
@@ -1,77 +0,0 @@
1
- import os
2
- import json
3
- import pickle
4
- import random
5
- import time
6
- import itertools
7
-
8
- import numpy as np
9
- from PIL import Image
10
- import skimage.io as io
11
- import matplotlib.pyplot as plt
12
- from matplotlib.collections import PatchCollection
13
- from matplotlib.patches import Polygon, Rectangle
14
- from torch.utils.data import Dataset
15
- import webdataset as wds
16
-
17
- from minigpt4.datasets.datasets.base_dataset import BaseDataset
18
- from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19
-
20
-
21
- class OCRVQADataset(Dataset):
22
- def __init__(self, vis_processor, text_processor, vis_root, ann_path):
23
- """
24
- vis_root (string): Root directory of images (e.g. coco/images/)
25
- ann_root (string): directory to store the annotation file
26
- """
27
- self.vis_root = vis_root
28
-
29
- self.vis_processor = vis_processor
30
- self.text_processor = text_processor
31
- self.data = self.create_data(ann_path)
32
-
33
- self.instruction_pool =[
34
- "[vqa] {}",
35
- "[vqa] Based on the image, respond to this question with a short answer: {}"
36
- ]
37
-
38
- def create_data(self, ann_path):
39
- processed_data = []
40
- with open(ann_path, 'r') as f:
41
- data = json.load(f)
42
- for k in data.keys():
43
- if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test
44
- ext = os.path.splitext(data[k]['imageURL'])[1]
45
- imageFile = k + ext
46
- assert len(data[k]['questions']) == len(data[k]['answers'])
47
- for q, a in zip(data[k]['questions'], data[k]['answers']):
48
- processed_data.append(
49
- {'question': q,
50
- 'answer': a,
51
- 'image_path': imageFile,
52
- 'image_id': k,
53
- 'title': data[k]['title'],
54
- 'genre': data[k]['genre'],
55
- }
56
- )
57
- return processed_data
58
-
59
- def __len__(self):
60
- return len(self.data)
61
-
62
- def __getitem__(self, index):
63
- sample = self.data[index]
64
- image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
65
- image = self.vis_processor(image)
66
- question = self.text_processor(sample["question"])
67
- answer = self.text_processor(sample["answer"])
68
-
69
- instruction = random.choice(self.instruction_pool).format(question)
70
- instruction = "<Img><ImageHere></Img> {} ".format(instruction)
71
- return {
72
- "image": image,
73
- "instruction_input": instruction,
74
- "answer": answer,
75
- "image_id": sample['image_id']
76
- }
77
-