Spaces:
Sleeping
Sleeping
saicharan1234
commited on
Commit
•
6742cd5
1
Parent(s):
adbca6c
Delete minigpt4
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- minigpt4/__init__.py +0 -31
- minigpt4/common/__init__.py +0 -0
- minigpt4/common/config.py +0 -496
- minigpt4/common/dist_utils.py +0 -140
- minigpt4/common/eval_utils.py +0 -76
- minigpt4/common/gradcam.py +0 -24
- minigpt4/common/logger.py +0 -195
- minigpt4/common/optims.py +0 -119
- minigpt4/common/registry.py +0 -329
- minigpt4/common/utils.py +0 -424
- minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py +0 -89
- minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py +0 -1
- minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py +0 -192
- minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py +0 -73
- minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py +0 -1
- minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py +0 -179
- minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt +0 -81
- minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt +0 -65
- minigpt4/common/vqa_tools/VQA/README.md +0 -80
- minigpt4/common/vqa_tools/VQA/license.txt +0 -30
- minigpt4/common/vqa_tools/__init__.py +0 -8
- minigpt4/common/vqa_tools/vqa.py +0 -211
- minigpt4/common/vqa_tools/vqa_eval.py +0 -324
- minigpt4/configs/datasets/cc_combine/align.yaml +0 -16
- minigpt4/configs/datasets/cc_combine/defaults.yaml +0 -11
- minigpt4/configs/datasets/laion/defaults.yaml +0 -13
- minigpt4/configs/default.yaml +0 -5
- minigpt4/configs/models/minigpt4_vicuna0.yaml +0 -32
- minigpt4/conversation/__init__.py +0 -0
- minigpt4/conversation/conversation.py +0 -233
- minigpt4/datasets/__init__.py +0 -0
- minigpt4/datasets/builders/__init__.py +0 -72
- minigpt4/datasets/builders/base_dataset_builder.py +0 -236
- minigpt4/datasets/builders/image_text_pair_builder.py +0 -535
- minigpt4/datasets/data_utils.py +0 -199
- minigpt4/datasets/datasets/__init__.py +0 -0
- minigpt4/datasets/datasets/aok_vqa_datasets.py +0 -116
- minigpt4/datasets/datasets/base_dataset.py +0 -78
- minigpt4/datasets/datasets/caption_datasets.py +0 -151
- minigpt4/datasets/datasets/cc_sbu_dataset.py +0 -47
- minigpt4/datasets/datasets/coco_caption.py +0 -120
- minigpt4/datasets/datasets/coco_dataset.py +0 -348
- minigpt4/datasets/datasets/coco_vqa_datasets.py +0 -145
- minigpt4/datasets/datasets/dataloader_utils.py +0 -162
- minigpt4/datasets/datasets/flickr.py +0 -159
- minigpt4/datasets/datasets/gqa_datasets.py +0 -60
- minigpt4/datasets/datasets/laion_dataset.py +0 -31
- minigpt4/datasets/datasets/llava_dataset.py +0 -149
- minigpt4/datasets/datasets/multitask_conversation.py +0 -75
- minigpt4/datasets/datasets/ocrvqa_dataset.py +0 -77
minigpt4/__init__.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
import sys
|
10 |
-
|
11 |
-
from omegaconf import OmegaConf
|
12 |
-
|
13 |
-
from minigpt4.common.registry import registry
|
14 |
-
|
15 |
-
from minigpt4.datasets.builders import *
|
16 |
-
from minigpt4.models import *
|
17 |
-
from minigpt4.processors import *
|
18 |
-
from minigpt4.tasks import *
|
19 |
-
|
20 |
-
|
21 |
-
root_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
-
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
23 |
-
|
24 |
-
registry.register_path("library_root", root_dir)
|
25 |
-
repo_root = os.path.join(root_dir, "..")
|
26 |
-
registry.register_path("repo_root", repo_root)
|
27 |
-
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
28 |
-
registry.register_path("cache_root", cache_root)
|
29 |
-
|
30 |
-
registry.register("MAX_INT", sys.maxsize)
|
31 |
-
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/__init__.py
DELETED
File without changes
|
minigpt4/common/config.py
DELETED
@@ -1,496 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import json
|
10 |
-
from typing import Dict
|
11 |
-
|
12 |
-
from omegaconf import OmegaConf
|
13 |
-
from minigpt4.common.registry import registry
|
14 |
-
|
15 |
-
|
16 |
-
class Config:
|
17 |
-
def __init__(self, args):
|
18 |
-
self.config = {}
|
19 |
-
|
20 |
-
self.args = args
|
21 |
-
|
22 |
-
# Register the config and configuration for setup
|
23 |
-
registry.register("configuration", self)
|
24 |
-
|
25 |
-
user_config = self._build_opt_list(self.args.options)
|
26 |
-
|
27 |
-
config = OmegaConf.load(self.args.cfg_path)
|
28 |
-
|
29 |
-
runner_config = self.build_runner_config(config)
|
30 |
-
model_config = self.build_model_config(config, **user_config)
|
31 |
-
dataset_config = self.build_dataset_config(config)
|
32 |
-
evaluation_dataset_config = self.build_evaluation_dataset_config(config)
|
33 |
-
|
34 |
-
# Validate the user-provided runner configuration
|
35 |
-
# model and dataset configuration are supposed to be validated by the respective classes
|
36 |
-
# [TODO] validate the model/dataset configuration
|
37 |
-
# self._validate_runner_config(runner_config)
|
38 |
-
|
39 |
-
# Override the default configuration with user options.
|
40 |
-
self.config = OmegaConf.merge(
|
41 |
-
runner_config, model_config, dataset_config,evaluation_dataset_config, user_config
|
42 |
-
)
|
43 |
-
|
44 |
-
def _validate_runner_config(self, runner_config):
|
45 |
-
"""
|
46 |
-
This method validates the configuration, such that
|
47 |
-
1) all the user specified options are valid;
|
48 |
-
2) no type mismatches between the user specified options and the config.
|
49 |
-
"""
|
50 |
-
runner_config_validator = create_runner_config_validator()
|
51 |
-
runner_config_validator.validate(runner_config)
|
52 |
-
|
53 |
-
def _build_opt_list(self, opts):
|
54 |
-
opts_dot_list = self._convert_to_dot_list(opts)
|
55 |
-
return OmegaConf.from_dotlist(opts_dot_list)
|
56 |
-
|
57 |
-
@staticmethod
|
58 |
-
def build_model_config(config, **kwargs):
|
59 |
-
model = config.get("model", None)
|
60 |
-
assert model is not None, "Missing model configuration file."
|
61 |
-
|
62 |
-
model_cls = registry.get_model_class(model.arch)
|
63 |
-
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
64 |
-
|
65 |
-
model_type = kwargs.get("model.model_type", None)
|
66 |
-
if not model_type:
|
67 |
-
model_type = model.get("model_type", None)
|
68 |
-
# else use the model type selected by user.
|
69 |
-
|
70 |
-
assert model_type is not None, "Missing model_type."
|
71 |
-
|
72 |
-
model_config_path = model_cls.default_config_path(model_type=model_type)
|
73 |
-
|
74 |
-
model_config = OmegaConf.create()
|
75 |
-
# hierarchy override, customized config > default config
|
76 |
-
model_config = OmegaConf.merge(
|
77 |
-
model_config,
|
78 |
-
OmegaConf.load(model_config_path),
|
79 |
-
{"model": config["model"]},
|
80 |
-
)
|
81 |
-
|
82 |
-
return model_config
|
83 |
-
|
84 |
-
@staticmethod
|
85 |
-
def build_runner_config(config):
|
86 |
-
return {"run": config.run}
|
87 |
-
|
88 |
-
@staticmethod
|
89 |
-
def build_dataset_config(config):
|
90 |
-
datasets = config.get("datasets", None)
|
91 |
-
if datasets is None:
|
92 |
-
raise KeyError(
|
93 |
-
"Expecting 'datasets' as the root key for dataset configuration."
|
94 |
-
)
|
95 |
-
|
96 |
-
dataset_config = OmegaConf.create()
|
97 |
-
|
98 |
-
for dataset_name in datasets:
|
99 |
-
builder_cls = registry.get_builder_class(dataset_name)
|
100 |
-
|
101 |
-
dataset_config_type = datasets[dataset_name].get("type", "default")
|
102 |
-
dataset_config_path = builder_cls.default_config_path(
|
103 |
-
type=dataset_config_type
|
104 |
-
)
|
105 |
-
|
106 |
-
# hierarchy override, customized config > default config
|
107 |
-
dataset_config = OmegaConf.merge(
|
108 |
-
dataset_config,
|
109 |
-
OmegaConf.load(dataset_config_path),
|
110 |
-
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
111 |
-
)
|
112 |
-
|
113 |
-
return dataset_config
|
114 |
-
|
115 |
-
|
116 |
-
@staticmethod
|
117 |
-
def build_evaluation_dataset_config(config):
|
118 |
-
datasets = config.get("evaluation_datasets", None)
|
119 |
-
# if datasets is None:
|
120 |
-
# raise KeyError(
|
121 |
-
# "Expecting 'datasets' as the root key for dataset configuration."
|
122 |
-
# )
|
123 |
-
|
124 |
-
dataset_config = OmegaConf.create()
|
125 |
-
|
126 |
-
if datasets is not None:
|
127 |
-
for dataset_name in datasets:
|
128 |
-
builder_cls = registry.get_builder_class(dataset_name)
|
129 |
-
|
130 |
-
# hierarchy override, customized config > default config
|
131 |
-
dataset_config = OmegaConf.merge(
|
132 |
-
dataset_config,
|
133 |
-
{"evaluation_datasets": {dataset_name: config["evaluation_datasets"][dataset_name]}},
|
134 |
-
)
|
135 |
-
|
136 |
-
return dataset_config
|
137 |
-
|
138 |
-
def _convert_to_dot_list(self, opts):
|
139 |
-
if opts is None:
|
140 |
-
opts = []
|
141 |
-
|
142 |
-
if len(opts) == 0:
|
143 |
-
return opts
|
144 |
-
|
145 |
-
has_equal = opts[0].find("=") != -1
|
146 |
-
|
147 |
-
if has_equal:
|
148 |
-
return opts
|
149 |
-
|
150 |
-
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
151 |
-
|
152 |
-
def get_config(self):
|
153 |
-
return self.config
|
154 |
-
|
155 |
-
@property
|
156 |
-
def run_cfg(self):
|
157 |
-
return self.config.run
|
158 |
-
|
159 |
-
@property
|
160 |
-
def datasets_cfg(self):
|
161 |
-
return self.config.datasets
|
162 |
-
|
163 |
-
@property
|
164 |
-
def evaluation_datasets_cfg(self):
|
165 |
-
return self.config.evaluation_datasets
|
166 |
-
|
167 |
-
@property
|
168 |
-
def model_cfg(self):
|
169 |
-
return self.config.model
|
170 |
-
|
171 |
-
def pretty_print(self):
|
172 |
-
logging.info("\n===== Running Parameters =====")
|
173 |
-
logging.info(self._convert_node_to_json(self.config.run))
|
174 |
-
|
175 |
-
logging.info("\n====== Dataset Attributes ======")
|
176 |
-
datasets = self.config.datasets
|
177 |
-
|
178 |
-
for dataset in datasets:
|
179 |
-
if dataset in self.config.datasets:
|
180 |
-
logging.info(f"\n======== {dataset} =======")
|
181 |
-
dataset_config = self.config.datasets[dataset]
|
182 |
-
logging.info(self._convert_node_to_json(dataset_config))
|
183 |
-
else:
|
184 |
-
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
185 |
-
|
186 |
-
logging.info(f"\n====== Model Attributes ======")
|
187 |
-
logging.info(self._convert_node_to_json(self.config.model))
|
188 |
-
|
189 |
-
def _convert_node_to_json(self, node):
|
190 |
-
container = OmegaConf.to_container(node, resolve=True)
|
191 |
-
return json.dumps(container, indent=4, sort_keys=True)
|
192 |
-
|
193 |
-
def to_dict(self):
|
194 |
-
return OmegaConf.to_container(self.config)
|
195 |
-
|
196 |
-
|
197 |
-
def node_to_dict(node):
|
198 |
-
return OmegaConf.to_container(node)
|
199 |
-
|
200 |
-
|
201 |
-
class ConfigValidator:
|
202 |
-
"""
|
203 |
-
This is a preliminary implementation to centralize and validate the configuration.
|
204 |
-
May be altered in the future.
|
205 |
-
|
206 |
-
A helper class to validate configurations from yaml file.
|
207 |
-
|
208 |
-
This serves the following purposes:
|
209 |
-
1. Ensure all the options in the yaml are defined, raise error if not.
|
210 |
-
2. when type mismatches are found, the validator will raise an error.
|
211 |
-
3. a central place to store and display helpful messages for supported configurations.
|
212 |
-
|
213 |
-
"""
|
214 |
-
|
215 |
-
class _Argument:
|
216 |
-
def __init__(self, name, choices=None, type=None, help=None):
|
217 |
-
self.name = name
|
218 |
-
self.val = None
|
219 |
-
self.choices = choices
|
220 |
-
self.type = type
|
221 |
-
self.help = help
|
222 |
-
|
223 |
-
def __str__(self):
|
224 |
-
s = f"{self.name}={self.val}"
|
225 |
-
if self.type is not None:
|
226 |
-
s += f", ({self.type})"
|
227 |
-
if self.choices is not None:
|
228 |
-
s += f", choices: {self.choices}"
|
229 |
-
if self.help is not None:
|
230 |
-
s += f", ({self.help})"
|
231 |
-
return s
|
232 |
-
|
233 |
-
def __init__(self, description):
|
234 |
-
self.description = description
|
235 |
-
|
236 |
-
self.arguments = dict()
|
237 |
-
|
238 |
-
self.parsed_args = None
|
239 |
-
|
240 |
-
def __getitem__(self, key):
|
241 |
-
assert self.parsed_args is not None, "No arguments parsed yet."
|
242 |
-
|
243 |
-
return self.parsed_args[key]
|
244 |
-
|
245 |
-
def __str__(self) -> str:
|
246 |
-
return self.format_help()
|
247 |
-
|
248 |
-
def add_argument(self, *args, **kwargs):
|
249 |
-
"""
|
250 |
-
Assume the first argument is the name of the argument.
|
251 |
-
"""
|
252 |
-
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
253 |
-
|
254 |
-
def validate(self, config=None):
|
255 |
-
"""
|
256 |
-
Convert yaml config (dict-like) to list, required by argparse.
|
257 |
-
"""
|
258 |
-
for k, v in config.items():
|
259 |
-
assert (
|
260 |
-
k in self.arguments
|
261 |
-
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
262 |
-
|
263 |
-
if self.arguments[k].type is not None:
|
264 |
-
try:
|
265 |
-
self.arguments[k].val = self.arguments[k].type(v)
|
266 |
-
except ValueError:
|
267 |
-
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
268 |
-
|
269 |
-
if self.arguments[k].choices is not None:
|
270 |
-
assert (
|
271 |
-
v in self.arguments[k].choices
|
272 |
-
), f"""{k} must be one of {self.arguments[k].choices}."""
|
273 |
-
|
274 |
-
return config
|
275 |
-
|
276 |
-
def format_arguments(self):
|
277 |
-
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
278 |
-
|
279 |
-
def format_help(self):
|
280 |
-
# description + key-value pair string for each argument
|
281 |
-
help_msg = str(self.description)
|
282 |
-
return help_msg + ", available arguments: " + self.format_arguments()
|
283 |
-
|
284 |
-
def print_help(self):
|
285 |
-
# display help message
|
286 |
-
print(self.format_help())
|
287 |
-
|
288 |
-
|
289 |
-
def create_runner_config_validator():
|
290 |
-
validator = ConfigValidator(description="Runner configurations")
|
291 |
-
|
292 |
-
validator.add_argument(
|
293 |
-
"runner",
|
294 |
-
type=str,
|
295 |
-
choices=["runner_base", "runner_iter"],
|
296 |
-
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
297 |
-
runner runs based on iters. Default: runner_base""",
|
298 |
-
)
|
299 |
-
# add argumetns for training dataset ratios
|
300 |
-
validator.add_argument(
|
301 |
-
"train_dataset_ratios",
|
302 |
-
type=Dict[str, float],
|
303 |
-
help="""Ratios of training dataset. This is used in iteration-based runner.
|
304 |
-
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
305 |
-
Default: None""",
|
306 |
-
)
|
307 |
-
validator.add_argument(
|
308 |
-
"max_iters",
|
309 |
-
type=float,
|
310 |
-
help="Maximum number of iterations to run.",
|
311 |
-
)
|
312 |
-
validator.add_argument(
|
313 |
-
"max_epoch",
|
314 |
-
type=int,
|
315 |
-
help="Maximum number of epochs to run.",
|
316 |
-
)
|
317 |
-
# add arguments for iters_per_inner_epoch
|
318 |
-
validator.add_argument(
|
319 |
-
"iters_per_inner_epoch",
|
320 |
-
type=float,
|
321 |
-
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
322 |
-
)
|
323 |
-
lr_scheds_choices = registry.list_lr_schedulers()
|
324 |
-
validator.add_argument(
|
325 |
-
"lr_sched",
|
326 |
-
type=str,
|
327 |
-
choices=lr_scheds_choices,
|
328 |
-
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
329 |
-
)
|
330 |
-
task_choices = registry.list_tasks()
|
331 |
-
validator.add_argument(
|
332 |
-
"task",
|
333 |
-
type=str,
|
334 |
-
choices=task_choices,
|
335 |
-
help="Task to use, from {}".format(task_choices),
|
336 |
-
)
|
337 |
-
# add arguments for init_lr
|
338 |
-
validator.add_argument(
|
339 |
-
"init_lr",
|
340 |
-
type=float,
|
341 |
-
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
342 |
-
)
|
343 |
-
# add arguments for min_lr
|
344 |
-
validator.add_argument(
|
345 |
-
"min_lr",
|
346 |
-
type=float,
|
347 |
-
help="Minimum learning rate (after decay).",
|
348 |
-
)
|
349 |
-
# add arguments for warmup_lr
|
350 |
-
validator.add_argument(
|
351 |
-
"warmup_lr",
|
352 |
-
type=float,
|
353 |
-
help="Starting learning rate for warmup.",
|
354 |
-
)
|
355 |
-
# add arguments for learning rate decay rate
|
356 |
-
validator.add_argument(
|
357 |
-
"lr_decay_rate",
|
358 |
-
type=float,
|
359 |
-
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
360 |
-
)
|
361 |
-
# add arguments for weight decay
|
362 |
-
validator.add_argument(
|
363 |
-
"weight_decay",
|
364 |
-
type=float,
|
365 |
-
help="Weight decay rate.",
|
366 |
-
)
|
367 |
-
# add arguments for training batch size
|
368 |
-
validator.add_argument(
|
369 |
-
"batch_size_train",
|
370 |
-
type=int,
|
371 |
-
help="Training batch size.",
|
372 |
-
)
|
373 |
-
# add arguments for evaluation batch size
|
374 |
-
validator.add_argument(
|
375 |
-
"batch_size_eval",
|
376 |
-
type=int,
|
377 |
-
help="Evaluation batch size, including validation and testing.",
|
378 |
-
)
|
379 |
-
# add arguments for number of workers for data loading
|
380 |
-
validator.add_argument(
|
381 |
-
"num_workers",
|
382 |
-
help="Number of workers for data loading.",
|
383 |
-
)
|
384 |
-
# add arguments for warm up steps
|
385 |
-
validator.add_argument(
|
386 |
-
"warmup_steps",
|
387 |
-
type=int,
|
388 |
-
help="Number of warmup steps. Required if a warmup schedule is used.",
|
389 |
-
)
|
390 |
-
# add arguments for random seed
|
391 |
-
validator.add_argument(
|
392 |
-
"seed",
|
393 |
-
type=int,
|
394 |
-
help="Random seed.",
|
395 |
-
)
|
396 |
-
# add arguments for output directory
|
397 |
-
validator.add_argument(
|
398 |
-
"output_dir",
|
399 |
-
type=str,
|
400 |
-
help="Output directory to save checkpoints and logs.",
|
401 |
-
)
|
402 |
-
# add arguments for whether only use evaluation
|
403 |
-
validator.add_argument(
|
404 |
-
"evaluate",
|
405 |
-
help="Whether to only evaluate the model. If true, training will not be performed.",
|
406 |
-
)
|
407 |
-
# add arguments for splits used for training, e.g. ["train", "val"]
|
408 |
-
validator.add_argument(
|
409 |
-
"train_splits",
|
410 |
-
type=list,
|
411 |
-
help="Splits to use for training.",
|
412 |
-
)
|
413 |
-
# add arguments for splits used for validation, e.g. ["val"]
|
414 |
-
validator.add_argument(
|
415 |
-
"valid_splits",
|
416 |
-
type=list,
|
417 |
-
help="Splits to use for validation. If not provided, will skip the validation.",
|
418 |
-
)
|
419 |
-
# add arguments for splits used for testing, e.g. ["test"]
|
420 |
-
validator.add_argument(
|
421 |
-
"test_splits",
|
422 |
-
type=list,
|
423 |
-
help="Splits to use for testing. If not provided, will skip the testing.",
|
424 |
-
)
|
425 |
-
# add arguments for accumulating gradient for iterations
|
426 |
-
validator.add_argument(
|
427 |
-
"accum_grad_iters",
|
428 |
-
type=int,
|
429 |
-
help="Number of iterations to accumulate gradient for.",
|
430 |
-
)
|
431 |
-
|
432 |
-
# ====== distributed training ======
|
433 |
-
validator.add_argument(
|
434 |
-
"device",
|
435 |
-
type=str,
|
436 |
-
choices=["cpu", "cuda"],
|
437 |
-
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
438 |
-
)
|
439 |
-
validator.add_argument(
|
440 |
-
"world_size",
|
441 |
-
type=int,
|
442 |
-
help="Number of processes participating in the job.",
|
443 |
-
)
|
444 |
-
validator.add_argument("dist_url", type=str)
|
445 |
-
validator.add_argument("distributed", type=bool)
|
446 |
-
# add arguments to opt using distributed sampler during evaluation or not
|
447 |
-
validator.add_argument(
|
448 |
-
"use_dist_eval_sampler",
|
449 |
-
type=bool,
|
450 |
-
help="Whether to use distributed sampler during evaluation or not.",
|
451 |
-
)
|
452 |
-
|
453 |
-
# ====== task specific ======
|
454 |
-
# generation task specific arguments
|
455 |
-
# add arguments for maximal length of text output
|
456 |
-
validator.add_argument(
|
457 |
-
"max_len",
|
458 |
-
type=int,
|
459 |
-
help="Maximal length of text output.",
|
460 |
-
)
|
461 |
-
# add arguments for minimal length of text output
|
462 |
-
validator.add_argument(
|
463 |
-
"min_len",
|
464 |
-
type=int,
|
465 |
-
help="Minimal length of text output.",
|
466 |
-
)
|
467 |
-
# add arguments number of beams
|
468 |
-
validator.add_argument(
|
469 |
-
"num_beams",
|
470 |
-
type=int,
|
471 |
-
help="Number of beams used for beam search.",
|
472 |
-
)
|
473 |
-
|
474 |
-
# vqa task specific arguments
|
475 |
-
# add arguments for number of answer candidates
|
476 |
-
validator.add_argument(
|
477 |
-
"num_ans_candidates",
|
478 |
-
type=int,
|
479 |
-
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
480 |
-
)
|
481 |
-
# add arguments for inference method
|
482 |
-
validator.add_argument(
|
483 |
-
"inference_method",
|
484 |
-
type=str,
|
485 |
-
choices=["genearte", "rank"],
|
486 |
-
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
487 |
-
)
|
488 |
-
|
489 |
-
# ====== model specific ======
|
490 |
-
validator.add_argument(
|
491 |
-
"k_test",
|
492 |
-
type=int,
|
493 |
-
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
494 |
-
)
|
495 |
-
|
496 |
-
return validator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/dist_utils.py
DELETED
@@ -1,140 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import datetime
|
9 |
-
import functools
|
10 |
-
import os
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import torch.distributed as dist
|
14 |
-
import timm.models.hub as timm_hub
|
15 |
-
|
16 |
-
|
17 |
-
def setup_for_distributed(is_master):
|
18 |
-
"""
|
19 |
-
This function disables printing when not in master process
|
20 |
-
"""
|
21 |
-
import builtins as __builtin__
|
22 |
-
|
23 |
-
builtin_print = __builtin__.print
|
24 |
-
|
25 |
-
def print(*args, **kwargs):
|
26 |
-
force = kwargs.pop("force", False)
|
27 |
-
if is_master or force:
|
28 |
-
builtin_print(*args, **kwargs)
|
29 |
-
|
30 |
-
__builtin__.print = print
|
31 |
-
|
32 |
-
|
33 |
-
def is_dist_avail_and_initialized():
|
34 |
-
if not dist.is_available():
|
35 |
-
return False
|
36 |
-
if not dist.is_initialized():
|
37 |
-
return False
|
38 |
-
return True
|
39 |
-
|
40 |
-
|
41 |
-
def get_world_size():
|
42 |
-
if not is_dist_avail_and_initialized():
|
43 |
-
return 1
|
44 |
-
return dist.get_world_size()
|
45 |
-
|
46 |
-
|
47 |
-
def get_rank():
|
48 |
-
if not is_dist_avail_and_initialized():
|
49 |
-
return 0
|
50 |
-
return dist.get_rank()
|
51 |
-
|
52 |
-
|
53 |
-
def is_main_process():
|
54 |
-
return get_rank() == 0
|
55 |
-
|
56 |
-
|
57 |
-
def init_distributed_mode(args):
|
58 |
-
if args.distributed is False:
|
59 |
-
print("Not using distributed mode")
|
60 |
-
return
|
61 |
-
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
62 |
-
args.rank = int(os.environ["RANK"])
|
63 |
-
args.world_size = int(os.environ["WORLD_SIZE"])
|
64 |
-
args.gpu = int(os.environ["LOCAL_RANK"])
|
65 |
-
elif "SLURM_PROCID" in os.environ:
|
66 |
-
args.rank = int(os.environ["SLURM_PROCID"])
|
67 |
-
args.gpu = args.rank % torch.cuda.device_count()
|
68 |
-
else:
|
69 |
-
print("Not using distributed mode")
|
70 |
-
args.distributed = False
|
71 |
-
return
|
72 |
-
|
73 |
-
args.distributed = True
|
74 |
-
|
75 |
-
torch.cuda.set_device(args.gpu)
|
76 |
-
args.dist_backend = "nccl"
|
77 |
-
print(
|
78 |
-
"| distributed init (rank {}, world {}): {}".format(
|
79 |
-
args.rank, args.world_size, args.dist_url
|
80 |
-
),
|
81 |
-
flush=True,
|
82 |
-
)
|
83 |
-
torch.distributed.init_process_group(
|
84 |
-
backend=args.dist_backend,
|
85 |
-
init_method=args.dist_url,
|
86 |
-
world_size=args.world_size,
|
87 |
-
rank=args.rank,
|
88 |
-
timeout=datetime.timedelta(
|
89 |
-
days=365
|
90 |
-
), # allow auto-downloading and de-compressing
|
91 |
-
)
|
92 |
-
torch.distributed.barrier()
|
93 |
-
setup_for_distributed(args.rank == 0)
|
94 |
-
|
95 |
-
|
96 |
-
def get_dist_info():
|
97 |
-
if torch.__version__ < "1.0":
|
98 |
-
initialized = dist._initialized
|
99 |
-
else:
|
100 |
-
initialized = dist.is_initialized()
|
101 |
-
if initialized:
|
102 |
-
rank = dist.get_rank()
|
103 |
-
world_size = dist.get_world_size()
|
104 |
-
else: # non-distributed training
|
105 |
-
rank = 0
|
106 |
-
world_size = 1
|
107 |
-
return rank, world_size
|
108 |
-
|
109 |
-
|
110 |
-
def main_process(func):
|
111 |
-
@functools.wraps(func)
|
112 |
-
def wrapper(*args, **kwargs):
|
113 |
-
rank, _ = get_dist_info()
|
114 |
-
if rank == 0:
|
115 |
-
return func(*args, **kwargs)
|
116 |
-
|
117 |
-
return wrapper
|
118 |
-
|
119 |
-
|
120 |
-
def download_cached_file(url, check_hash=True, progress=False):
|
121 |
-
"""
|
122 |
-
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
123 |
-
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
124 |
-
"""
|
125 |
-
|
126 |
-
def get_cached_file_path():
|
127 |
-
# a hack to sync the file path across processes
|
128 |
-
parts = torch.hub.urlparse(url)
|
129 |
-
filename = os.path.basename(parts.path)
|
130 |
-
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
131 |
-
|
132 |
-
return cached_file
|
133 |
-
|
134 |
-
if is_main_process():
|
135 |
-
timm_hub.download_cached_file(url, check_hash, progress)
|
136 |
-
|
137 |
-
if is_dist_avail_and_initialized():
|
138 |
-
dist.barrier()
|
139 |
-
|
140 |
-
return get_cached_file_path()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/eval_utils.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import numpy as np
|
3 |
-
from nltk.translate.bleu_score import sentence_bleu
|
4 |
-
|
5 |
-
from minigpt4.common.registry import registry
|
6 |
-
from minigpt4.common.config import Config
|
7 |
-
|
8 |
-
# imports modules for registration
|
9 |
-
from minigpt4.datasets.builders import *
|
10 |
-
from minigpt4.models import *
|
11 |
-
from minigpt4.processors import *
|
12 |
-
from minigpt4.runners import *
|
13 |
-
from minigpt4.tasks import *
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
def eval_parser():
|
18 |
-
parser = argparse.ArgumentParser(description="Demo")
|
19 |
-
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
20 |
-
parser.add_argument("--name", type=str, default='A2', help="evaluation name")
|
21 |
-
parser.add_argument("--ckpt", type=str, help="path to configuration file.")
|
22 |
-
parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
|
23 |
-
parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens")
|
24 |
-
parser.add_argument("--batch_size", type=int, default=32)
|
25 |
-
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
|
26 |
-
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
|
27 |
-
parser.add_argument(
|
28 |
-
"--options",
|
29 |
-
nargs="+",
|
30 |
-
help="override some settings in the used config, the key-value pair "
|
31 |
-
"in xxx=yyy format will be merged into config file (deprecate), "
|
32 |
-
"change to --cfg-options instead.",
|
33 |
-
)
|
34 |
-
return parser
|
35 |
-
|
36 |
-
|
37 |
-
def prepare_texts(texts, conv_temp):
|
38 |
-
convs = [conv_temp.copy() for _ in range(len(texts))]
|
39 |
-
[conv.append_message(
|
40 |
-
conv.roles[0], '<Img><ImageHere></Img> {}'.format(text)) for conv, text in zip(convs, texts)]
|
41 |
-
[conv.append_message(conv.roles[1], None) for conv in convs]
|
42 |
-
texts = [conv.get_prompt() for conv in convs]
|
43 |
-
return texts
|
44 |
-
|
45 |
-
|
46 |
-
def init_model(args):
|
47 |
-
print('Initialization Model')
|
48 |
-
cfg = Config(args)
|
49 |
-
# cfg.model_cfg.ckpt = args.ckpt
|
50 |
-
# cfg.model_cfg.lora_r = args.lora_r
|
51 |
-
# cfg.model_cfg.lora_alpha = args.lora_alpha
|
52 |
-
|
53 |
-
model_config = cfg.model_cfg
|
54 |
-
model_cls = registry.get_model_class(model_config.arch)
|
55 |
-
model = model_cls.from_config(model_config).to('cuda:0')
|
56 |
-
|
57 |
-
# import pudb; pudb.set_trace()
|
58 |
-
key = list(cfg.datasets_cfg.keys())[0]
|
59 |
-
vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
|
60 |
-
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
61 |
-
print('Initialization Finished')
|
62 |
-
return model, vis_processor
|
63 |
-
|
64 |
-
def computeIoU(bbox1, bbox2):
|
65 |
-
x1, y1, x2, y2 = bbox1
|
66 |
-
x3, y3, x4, y4 = bbox2
|
67 |
-
intersection_x1 = max(x1, x3)
|
68 |
-
intersection_y1 = max(y1, y3)
|
69 |
-
intersection_x2 = min(x2, x4)
|
70 |
-
intersection_y2 = min(y2, y4)
|
71 |
-
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
|
72 |
-
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
73 |
-
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
|
74 |
-
union_area = bbox1_area + bbox2_area - intersection_area
|
75 |
-
iou = intersection_area / union_area
|
76 |
-
return iou
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/gradcam.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from matplotlib import pyplot as plt
|
3 |
-
from scipy.ndimage import filters
|
4 |
-
from skimage import transform as skimage_transform
|
5 |
-
|
6 |
-
|
7 |
-
def getAttMap(img, attMap, blur=True, overlap=True):
|
8 |
-
attMap -= attMap.min()
|
9 |
-
if attMap.max() > 0:
|
10 |
-
attMap /= attMap.max()
|
11 |
-
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
12 |
-
if blur:
|
13 |
-
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
14 |
-
attMap -= attMap.min()
|
15 |
-
attMap /= attMap.max()
|
16 |
-
cmap = plt.get_cmap("jet")
|
17 |
-
attMapV = cmap(attMap)
|
18 |
-
attMapV = np.delete(attMapV, 3, 2)
|
19 |
-
if overlap:
|
20 |
-
attMap = (
|
21 |
-
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
22 |
-
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
23 |
-
)
|
24 |
-
return attMap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/logger.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import datetime
|
9 |
-
import logging
|
10 |
-
import time
|
11 |
-
from collections import defaultdict, deque
|
12 |
-
|
13 |
-
import torch
|
14 |
-
import torch.distributed as dist
|
15 |
-
|
16 |
-
from minigpt4.common import dist_utils
|
17 |
-
|
18 |
-
|
19 |
-
class SmoothedValue(object):
|
20 |
-
"""Track a series of values and provide access to smoothed values over a
|
21 |
-
window or the global series average.
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self, window_size=20, fmt=None):
|
25 |
-
if fmt is None:
|
26 |
-
fmt = "{median:.4f} ({global_avg:.4f})"
|
27 |
-
self.deque = deque(maxlen=window_size)
|
28 |
-
self.total = 0.0
|
29 |
-
self.count = 0
|
30 |
-
self.fmt = fmt
|
31 |
-
|
32 |
-
def update(self, value, n=1):
|
33 |
-
self.deque.append(value)
|
34 |
-
self.count += n
|
35 |
-
self.total += value * n
|
36 |
-
|
37 |
-
def synchronize_between_processes(self):
|
38 |
-
"""
|
39 |
-
Warning: does not synchronize the deque!
|
40 |
-
"""
|
41 |
-
if not dist_utils.is_dist_avail_and_initialized():
|
42 |
-
return
|
43 |
-
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
44 |
-
dist.barrier()
|
45 |
-
dist.all_reduce(t)
|
46 |
-
t = t.tolist()
|
47 |
-
self.count = int(t[0])
|
48 |
-
self.total = t[1]
|
49 |
-
|
50 |
-
@property
|
51 |
-
def median(self):
|
52 |
-
d = torch.tensor(list(self.deque))
|
53 |
-
return d.median().item()
|
54 |
-
|
55 |
-
@property
|
56 |
-
def avg(self):
|
57 |
-
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
58 |
-
return d.mean().item()
|
59 |
-
|
60 |
-
@property
|
61 |
-
def global_avg(self):
|
62 |
-
return self.total / self.count
|
63 |
-
|
64 |
-
@property
|
65 |
-
def max(self):
|
66 |
-
return max(self.deque)
|
67 |
-
|
68 |
-
@property
|
69 |
-
def value(self):
|
70 |
-
return self.deque[-1]
|
71 |
-
|
72 |
-
def __str__(self):
|
73 |
-
return self.fmt.format(
|
74 |
-
median=self.median,
|
75 |
-
avg=self.avg,
|
76 |
-
global_avg=self.global_avg,
|
77 |
-
max=self.max,
|
78 |
-
value=self.value,
|
79 |
-
)
|
80 |
-
|
81 |
-
|
82 |
-
class MetricLogger(object):
|
83 |
-
def __init__(self, delimiter="\t"):
|
84 |
-
self.meters = defaultdict(SmoothedValue)
|
85 |
-
self.delimiter = delimiter
|
86 |
-
|
87 |
-
def update(self, **kwargs):
|
88 |
-
for k, v in kwargs.items():
|
89 |
-
if isinstance(v, torch.Tensor):
|
90 |
-
v = v.item()
|
91 |
-
assert isinstance(v, (float, int))
|
92 |
-
self.meters[k].update(v)
|
93 |
-
|
94 |
-
def __getattr__(self, attr):
|
95 |
-
if attr in self.meters:
|
96 |
-
return self.meters[attr]
|
97 |
-
if attr in self.__dict__:
|
98 |
-
return self.__dict__[attr]
|
99 |
-
raise AttributeError(
|
100 |
-
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
101 |
-
)
|
102 |
-
|
103 |
-
def __str__(self):
|
104 |
-
loss_str = []
|
105 |
-
for name, meter in self.meters.items():
|
106 |
-
loss_str.append("{}: {}".format(name, str(meter)))
|
107 |
-
return self.delimiter.join(loss_str)
|
108 |
-
|
109 |
-
def global_avg(self):
|
110 |
-
loss_str = []
|
111 |
-
for name, meter in self.meters.items():
|
112 |
-
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
113 |
-
return self.delimiter.join(loss_str)
|
114 |
-
|
115 |
-
def synchronize_between_processes(self):
|
116 |
-
for meter in self.meters.values():
|
117 |
-
meter.synchronize_between_processes()
|
118 |
-
|
119 |
-
def add_meter(self, name, meter):
|
120 |
-
self.meters[name] = meter
|
121 |
-
|
122 |
-
def log_every(self, iterable, print_freq, header=None):
|
123 |
-
i = 0
|
124 |
-
if not header:
|
125 |
-
header = ""
|
126 |
-
start_time = time.time()
|
127 |
-
end = time.time()
|
128 |
-
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
129 |
-
data_time = SmoothedValue(fmt="{avg:.4f}")
|
130 |
-
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
131 |
-
log_msg = [
|
132 |
-
header,
|
133 |
-
"[{0" + space_fmt + "}/{1}]",
|
134 |
-
"eta: {eta}",
|
135 |
-
"{meters}",
|
136 |
-
"time: {time}",
|
137 |
-
"data: {data}",
|
138 |
-
]
|
139 |
-
if torch.cuda.is_available():
|
140 |
-
log_msg.append("max mem: {memory:.0f}")
|
141 |
-
log_msg = self.delimiter.join(log_msg)
|
142 |
-
MB = 1024.0 * 1024.0
|
143 |
-
for obj in iterable:
|
144 |
-
data_time.update(time.time() - end)
|
145 |
-
yield obj
|
146 |
-
iter_time.update(time.time() - end)
|
147 |
-
if i % print_freq == 0 or i == len(iterable) - 1:
|
148 |
-
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
149 |
-
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
150 |
-
if torch.cuda.is_available():
|
151 |
-
print(
|
152 |
-
log_msg.format(
|
153 |
-
i,
|
154 |
-
len(iterable),
|
155 |
-
eta=eta_string,
|
156 |
-
meters=str(self),
|
157 |
-
time=str(iter_time),
|
158 |
-
data=str(data_time),
|
159 |
-
memory=torch.cuda.max_memory_allocated() / MB,
|
160 |
-
)
|
161 |
-
)
|
162 |
-
else:
|
163 |
-
print(
|
164 |
-
log_msg.format(
|
165 |
-
i,
|
166 |
-
len(iterable),
|
167 |
-
eta=eta_string,
|
168 |
-
meters=str(self),
|
169 |
-
time=str(iter_time),
|
170 |
-
data=str(data_time),
|
171 |
-
)
|
172 |
-
)
|
173 |
-
i += 1
|
174 |
-
end = time.time()
|
175 |
-
total_time = time.time() - start_time
|
176 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
177 |
-
print(
|
178 |
-
"{} Total time: {} ({:.4f} s / it)".format(
|
179 |
-
header, total_time_str, total_time / len(iterable)
|
180 |
-
)
|
181 |
-
)
|
182 |
-
|
183 |
-
|
184 |
-
class AttrDict(dict):
|
185 |
-
def __init__(self, *args, **kwargs):
|
186 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
187 |
-
self.__dict__ = self
|
188 |
-
|
189 |
-
|
190 |
-
def setup_logger():
|
191 |
-
logging.basicConfig(
|
192 |
-
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
193 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
194 |
-
handlers=[logging.StreamHandler()],
|
195 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/optims.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import math
|
9 |
-
|
10 |
-
from minigpt4.common.registry import registry
|
11 |
-
|
12 |
-
|
13 |
-
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
14 |
-
class LinearWarmupStepLRScheduler:
|
15 |
-
def __init__(
|
16 |
-
self,
|
17 |
-
optimizer,
|
18 |
-
max_epoch,
|
19 |
-
min_lr,
|
20 |
-
init_lr,
|
21 |
-
decay_rate=1,
|
22 |
-
warmup_start_lr=-1,
|
23 |
-
warmup_steps=0,
|
24 |
-
**kwargs
|
25 |
-
):
|
26 |
-
self.optimizer = optimizer
|
27 |
-
|
28 |
-
self.max_epoch = max_epoch
|
29 |
-
self.min_lr = min_lr
|
30 |
-
|
31 |
-
self.decay_rate = decay_rate
|
32 |
-
|
33 |
-
self.init_lr = init_lr
|
34 |
-
self.warmup_steps = warmup_steps
|
35 |
-
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
36 |
-
|
37 |
-
def step(self, cur_epoch, cur_step):
|
38 |
-
if cur_epoch == 0:
|
39 |
-
warmup_lr_schedule(
|
40 |
-
step=cur_step,
|
41 |
-
optimizer=self.optimizer,
|
42 |
-
max_step=self.warmup_steps,
|
43 |
-
init_lr=self.warmup_start_lr,
|
44 |
-
max_lr=self.init_lr,
|
45 |
-
)
|
46 |
-
else:
|
47 |
-
step_lr_schedule(
|
48 |
-
epoch=cur_epoch,
|
49 |
-
optimizer=self.optimizer,
|
50 |
-
init_lr=self.init_lr,
|
51 |
-
min_lr=self.min_lr,
|
52 |
-
decay_rate=self.decay_rate,
|
53 |
-
)
|
54 |
-
|
55 |
-
|
56 |
-
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
57 |
-
class LinearWarmupCosineLRScheduler:
|
58 |
-
def __init__(
|
59 |
-
self,
|
60 |
-
optimizer,
|
61 |
-
max_epoch,
|
62 |
-
iters_per_epoch,
|
63 |
-
min_lr,
|
64 |
-
init_lr,
|
65 |
-
warmup_steps=0,
|
66 |
-
warmup_start_lr=-1,
|
67 |
-
**kwargs
|
68 |
-
):
|
69 |
-
self.optimizer = optimizer
|
70 |
-
|
71 |
-
self.max_epoch = max_epoch
|
72 |
-
self.iters_per_epoch = iters_per_epoch
|
73 |
-
self.min_lr = min_lr
|
74 |
-
|
75 |
-
self.init_lr = init_lr
|
76 |
-
self.warmup_steps = warmup_steps
|
77 |
-
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
78 |
-
|
79 |
-
def step(self, cur_epoch, cur_step):
|
80 |
-
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
81 |
-
if total_cur_step < self.warmup_steps:
|
82 |
-
warmup_lr_schedule(
|
83 |
-
step=cur_step,
|
84 |
-
optimizer=self.optimizer,
|
85 |
-
max_step=self.warmup_steps,
|
86 |
-
init_lr=self.warmup_start_lr,
|
87 |
-
max_lr=self.init_lr,
|
88 |
-
)
|
89 |
-
else:
|
90 |
-
cosine_lr_schedule(
|
91 |
-
epoch=total_cur_step,
|
92 |
-
optimizer=self.optimizer,
|
93 |
-
max_epoch=self.max_epoch * self.iters_per_epoch,
|
94 |
-
init_lr=self.init_lr,
|
95 |
-
min_lr=self.min_lr,
|
96 |
-
)
|
97 |
-
|
98 |
-
|
99 |
-
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
100 |
-
"""Decay the learning rate"""
|
101 |
-
lr = (init_lr - min_lr) * 0.5 * (
|
102 |
-
1.0 + math.cos(math.pi * epoch / max_epoch)
|
103 |
-
) + min_lr
|
104 |
-
for param_group in optimizer.param_groups:
|
105 |
-
param_group["lr"] = lr
|
106 |
-
|
107 |
-
|
108 |
-
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
109 |
-
"""Warmup the learning rate"""
|
110 |
-
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
111 |
-
for param_group in optimizer.param_groups:
|
112 |
-
param_group["lr"] = lr
|
113 |
-
|
114 |
-
|
115 |
-
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
116 |
-
"""Decay the learning rate"""
|
117 |
-
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
118 |
-
for param_group in optimizer.param_groups:
|
119 |
-
param_group["lr"] = lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/registry.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
|
9 |
-
class Registry:
|
10 |
-
mapping = {
|
11 |
-
"builder_name_mapping": {},
|
12 |
-
"task_name_mapping": {},
|
13 |
-
"processor_name_mapping": {},
|
14 |
-
"model_name_mapping": {},
|
15 |
-
"lr_scheduler_name_mapping": {},
|
16 |
-
"runner_name_mapping": {},
|
17 |
-
"state": {},
|
18 |
-
"paths": {},
|
19 |
-
}
|
20 |
-
|
21 |
-
@classmethod
|
22 |
-
def register_builder(cls, name):
|
23 |
-
r"""Register a dataset builder to registry with key 'name'
|
24 |
-
|
25 |
-
Args:
|
26 |
-
name: Key with which the builder will be registered.
|
27 |
-
|
28 |
-
Usage:
|
29 |
-
|
30 |
-
from minigpt4.common.registry import registry
|
31 |
-
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
|
32 |
-
"""
|
33 |
-
|
34 |
-
def wrap(builder_cls):
|
35 |
-
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
36 |
-
|
37 |
-
assert issubclass(
|
38 |
-
builder_cls, BaseDatasetBuilder
|
39 |
-
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
40 |
-
builder_cls
|
41 |
-
)
|
42 |
-
if name in cls.mapping["builder_name_mapping"]:
|
43 |
-
raise KeyError(
|
44 |
-
"Name '{}' already registered for {}.".format(
|
45 |
-
name, cls.mapping["builder_name_mapping"][name]
|
46 |
-
)
|
47 |
-
)
|
48 |
-
cls.mapping["builder_name_mapping"][name] = builder_cls
|
49 |
-
return builder_cls
|
50 |
-
|
51 |
-
return wrap
|
52 |
-
|
53 |
-
@classmethod
|
54 |
-
def register_task(cls, name):
|
55 |
-
r"""Register a task to registry with key 'name'
|
56 |
-
|
57 |
-
Args:
|
58 |
-
name: Key with which the task will be registered.
|
59 |
-
|
60 |
-
Usage:
|
61 |
-
|
62 |
-
from minigpt4.common.registry import registry
|
63 |
-
"""
|
64 |
-
|
65 |
-
def wrap(task_cls):
|
66 |
-
from minigpt4.tasks.base_task import BaseTask
|
67 |
-
|
68 |
-
assert issubclass(
|
69 |
-
task_cls, BaseTask
|
70 |
-
), "All tasks must inherit BaseTask class"
|
71 |
-
if name in cls.mapping["task_name_mapping"]:
|
72 |
-
raise KeyError(
|
73 |
-
"Name '{}' already registered for {}.".format(
|
74 |
-
name, cls.mapping["task_name_mapping"][name]
|
75 |
-
)
|
76 |
-
)
|
77 |
-
cls.mapping["task_name_mapping"][name] = task_cls
|
78 |
-
return task_cls
|
79 |
-
|
80 |
-
return wrap
|
81 |
-
|
82 |
-
@classmethod
|
83 |
-
def register_model(cls, name):
|
84 |
-
r"""Register a task to registry with key 'name'
|
85 |
-
|
86 |
-
Args:
|
87 |
-
name: Key with which the task will be registered.
|
88 |
-
|
89 |
-
Usage:
|
90 |
-
|
91 |
-
from minigpt4.common.registry import registry
|
92 |
-
"""
|
93 |
-
|
94 |
-
def wrap(model_cls):
|
95 |
-
from minigpt4.models import BaseModel
|
96 |
-
|
97 |
-
assert issubclass(
|
98 |
-
model_cls, BaseModel
|
99 |
-
), "All models must inherit BaseModel class"
|
100 |
-
if name in cls.mapping["model_name_mapping"]:
|
101 |
-
raise KeyError(
|
102 |
-
"Name '{}' already registered for {}.".format(
|
103 |
-
name, cls.mapping["model_name_mapping"][name]
|
104 |
-
)
|
105 |
-
)
|
106 |
-
cls.mapping["model_name_mapping"][name] = model_cls
|
107 |
-
return model_cls
|
108 |
-
|
109 |
-
return wrap
|
110 |
-
|
111 |
-
@classmethod
|
112 |
-
def register_processor(cls, name):
|
113 |
-
r"""Register a processor to registry with key 'name'
|
114 |
-
|
115 |
-
Args:
|
116 |
-
name: Key with which the task will be registered.
|
117 |
-
|
118 |
-
Usage:
|
119 |
-
|
120 |
-
from minigpt4.common.registry import registry
|
121 |
-
"""
|
122 |
-
|
123 |
-
def wrap(processor_cls):
|
124 |
-
from minigpt4.processors import BaseProcessor
|
125 |
-
|
126 |
-
assert issubclass(
|
127 |
-
processor_cls, BaseProcessor
|
128 |
-
), "All processors must inherit BaseProcessor class"
|
129 |
-
if name in cls.mapping["processor_name_mapping"]:
|
130 |
-
raise KeyError(
|
131 |
-
"Name '{}' already registered for {}.".format(
|
132 |
-
name, cls.mapping["processor_name_mapping"][name]
|
133 |
-
)
|
134 |
-
)
|
135 |
-
cls.mapping["processor_name_mapping"][name] = processor_cls
|
136 |
-
return processor_cls
|
137 |
-
|
138 |
-
return wrap
|
139 |
-
|
140 |
-
@classmethod
|
141 |
-
def register_lr_scheduler(cls, name):
|
142 |
-
r"""Register a model to registry with key 'name'
|
143 |
-
|
144 |
-
Args:
|
145 |
-
name: Key with which the task will be registered.
|
146 |
-
|
147 |
-
Usage:
|
148 |
-
|
149 |
-
from minigpt4.common.registry import registry
|
150 |
-
"""
|
151 |
-
|
152 |
-
def wrap(lr_sched_cls):
|
153 |
-
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
154 |
-
raise KeyError(
|
155 |
-
"Name '{}' already registered for {}.".format(
|
156 |
-
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
157 |
-
)
|
158 |
-
)
|
159 |
-
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
160 |
-
return lr_sched_cls
|
161 |
-
|
162 |
-
return wrap
|
163 |
-
|
164 |
-
@classmethod
|
165 |
-
def register_runner(cls, name):
|
166 |
-
r"""Register a model to registry with key 'name'
|
167 |
-
|
168 |
-
Args:
|
169 |
-
name: Key with which the task will be registered.
|
170 |
-
|
171 |
-
Usage:
|
172 |
-
|
173 |
-
from minigpt4.common.registry import registry
|
174 |
-
"""
|
175 |
-
|
176 |
-
def wrap(runner_cls):
|
177 |
-
if name in cls.mapping["runner_name_mapping"]:
|
178 |
-
raise KeyError(
|
179 |
-
"Name '{}' already registered for {}.".format(
|
180 |
-
name, cls.mapping["runner_name_mapping"][name]
|
181 |
-
)
|
182 |
-
)
|
183 |
-
cls.mapping["runner_name_mapping"][name] = runner_cls
|
184 |
-
return runner_cls
|
185 |
-
|
186 |
-
return wrap
|
187 |
-
|
188 |
-
@classmethod
|
189 |
-
def register_path(cls, name, path):
|
190 |
-
r"""Register a path to registry with key 'name'
|
191 |
-
|
192 |
-
Args:
|
193 |
-
name: Key with which the path will be registered.
|
194 |
-
|
195 |
-
Usage:
|
196 |
-
|
197 |
-
from minigpt4.common.registry import registry
|
198 |
-
"""
|
199 |
-
assert isinstance(path, str), "All path must be str."
|
200 |
-
if name in cls.mapping["paths"]:
|
201 |
-
raise KeyError("Name '{}' already registered.".format(name))
|
202 |
-
cls.mapping["paths"][name] = path
|
203 |
-
|
204 |
-
@classmethod
|
205 |
-
def register(cls, name, obj):
|
206 |
-
r"""Register an item to registry with key 'name'
|
207 |
-
|
208 |
-
Args:
|
209 |
-
name: Key with which the item will be registered.
|
210 |
-
|
211 |
-
Usage::
|
212 |
-
|
213 |
-
from minigpt4.common.registry import registry
|
214 |
-
|
215 |
-
registry.register("config", {})
|
216 |
-
"""
|
217 |
-
path = name.split(".")
|
218 |
-
current = cls.mapping["state"]
|
219 |
-
|
220 |
-
for part in path[:-1]:
|
221 |
-
if part not in current:
|
222 |
-
current[part] = {}
|
223 |
-
current = current[part]
|
224 |
-
|
225 |
-
current[path[-1]] = obj
|
226 |
-
|
227 |
-
# @classmethod
|
228 |
-
# def get_trainer_class(cls, name):
|
229 |
-
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
230 |
-
|
231 |
-
@classmethod
|
232 |
-
def get_builder_class(cls, name):
|
233 |
-
return cls.mapping["builder_name_mapping"].get(name, None)
|
234 |
-
|
235 |
-
@classmethod
|
236 |
-
def get_model_class(cls, name):
|
237 |
-
return cls.mapping["model_name_mapping"].get(name, None)
|
238 |
-
|
239 |
-
@classmethod
|
240 |
-
def get_task_class(cls, name):
|
241 |
-
return cls.mapping["task_name_mapping"].get(name, None)
|
242 |
-
|
243 |
-
@classmethod
|
244 |
-
def get_processor_class(cls, name):
|
245 |
-
return cls.mapping["processor_name_mapping"].get(name, None)
|
246 |
-
|
247 |
-
@classmethod
|
248 |
-
def get_lr_scheduler_class(cls, name):
|
249 |
-
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
250 |
-
|
251 |
-
@classmethod
|
252 |
-
def get_runner_class(cls, name):
|
253 |
-
return cls.mapping["runner_name_mapping"].get(name, None)
|
254 |
-
|
255 |
-
@classmethod
|
256 |
-
def list_runners(cls):
|
257 |
-
return sorted(cls.mapping["runner_name_mapping"].keys())
|
258 |
-
|
259 |
-
@classmethod
|
260 |
-
def list_models(cls):
|
261 |
-
return sorted(cls.mapping["model_name_mapping"].keys())
|
262 |
-
|
263 |
-
@classmethod
|
264 |
-
def list_tasks(cls):
|
265 |
-
return sorted(cls.mapping["task_name_mapping"].keys())
|
266 |
-
|
267 |
-
@classmethod
|
268 |
-
def list_processors(cls):
|
269 |
-
return sorted(cls.mapping["processor_name_mapping"].keys())
|
270 |
-
|
271 |
-
@classmethod
|
272 |
-
def list_lr_schedulers(cls):
|
273 |
-
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
274 |
-
|
275 |
-
@classmethod
|
276 |
-
def list_datasets(cls):
|
277 |
-
return sorted(cls.mapping["builder_name_mapping"].keys())
|
278 |
-
|
279 |
-
@classmethod
|
280 |
-
def get_path(cls, name):
|
281 |
-
return cls.mapping["paths"].get(name, None)
|
282 |
-
|
283 |
-
@classmethod
|
284 |
-
def get(cls, name, default=None, no_warning=False):
|
285 |
-
r"""Get an item from registry with key 'name'
|
286 |
-
|
287 |
-
Args:
|
288 |
-
name (string): Key whose value needs to be retrieved.
|
289 |
-
default: If passed and key is not in registry, default value will
|
290 |
-
be returned with a warning. Default: None
|
291 |
-
no_warning (bool): If passed as True, warning when key doesn't exist
|
292 |
-
will not be generated. Useful for MMF's
|
293 |
-
internal operations. Default: False
|
294 |
-
"""
|
295 |
-
original_name = name
|
296 |
-
name = name.split(".")
|
297 |
-
value = cls.mapping["state"]
|
298 |
-
for subname in name:
|
299 |
-
value = value.get(subname, default)
|
300 |
-
if value is default:
|
301 |
-
break
|
302 |
-
|
303 |
-
if (
|
304 |
-
"writer" in cls.mapping["state"]
|
305 |
-
and value == default
|
306 |
-
and no_warning is False
|
307 |
-
):
|
308 |
-
cls.mapping["state"]["writer"].warning(
|
309 |
-
"Key {} is not present in registry, returning default value "
|
310 |
-
"of {}".format(original_name, default)
|
311 |
-
)
|
312 |
-
return value
|
313 |
-
|
314 |
-
@classmethod
|
315 |
-
def unregister(cls, name):
|
316 |
-
r"""Remove an item from registry with key 'name'
|
317 |
-
|
318 |
-
Args:
|
319 |
-
name: Key which needs to be removed.
|
320 |
-
Usage::
|
321 |
-
|
322 |
-
from mmf.common.registry import registry
|
323 |
-
|
324 |
-
config = registry.unregister("config")
|
325 |
-
"""
|
326 |
-
return cls.mapping["state"].pop(name, None)
|
327 |
-
|
328 |
-
|
329 |
-
registry = Registry()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/utils.py
DELETED
@@ -1,424 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import io
|
9 |
-
import json
|
10 |
-
import logging
|
11 |
-
import os
|
12 |
-
import pickle
|
13 |
-
import re
|
14 |
-
import shutil
|
15 |
-
import urllib
|
16 |
-
import urllib.error
|
17 |
-
import urllib.request
|
18 |
-
from typing import Optional
|
19 |
-
from urllib.parse import urlparse
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
import pandas as pd
|
23 |
-
import yaml
|
24 |
-
from iopath.common.download import download
|
25 |
-
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
-
from minigpt4.common.registry import registry
|
27 |
-
from torch.utils.model_zoo import tqdm
|
28 |
-
from torchvision.datasets.utils import (
|
29 |
-
check_integrity,
|
30 |
-
download_file_from_google_drive,
|
31 |
-
extract_archive,
|
32 |
-
)
|
33 |
-
|
34 |
-
|
35 |
-
def now():
|
36 |
-
from datetime import datetime
|
37 |
-
|
38 |
-
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
39 |
-
|
40 |
-
|
41 |
-
def is_url(url_or_filename):
|
42 |
-
parsed = urlparse(url_or_filename)
|
43 |
-
return parsed.scheme in ("http", "https")
|
44 |
-
|
45 |
-
|
46 |
-
def get_cache_path(rel_path):
|
47 |
-
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
48 |
-
|
49 |
-
|
50 |
-
def get_abs_path(rel_path):
|
51 |
-
return os.path.join(registry.get_path("library_root"), rel_path)
|
52 |
-
|
53 |
-
|
54 |
-
def load_json(filename):
|
55 |
-
with open(filename, "r") as f:
|
56 |
-
return json.load(f)
|
57 |
-
|
58 |
-
|
59 |
-
# The following are adapted from torchvision and vissl
|
60 |
-
# torchvision: https://github.com/pytorch/vision
|
61 |
-
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
62 |
-
|
63 |
-
|
64 |
-
def makedir(dir_path):
|
65 |
-
"""
|
66 |
-
Create the directory if it does not exist.
|
67 |
-
"""
|
68 |
-
is_success = False
|
69 |
-
try:
|
70 |
-
if not g_pathmgr.exists(dir_path):
|
71 |
-
g_pathmgr.mkdirs(dir_path)
|
72 |
-
is_success = True
|
73 |
-
except BaseException:
|
74 |
-
print(f"Error creating directory: {dir_path}")
|
75 |
-
return is_success
|
76 |
-
|
77 |
-
|
78 |
-
def get_redirected_url(url: str):
|
79 |
-
"""
|
80 |
-
Given a URL, returns the URL it redirects to or the
|
81 |
-
original URL in case of no indirection
|
82 |
-
"""
|
83 |
-
import requests
|
84 |
-
|
85 |
-
with requests.Session() as session:
|
86 |
-
with session.get(url, stream=True, allow_redirects=True) as response:
|
87 |
-
if response.history:
|
88 |
-
return response.url
|
89 |
-
else:
|
90 |
-
return url
|
91 |
-
|
92 |
-
|
93 |
-
def to_google_drive_download_url(view_url: str) -> str:
|
94 |
-
"""
|
95 |
-
Utility function to transform a view URL of google drive
|
96 |
-
to a download URL for google drive
|
97 |
-
Example input:
|
98 |
-
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
99 |
-
Example output:
|
100 |
-
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
101 |
-
"""
|
102 |
-
splits = view_url.split("/")
|
103 |
-
assert splits[-1] == "view"
|
104 |
-
file_id = splits[-2]
|
105 |
-
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
106 |
-
|
107 |
-
|
108 |
-
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
109 |
-
"""
|
110 |
-
Download a file from google drive
|
111 |
-
Downloading an URL from google drive requires confirmation when
|
112 |
-
the file of the size is too big (google drive notifies that
|
113 |
-
anti-viral checks cannot be performed on such files)
|
114 |
-
"""
|
115 |
-
import requests
|
116 |
-
|
117 |
-
with requests.Session() as session:
|
118 |
-
|
119 |
-
# First get the confirmation token and append it to the URL
|
120 |
-
with session.get(url, stream=True, allow_redirects=True) as response:
|
121 |
-
for k, v in response.cookies.items():
|
122 |
-
if k.startswith("download_warning"):
|
123 |
-
url = url + "&confirm=" + v
|
124 |
-
|
125 |
-
# Then download the content of the file
|
126 |
-
with session.get(url, stream=True, verify=True) as response:
|
127 |
-
makedir(output_path)
|
128 |
-
path = os.path.join(output_path, output_file_name)
|
129 |
-
total_size = int(response.headers.get("Content-length", 0))
|
130 |
-
with open(path, "wb") as file:
|
131 |
-
from tqdm import tqdm
|
132 |
-
|
133 |
-
with tqdm(total=total_size) as progress_bar:
|
134 |
-
for block in response.iter_content(
|
135 |
-
chunk_size=io.DEFAULT_BUFFER_SIZE
|
136 |
-
):
|
137 |
-
file.write(block)
|
138 |
-
progress_bar.update(len(block))
|
139 |
-
|
140 |
-
|
141 |
-
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
142 |
-
parts = urlparse(url)
|
143 |
-
|
144 |
-
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
145 |
-
return None
|
146 |
-
|
147 |
-
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
148 |
-
if match is None:
|
149 |
-
return None
|
150 |
-
|
151 |
-
return match.group("id")
|
152 |
-
|
153 |
-
|
154 |
-
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
155 |
-
with open(filename, "wb") as fh:
|
156 |
-
with urllib.request.urlopen(
|
157 |
-
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
158 |
-
) as response:
|
159 |
-
with tqdm(total=response.length) as pbar:
|
160 |
-
for chunk in iter(lambda: response.read(chunk_size), ""):
|
161 |
-
if not chunk:
|
162 |
-
break
|
163 |
-
pbar.update(chunk_size)
|
164 |
-
fh.write(chunk)
|
165 |
-
|
166 |
-
|
167 |
-
def download_url(
|
168 |
-
url: str,
|
169 |
-
root: str,
|
170 |
-
filename: Optional[str] = None,
|
171 |
-
md5: Optional[str] = None,
|
172 |
-
) -> None:
|
173 |
-
"""Download a file from a url and place it in root.
|
174 |
-
Args:
|
175 |
-
url (str): URL to download file from
|
176 |
-
root (str): Directory to place downloaded file in
|
177 |
-
filename (str, optional): Name to save the file under.
|
178 |
-
If None, use the basename of the URL.
|
179 |
-
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
180 |
-
"""
|
181 |
-
root = os.path.expanduser(root)
|
182 |
-
if not filename:
|
183 |
-
filename = os.path.basename(url)
|
184 |
-
fpath = os.path.join(root, filename)
|
185 |
-
|
186 |
-
makedir(root)
|
187 |
-
|
188 |
-
# check if file is already present locally
|
189 |
-
if check_integrity(fpath, md5):
|
190 |
-
print("Using downloaded and verified file: " + fpath)
|
191 |
-
return
|
192 |
-
|
193 |
-
# expand redirect chain if needed
|
194 |
-
url = get_redirected_url(url)
|
195 |
-
|
196 |
-
# check if file is located on Google Drive
|
197 |
-
file_id = _get_google_drive_file_id(url)
|
198 |
-
if file_id is not None:
|
199 |
-
return download_file_from_google_drive(file_id, root, filename, md5)
|
200 |
-
|
201 |
-
# download the file
|
202 |
-
try:
|
203 |
-
print("Downloading " + url + " to " + fpath)
|
204 |
-
_urlretrieve(url, fpath)
|
205 |
-
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
206 |
-
if url[:5] == "https":
|
207 |
-
url = url.replace("https:", "http:")
|
208 |
-
print(
|
209 |
-
"Failed download. Trying https -> http instead."
|
210 |
-
" Downloading " + url + " to " + fpath
|
211 |
-
)
|
212 |
-
_urlretrieve(url, fpath)
|
213 |
-
else:
|
214 |
-
raise e
|
215 |
-
|
216 |
-
# check integrity of downloaded file
|
217 |
-
if not check_integrity(fpath, md5):
|
218 |
-
raise RuntimeError("File not found or corrupted.")
|
219 |
-
|
220 |
-
|
221 |
-
def download_and_extract_archive(
|
222 |
-
url: str,
|
223 |
-
download_root: str,
|
224 |
-
extract_root: Optional[str] = None,
|
225 |
-
filename: Optional[str] = None,
|
226 |
-
md5: Optional[str] = None,
|
227 |
-
remove_finished: bool = False,
|
228 |
-
) -> None:
|
229 |
-
download_root = os.path.expanduser(download_root)
|
230 |
-
if extract_root is None:
|
231 |
-
extract_root = download_root
|
232 |
-
if not filename:
|
233 |
-
filename = os.path.basename(url)
|
234 |
-
|
235 |
-
download_url(url, download_root, filename, md5)
|
236 |
-
|
237 |
-
archive = os.path.join(download_root, filename)
|
238 |
-
print("Extracting {} to {}".format(archive, extract_root))
|
239 |
-
extract_archive(archive, extract_root, remove_finished)
|
240 |
-
|
241 |
-
|
242 |
-
def cache_url(url: str, cache_dir: str) -> str:
|
243 |
-
"""
|
244 |
-
This implementation downloads the remote resource and caches it locally.
|
245 |
-
The resource will only be downloaded if not previously requested.
|
246 |
-
"""
|
247 |
-
parsed_url = urlparse(url)
|
248 |
-
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
249 |
-
makedir(dirname)
|
250 |
-
filename = url.split("/")[-1]
|
251 |
-
cached = os.path.join(dirname, filename)
|
252 |
-
with file_lock(cached):
|
253 |
-
if not os.path.isfile(cached):
|
254 |
-
logging.info(f"Downloading {url} to {cached} ...")
|
255 |
-
cached = download(url, dirname, filename=filename)
|
256 |
-
logging.info(f"URL {url} cached in {cached}")
|
257 |
-
return cached
|
258 |
-
|
259 |
-
|
260 |
-
# TODO (prigoyal): convert this into RAII-style API
|
261 |
-
def create_file_symlink(file1, file2):
|
262 |
-
"""
|
263 |
-
Simply create the symlinks for a given file1 to file2.
|
264 |
-
Useful during model checkpointing to symlinks to the
|
265 |
-
latest successful checkpoint.
|
266 |
-
"""
|
267 |
-
try:
|
268 |
-
if g_pathmgr.exists(file2):
|
269 |
-
g_pathmgr.rm(file2)
|
270 |
-
g_pathmgr.symlink(file1, file2)
|
271 |
-
except Exception as e:
|
272 |
-
logging.info(f"Could NOT create symlink. Error: {e}")
|
273 |
-
|
274 |
-
|
275 |
-
def save_file(data, filename, append_to_json=True, verbose=True):
|
276 |
-
"""
|
277 |
-
Common i/o utility to handle saving data to various file formats.
|
278 |
-
Supported:
|
279 |
-
.pkl, .pickle, .npy, .json
|
280 |
-
Specifically for .json, users have the option to either append (default)
|
281 |
-
or rewrite by passing in Boolean value to append_to_json.
|
282 |
-
"""
|
283 |
-
if verbose:
|
284 |
-
logging.info(f"Saving data to file: {filename}")
|
285 |
-
file_ext = os.path.splitext(filename)[1]
|
286 |
-
if file_ext in [".pkl", ".pickle"]:
|
287 |
-
with g_pathmgr.open(filename, "wb") as fopen:
|
288 |
-
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
289 |
-
elif file_ext == ".npy":
|
290 |
-
with g_pathmgr.open(filename, "wb") as fopen:
|
291 |
-
np.save(fopen, data)
|
292 |
-
elif file_ext == ".json":
|
293 |
-
if append_to_json:
|
294 |
-
with g_pathmgr.open(filename, "a") as fopen:
|
295 |
-
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
296 |
-
fopen.flush()
|
297 |
-
else:
|
298 |
-
with g_pathmgr.open(filename, "w") as fopen:
|
299 |
-
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
300 |
-
fopen.flush()
|
301 |
-
elif file_ext == ".yaml":
|
302 |
-
with g_pathmgr.open(filename, "w") as fopen:
|
303 |
-
dump = yaml.dump(data)
|
304 |
-
fopen.write(dump)
|
305 |
-
fopen.flush()
|
306 |
-
else:
|
307 |
-
raise Exception(f"Saving {file_ext} is not supported yet")
|
308 |
-
|
309 |
-
if verbose:
|
310 |
-
logging.info(f"Saved data to file: {filename}")
|
311 |
-
|
312 |
-
|
313 |
-
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
314 |
-
"""
|
315 |
-
Common i/o utility to handle loading data from various file formats.
|
316 |
-
Supported:
|
317 |
-
.pkl, .pickle, .npy, .json
|
318 |
-
For the npy files, we support reading the files in mmap_mode.
|
319 |
-
If the mmap_mode of reading is not successful, we load data without the
|
320 |
-
mmap_mode.
|
321 |
-
"""
|
322 |
-
if verbose:
|
323 |
-
logging.info(f"Loading data from file: {filename}")
|
324 |
-
|
325 |
-
file_ext = os.path.splitext(filename)[1]
|
326 |
-
if file_ext == ".txt":
|
327 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
328 |
-
data = fopen.readlines()
|
329 |
-
elif file_ext in [".pkl", ".pickle"]:
|
330 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
331 |
-
data = pickle.load(fopen, encoding="latin1")
|
332 |
-
elif file_ext == ".npy":
|
333 |
-
if mmap_mode:
|
334 |
-
try:
|
335 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
336 |
-
data = np.load(
|
337 |
-
fopen,
|
338 |
-
allow_pickle=allow_pickle,
|
339 |
-
encoding="latin1",
|
340 |
-
mmap_mode=mmap_mode,
|
341 |
-
)
|
342 |
-
except ValueError as e:
|
343 |
-
logging.info(
|
344 |
-
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
345 |
-
)
|
346 |
-
data = np.load(
|
347 |
-
filename,
|
348 |
-
allow_pickle=allow_pickle,
|
349 |
-
encoding="latin1",
|
350 |
-
mmap_mode=mmap_mode,
|
351 |
-
)
|
352 |
-
logging.info("Successfully loaded without g_pathmgr")
|
353 |
-
except Exception:
|
354 |
-
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
355 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
356 |
-
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
357 |
-
else:
|
358 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
359 |
-
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
360 |
-
elif file_ext == ".json":
|
361 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
362 |
-
data = json.load(fopen)
|
363 |
-
elif file_ext == ".yaml":
|
364 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
365 |
-
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
366 |
-
elif file_ext == ".csv":
|
367 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
368 |
-
data = pd.read_csv(fopen)
|
369 |
-
else:
|
370 |
-
raise Exception(f"Reading from {file_ext} is not supported yet")
|
371 |
-
return data
|
372 |
-
|
373 |
-
|
374 |
-
def abspath(resource_path: str):
|
375 |
-
"""
|
376 |
-
Make a path absolute, but take into account prefixes like
|
377 |
-
"http://" or "manifold://"
|
378 |
-
"""
|
379 |
-
regex = re.compile(r"^\w+://")
|
380 |
-
if regex.match(resource_path) is None:
|
381 |
-
return os.path.abspath(resource_path)
|
382 |
-
else:
|
383 |
-
return resource_path
|
384 |
-
|
385 |
-
|
386 |
-
def makedir(dir_path):
|
387 |
-
"""
|
388 |
-
Create the directory if it does not exist.
|
389 |
-
"""
|
390 |
-
is_success = False
|
391 |
-
try:
|
392 |
-
if not g_pathmgr.exists(dir_path):
|
393 |
-
g_pathmgr.mkdirs(dir_path)
|
394 |
-
is_success = True
|
395 |
-
except BaseException:
|
396 |
-
logging.info(f"Error creating directory: {dir_path}")
|
397 |
-
return is_success
|
398 |
-
|
399 |
-
|
400 |
-
def is_url(input_url):
|
401 |
-
"""
|
402 |
-
Check if an input string is a url. look for http(s):// and ignoring the case
|
403 |
-
"""
|
404 |
-
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
405 |
-
return is_url
|
406 |
-
|
407 |
-
|
408 |
-
def cleanup_dir(dir):
|
409 |
-
"""
|
410 |
-
Utility for deleting a directory. Useful for cleaning the storage space
|
411 |
-
that contains various training artifacts like checkpoints, data etc.
|
412 |
-
"""
|
413 |
-
if os.path.exists(dir):
|
414 |
-
logging.info(f"Deleting directory: {dir}")
|
415 |
-
shutil.rmtree(dir)
|
416 |
-
logging.info(f"Deleted contents of directory: {dir}")
|
417 |
-
|
418 |
-
|
419 |
-
def get_file_size(filename):
|
420 |
-
"""
|
421 |
-
Given a file, get the size of file in MB
|
422 |
-
"""
|
423 |
-
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
-
return size_in_mb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
# coding: utf-8
|
2 |
-
|
3 |
-
import sys
|
4 |
-
dataDir = '../../VQA'
|
5 |
-
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
|
6 |
-
from vqa import VQA
|
7 |
-
from vqaEvaluation.vqaEval import VQAEval
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
-
import skimage.io as io
|
10 |
-
import json
|
11 |
-
import random
|
12 |
-
import os
|
13 |
-
|
14 |
-
# set up file names and paths
|
15 |
-
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
|
16 |
-
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
17 |
-
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
18 |
-
dataSubType ='train2014'
|
19 |
-
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
|
20 |
-
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
|
21 |
-
imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
|
22 |
-
resultType ='fake'
|
23 |
-
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
|
24 |
-
|
25 |
-
# An example result json file has been provided in './Results' folder.
|
26 |
-
|
27 |
-
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
|
28 |
-
resultType, fileType) for fileType in fileTypes]
|
29 |
-
|
30 |
-
# create vqa object and vqaRes object
|
31 |
-
vqa = VQA(annFile, quesFile)
|
32 |
-
vqaRes = vqa.loadRes(resFile, quesFile)
|
33 |
-
|
34 |
-
# create vqaEval object by taking vqa and vqaRes
|
35 |
-
vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
|
36 |
-
|
37 |
-
# evaluate results
|
38 |
-
"""
|
39 |
-
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
|
40 |
-
By default it uses all the question ids in annotation file
|
41 |
-
"""
|
42 |
-
vqaEval.evaluate()
|
43 |
-
|
44 |
-
# print accuracies
|
45 |
-
print "\n"
|
46 |
-
print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
|
47 |
-
print "Per Question Type Accuracy is the following:"
|
48 |
-
for quesType in vqaEval.accuracy['perQuestionType']:
|
49 |
-
print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
|
50 |
-
print "\n"
|
51 |
-
print "Per Answer Type Accuracy is the following:"
|
52 |
-
for ansType in vqaEval.accuracy['perAnswerType']:
|
53 |
-
print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
|
54 |
-
print "\n"
|
55 |
-
# demo how to use evalQA to retrieve low score result
|
56 |
-
evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
|
57 |
-
if len(evals) > 0:
|
58 |
-
print 'ground truth answers'
|
59 |
-
randomEval = random.choice(evals)
|
60 |
-
randomAnn = vqa.loadQA(randomEval)
|
61 |
-
vqa.showQA(randomAnn)
|
62 |
-
|
63 |
-
print '\n'
|
64 |
-
print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
|
65 |
-
ann = vqaRes.loadQA(randomEval)[0]
|
66 |
-
print "Answer: %s\n" %(ann['answer'])
|
67 |
-
|
68 |
-
imgId = randomAnn[0]['image_id']
|
69 |
-
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
70 |
-
if os.path.isfile(imgDir + imgFilename):
|
71 |
-
I = io.imread(imgDir + imgFilename)
|
72 |
-
plt.imshow(I)
|
73 |
-
plt.axis('off')
|
74 |
-
plt.show()
|
75 |
-
|
76 |
-
# plot accuracy for various question types
|
77 |
-
plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
|
78 |
-
plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
|
79 |
-
plt.title('Per Question Type Accuracy', fontsize=10)
|
80 |
-
plt.xlabel('Question Types', fontsize=10)
|
81 |
-
plt.ylabel('Accuracy', fontsize=10)
|
82 |
-
plt.show()
|
83 |
-
|
84 |
-
# save evaluation results to ./Results folder
|
85 |
-
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
|
86 |
-
json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
|
87 |
-
json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
|
88 |
-
json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
author='aagrawal'
|
|
|
|
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
|
3 |
-
__author__='aagrawal'
|
4 |
-
|
5 |
-
import re
|
6 |
-
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
7 |
-
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
8 |
-
import sys
|
9 |
-
|
10 |
-
|
11 |
-
class VQAEval:
|
12 |
-
def __init__(self, vqa, vqaRes, n=2):
|
13 |
-
self.n = n
|
14 |
-
self.accuracy = {}
|
15 |
-
self.evalQA = {}
|
16 |
-
self.evalQuesType = {}
|
17 |
-
self.evalAnsType = {}
|
18 |
-
self.vqa = vqa
|
19 |
-
self.vqaRes = vqaRes
|
20 |
-
self.params = {'question_id': vqa.getQuesIds()}
|
21 |
-
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
|
22 |
-
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
|
23 |
-
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
|
24 |
-
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
|
25 |
-
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
|
26 |
-
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
|
27 |
-
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
|
28 |
-
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
|
29 |
-
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
|
30 |
-
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
|
31 |
-
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
|
32 |
-
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
|
33 |
-
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
|
34 |
-
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
|
35 |
-
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
|
36 |
-
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
|
37 |
-
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
|
38 |
-
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
|
39 |
-
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
|
40 |
-
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
|
41 |
-
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
|
42 |
-
"youll": "you'll", "youre": "you're", "youve": "you've"}
|
43 |
-
self.manualMap = { 'none': '0',
|
44 |
-
'zero': '0',
|
45 |
-
'one': '1',
|
46 |
-
'two': '2',
|
47 |
-
'three': '3',
|
48 |
-
'four': '4',
|
49 |
-
'five': '5',
|
50 |
-
'six': '6',
|
51 |
-
'seven': '7',
|
52 |
-
'eight': '8',
|
53 |
-
'nine': '9',
|
54 |
-
'ten': '10'
|
55 |
-
}
|
56 |
-
self.articles = ['a',
|
57 |
-
'an',
|
58 |
-
'the'
|
59 |
-
]
|
60 |
-
|
61 |
-
|
62 |
-
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
63 |
-
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
64 |
-
self.punct = [';', r"/", '[', ']', '"', '{', '}',
|
65 |
-
'(', ')', '=', '+', '\\', '_', '-',
|
66 |
-
'>', '<', '@', '`', ',', '?', '!']
|
67 |
-
|
68 |
-
|
69 |
-
def evaluate(self, quesIds=None):
|
70 |
-
if quesIds == None:
|
71 |
-
quesIds = [quesId for quesId in self.params['question_id']]
|
72 |
-
gts = {}
|
73 |
-
res = {}
|
74 |
-
for quesId in quesIds:
|
75 |
-
gts[quesId] = self.vqa.qa[quesId]
|
76 |
-
res[quesId] = self.vqaRes.qa[quesId]
|
77 |
-
|
78 |
-
# =================================================
|
79 |
-
# Compute accuracy
|
80 |
-
# =================================================
|
81 |
-
accQA = []
|
82 |
-
accQuesType = {}
|
83 |
-
accAnsType = {}
|
84 |
-
# print "computing accuracy"
|
85 |
-
step = 0
|
86 |
-
for quesId in quesIds:
|
87 |
-
for ansDic in gts[quesId]['answers']:
|
88 |
-
ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
|
89 |
-
ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
|
90 |
-
ansDic['answer'] = ansDic['answer'].strip()
|
91 |
-
resAns = res[quesId]['answer']
|
92 |
-
resAns = resAns.replace('\n', ' ')
|
93 |
-
resAns = resAns.replace('\t', ' ')
|
94 |
-
resAns = resAns.strip()
|
95 |
-
gtAcc = []
|
96 |
-
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
|
97 |
-
|
98 |
-
if len(set(gtAnswers)) > 1:
|
99 |
-
for ansDic in gts[quesId]['answers']:
|
100 |
-
ansDic['answer'] = self.processPunctuation(ansDic['answer'])
|
101 |
-
ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
|
102 |
-
resAns = self.processPunctuation(resAns)
|
103 |
-
resAns = self.processDigitArticle(resAns)
|
104 |
-
|
105 |
-
for gtAnsDatum in gts[quesId]['answers']:
|
106 |
-
otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
|
107 |
-
matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
|
108 |
-
acc = min(1, float(len(matchingAns))/3)
|
109 |
-
gtAcc.append(acc)
|
110 |
-
quesType = gts[quesId]['question_type']
|
111 |
-
ansType = gts[quesId]['answer_type']
|
112 |
-
avgGTAcc = float(sum(gtAcc))/len(gtAcc)
|
113 |
-
accQA.append(avgGTAcc)
|
114 |
-
if quesType not in accQuesType:
|
115 |
-
accQuesType[quesType] = []
|
116 |
-
accQuesType[quesType].append(avgGTAcc)
|
117 |
-
if ansType not in accAnsType:
|
118 |
-
accAnsType[ansType] = []
|
119 |
-
accAnsType[ansType].append(avgGTAcc)
|
120 |
-
self.setEvalQA(quesId, avgGTAcc)
|
121 |
-
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
122 |
-
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
123 |
-
if step%100 == 0:
|
124 |
-
self.updateProgress(step/float(len(quesIds)))
|
125 |
-
step = step + 1
|
126 |
-
|
127 |
-
self.setAccuracy(accQA, accQuesType, accAnsType)
|
128 |
-
# print "Done computing accuracy"
|
129 |
-
|
130 |
-
def processPunctuation(self, inText):
|
131 |
-
outText = inText
|
132 |
-
for p in self.punct:
|
133 |
-
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
|
134 |
-
outText = outText.replace(p, '')
|
135 |
-
else:
|
136 |
-
outText = outText.replace(p, ' ')
|
137 |
-
outText = self.periodStrip.sub("",
|
138 |
-
outText,
|
139 |
-
re.UNICODE)
|
140 |
-
return outText
|
141 |
-
|
142 |
-
def processDigitArticle(self, inText):
|
143 |
-
outText = []
|
144 |
-
tempText = inText.lower().split()
|
145 |
-
for word in tempText:
|
146 |
-
word = self.manualMap.setdefault(word, word)
|
147 |
-
if word not in self.articles:
|
148 |
-
outText.append(word)
|
149 |
-
else:
|
150 |
-
pass
|
151 |
-
for wordId, word in enumerate(outText):
|
152 |
-
if word in self.contractions:
|
153 |
-
outText[wordId] = self.contractions[word]
|
154 |
-
outText = ' '.join(outText)
|
155 |
-
return outText
|
156 |
-
|
157 |
-
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
158 |
-
self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
|
159 |
-
self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
|
160 |
-
self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
|
161 |
-
|
162 |
-
def setEvalQA(self, quesId, acc):
|
163 |
-
self.evalQA[quesId] = round(100*acc, self.n)
|
164 |
-
|
165 |
-
def setEvalQuesType(self, quesId, quesType, acc):
|
166 |
-
if quesType not in self.evalQuesType:
|
167 |
-
self.evalQuesType[quesType] = {}
|
168 |
-
self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
|
169 |
-
|
170 |
-
def setEvalAnsType(self, quesId, ansType, acc):
|
171 |
-
if ansType not in self.evalAnsType:
|
172 |
-
self.evalAnsType[ansType] = {}
|
173 |
-
self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
|
174 |
-
|
175 |
-
def updateProgress(self, progress):
|
176 |
-
barLength = 20
|
177 |
-
status = ""
|
178 |
-
if isinstance(progress, int):
|
179 |
-
progress = float(progress)
|
180 |
-
if not isinstance(progress, float):
|
181 |
-
progress = 0
|
182 |
-
status = "error: progress var must be float\r\n"
|
183 |
-
if progress < 0:
|
184 |
-
progress = 0
|
185 |
-
status = "Halt...\r\n"
|
186 |
-
if progress >= 1:
|
187 |
-
progress = 1
|
188 |
-
status = "Done...\r\n"
|
189 |
-
block = int(round(barLength*progress))
|
190 |
-
text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
|
191 |
-
sys.stdout.write(text)
|
192 |
-
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
# coding: utf-8
|
2 |
-
|
3 |
-
from vqaTools.vqa import VQA
|
4 |
-
import random
|
5 |
-
import skimage.io as io
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import os
|
8 |
-
|
9 |
-
dataDir ='../../VQA'
|
10 |
-
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
|
11 |
-
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
12 |
-
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
13 |
-
dataSubType ='train2014'
|
14 |
-
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
|
15 |
-
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
|
16 |
-
imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
|
17 |
-
|
18 |
-
# initialize VQA api for QA annotations
|
19 |
-
vqa=VQA(annFile, quesFile)
|
20 |
-
|
21 |
-
# load and display QA annotations for given question types
|
22 |
-
"""
|
23 |
-
All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
|
24 |
-
"""
|
25 |
-
annIds = vqa.getQuesIds(quesTypes='how many');
|
26 |
-
anns = vqa.loadQA(annIds)
|
27 |
-
randomAnn = random.choice(anns)
|
28 |
-
vqa.showQA([randomAnn])
|
29 |
-
imgId = randomAnn['image_id']
|
30 |
-
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
31 |
-
if os.path.isfile(imgDir + imgFilename):
|
32 |
-
I = io.imread(imgDir + imgFilename)
|
33 |
-
plt.imshow(I)
|
34 |
-
plt.axis('off')
|
35 |
-
plt.show()
|
36 |
-
|
37 |
-
# load and display QA annotations for given answer types
|
38 |
-
"""
|
39 |
-
ansTypes can be one of the following
|
40 |
-
yes/no
|
41 |
-
number
|
42 |
-
other
|
43 |
-
"""
|
44 |
-
annIds = vqa.getQuesIds(ansTypes='yes/no');
|
45 |
-
anns = vqa.loadQA(annIds)
|
46 |
-
randomAnn = random.choice(anns)
|
47 |
-
vqa.showQA([randomAnn])
|
48 |
-
imgId = randomAnn['image_id']
|
49 |
-
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
50 |
-
if os.path.isfile(imgDir + imgFilename):
|
51 |
-
I = io.imread(imgDir + imgFilename)
|
52 |
-
plt.imshow(I)
|
53 |
-
plt.axis('off')
|
54 |
-
plt.show()
|
55 |
-
|
56 |
-
# load and display QA annotations for given images
|
57 |
-
"""
|
58 |
-
Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
|
59 |
-
Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
|
60 |
-
"""
|
61 |
-
ids = vqa.getImgIds()
|
62 |
-
annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
|
63 |
-
anns = vqa.loadQA(annIds)
|
64 |
-
randomAnn = random.choice(anns)
|
65 |
-
vqa.showQA([randomAnn])
|
66 |
-
imgId = randomAnn['image_id']
|
67 |
-
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
68 |
-
if os.path.isfile(imgDir + imgFilename):
|
69 |
-
I = io.imread(imgDir + imgFilename)
|
70 |
-
plt.imshow(I)
|
71 |
-
plt.axis('off')
|
72 |
-
plt.show()
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
__author__ = 'aagrawal'
|
|
|
|
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
DELETED
@@ -1,179 +0,0 @@
|
|
1 |
-
__author__ = 'aagrawal'
|
2 |
-
__version__ = '0.9'
|
3 |
-
|
4 |
-
# Interface for accessing the VQA dataset.
|
5 |
-
|
6 |
-
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
7 |
-
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
8 |
-
|
9 |
-
# The following functions are defined:
|
10 |
-
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
11 |
-
# getQuesIds - Get question ids that satisfy given filter conditions.
|
12 |
-
# getImgIds - Get image ids that satisfy given filter conditions.
|
13 |
-
# loadQA - Load questions and answers with the specified question ids.
|
14 |
-
# showQA - Display the specified questions and answers.
|
15 |
-
# loadRes - Load result file and create result object.
|
16 |
-
|
17 |
-
# Help on each function can be accessed by: "help(COCO.function)"
|
18 |
-
|
19 |
-
import json
|
20 |
-
import datetime
|
21 |
-
import copy
|
22 |
-
|
23 |
-
|
24 |
-
class VQA:
|
25 |
-
def __init__(self, annotation_file=None, question_file=None):
|
26 |
-
"""
|
27 |
-
Constructor of VQA helper class for reading and visualizing questions and answers.
|
28 |
-
:param annotation_file (str): location of VQA annotation file
|
29 |
-
:return:
|
30 |
-
"""
|
31 |
-
# load dataset
|
32 |
-
self.dataset = {}
|
33 |
-
self.questions = {}
|
34 |
-
self.qa = {}
|
35 |
-
self.qqa = {}
|
36 |
-
self.imgToQA = {}
|
37 |
-
if not annotation_file == None and not question_file == None:
|
38 |
-
# print 'loading VQA annotations and questions into memory...'
|
39 |
-
time_t = datetime.datetime.utcnow()
|
40 |
-
dataset = json.load(open(annotation_file, 'r'))
|
41 |
-
questions = json.load(open(question_file, 'r'))
|
42 |
-
# print datetime.datetime.utcnow() - time_t
|
43 |
-
self.dataset = dataset
|
44 |
-
self.questions = questions
|
45 |
-
self.createIndex()
|
46 |
-
|
47 |
-
def createIndex(self):
|
48 |
-
imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
|
49 |
-
qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
|
50 |
-
qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
|
51 |
-
for ann in self.dataset['annotations']:
|
52 |
-
imgToQA[ann['image_id']] += [ann]
|
53 |
-
qa[ann['question_id']] = ann
|
54 |
-
for ques in self.questions['questions']:
|
55 |
-
qqa[ques['question_id']] = ques
|
56 |
-
# print 'index created!'
|
57 |
-
|
58 |
-
# create class members
|
59 |
-
self.qa = qa
|
60 |
-
self.qqa = qqa
|
61 |
-
self.imgToQA = imgToQA
|
62 |
-
|
63 |
-
def info(self):
|
64 |
-
"""
|
65 |
-
Print information about the VQA annotation file.
|
66 |
-
:return:
|
67 |
-
"""
|
68 |
-
|
69 |
-
# for key, value in self.datset['info'].items():
|
70 |
-
# print '%s: %s'%(key, value)
|
71 |
-
|
72 |
-
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
73 |
-
"""
|
74 |
-
Get question ids that satisfy given filter conditions. default skips that filter
|
75 |
-
:param imgIds (int array) : get question ids for given imgs
|
76 |
-
quesTypes (str array) : get question ids for given question types
|
77 |
-
ansTypes (str array) : get question ids for given answer types
|
78 |
-
:return: ids (int array) : integer array of question ids
|
79 |
-
"""
|
80 |
-
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
81 |
-
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
82 |
-
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
83 |
-
|
84 |
-
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
85 |
-
anns = self.dataset['annotations']
|
86 |
-
else:
|
87 |
-
if not len(imgIds) == 0:
|
88 |
-
anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
|
89 |
-
else:
|
90 |
-
anns = self.dataset['annotations']
|
91 |
-
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
|
92 |
-
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
|
93 |
-
ids = [ann['question_id'] for ann in anns]
|
94 |
-
return ids
|
95 |
-
|
96 |
-
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
97 |
-
"""
|
98 |
-
Get image ids that satisfy given filter conditions. default skips that filter
|
99 |
-
:param quesIds (int array) : get image ids for given question ids
|
100 |
-
quesTypes (str array) : get image ids for given question types
|
101 |
-
ansTypes (str array) : get image ids for given answer types
|
102 |
-
:return: ids (int array) : integer array of image ids
|
103 |
-
"""
|
104 |
-
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
105 |
-
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
106 |
-
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
107 |
-
|
108 |
-
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
109 |
-
anns = self.dataset['annotations']
|
110 |
-
else:
|
111 |
-
if not len(quesIds) == 0:
|
112 |
-
anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
|
113 |
-
else:
|
114 |
-
anns = self.dataset['annotations']
|
115 |
-
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
|
116 |
-
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
|
117 |
-
ids = [ann['image_id'] for ann in anns]
|
118 |
-
return ids
|
119 |
-
|
120 |
-
def loadQA(self, ids=[]):
|
121 |
-
"""
|
122 |
-
Load questions and answers with the specified question ids.
|
123 |
-
:param ids (int array) : integer ids specifying question ids
|
124 |
-
:return: qa (object array) : loaded qa objects
|
125 |
-
"""
|
126 |
-
if type(ids) == list:
|
127 |
-
return [self.qa[id] for id in ids]
|
128 |
-
elif type(ids) == int:
|
129 |
-
return [self.qa[ids]]
|
130 |
-
|
131 |
-
def showQA(self, anns):
|
132 |
-
"""
|
133 |
-
Display the specified annotations.
|
134 |
-
:param anns (array of object): annotations to display
|
135 |
-
:return: None
|
136 |
-
"""
|
137 |
-
if len(anns) == 0:
|
138 |
-
return 0
|
139 |
-
for ann in anns:
|
140 |
-
quesId = ann['question_id']
|
141 |
-
print("Question: %s" % (self.qqa[quesId]['question']))
|
142 |
-
for ans in ann['answers']:
|
143 |
-
print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
|
144 |
-
|
145 |
-
def loadRes(self, resFile, quesFile):
|
146 |
-
"""
|
147 |
-
Load result file and return a result object.
|
148 |
-
:param resFile (str) : file name of result file
|
149 |
-
:return: res (obj) : result api object
|
150 |
-
"""
|
151 |
-
res = VQA()
|
152 |
-
res.questions = json.load(open(quesFile))
|
153 |
-
res.dataset['info'] = copy.deepcopy(self.questions['info'])
|
154 |
-
res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
|
155 |
-
res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
|
156 |
-
res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
|
157 |
-
res.dataset['license'] = copy.deepcopy(self.questions['license'])
|
158 |
-
|
159 |
-
# print 'Loading and preparing results... '
|
160 |
-
time_t = datetime.datetime.utcnow()
|
161 |
-
anns = json.load(open(resFile))
|
162 |
-
assert type(anns) == list, 'results is not an array of objects'
|
163 |
-
annsQuesIds = [ann['question_id'] for ann in anns]
|
164 |
-
assert set(annsQuesIds) == set(self.getQuesIds()), \
|
165 |
-
'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
|
166 |
-
for ann in anns:
|
167 |
-
quesId = ann['question_id']
|
168 |
-
if res.dataset['task_type'] == 'Multiple Choice':
|
169 |
-
assert ann['answer'] in self.qqa[quesId][
|
170 |
-
'multiple_choices'], 'predicted answer is not one of the multiple choices'
|
171 |
-
qaAnn = self.qa[quesId]
|
172 |
-
ann['image_id'] = qaAnn['image_id']
|
173 |
-
ann['question_type'] = qaAnn['question_type']
|
174 |
-
ann['answer_type'] = qaAnn['answer_type']
|
175 |
-
# print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
|
176 |
-
|
177 |
-
res.dataset['annotations'] = anns
|
178 |
-
res.createIndex()
|
179 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
how many
|
2 |
-
what color is the
|
3 |
-
is the
|
4 |
-
where is the
|
5 |
-
what
|
6 |
-
what is
|
7 |
-
are the
|
8 |
-
what is the
|
9 |
-
is there a
|
10 |
-
does the
|
11 |
-
is the woman
|
12 |
-
is the man
|
13 |
-
what is on the
|
14 |
-
is it
|
15 |
-
is the girl
|
16 |
-
is the boy
|
17 |
-
is the dog
|
18 |
-
are they
|
19 |
-
who is
|
20 |
-
what kind of
|
21 |
-
what color are the
|
22 |
-
what is in the
|
23 |
-
what is the man
|
24 |
-
is there
|
25 |
-
what is the woman
|
26 |
-
what are the
|
27 |
-
what is the boy
|
28 |
-
are there
|
29 |
-
what is the girl
|
30 |
-
is this
|
31 |
-
how
|
32 |
-
which
|
33 |
-
how many people are
|
34 |
-
is the cat
|
35 |
-
why is the
|
36 |
-
are
|
37 |
-
will the
|
38 |
-
what type of
|
39 |
-
what is the dog
|
40 |
-
do
|
41 |
-
is she
|
42 |
-
does
|
43 |
-
do the
|
44 |
-
is
|
45 |
-
is the baby
|
46 |
-
are there any
|
47 |
-
is the lady
|
48 |
-
can
|
49 |
-
what animal is
|
50 |
-
where are the
|
51 |
-
is the sun
|
52 |
-
what are they
|
53 |
-
did the
|
54 |
-
what is the cat
|
55 |
-
what is the lady
|
56 |
-
how many clouds are
|
57 |
-
is that
|
58 |
-
is the little girl
|
59 |
-
is he
|
60 |
-
are these
|
61 |
-
how many trees are
|
62 |
-
how many pillows
|
63 |
-
are the people
|
64 |
-
why
|
65 |
-
is the young
|
66 |
-
how many windows are
|
67 |
-
is this a
|
68 |
-
what is the little
|
69 |
-
is the tv
|
70 |
-
how many animals are
|
71 |
-
who
|
72 |
-
how many pictures
|
73 |
-
how many plants are
|
74 |
-
how many birds are
|
75 |
-
what color is
|
76 |
-
what is the baby
|
77 |
-
is anyone
|
78 |
-
what color
|
79 |
-
how many bushes
|
80 |
-
is the old man
|
81 |
-
none of the above
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
how many
|
2 |
-
is the
|
3 |
-
what
|
4 |
-
what color is the
|
5 |
-
what is the
|
6 |
-
is this
|
7 |
-
is this a
|
8 |
-
what is
|
9 |
-
are the
|
10 |
-
what kind of
|
11 |
-
is there a
|
12 |
-
what type of
|
13 |
-
is it
|
14 |
-
what are the
|
15 |
-
where is the
|
16 |
-
is there
|
17 |
-
does the
|
18 |
-
what color are the
|
19 |
-
are these
|
20 |
-
are there
|
21 |
-
which
|
22 |
-
is
|
23 |
-
what is the man
|
24 |
-
is the man
|
25 |
-
are
|
26 |
-
how
|
27 |
-
does this
|
28 |
-
what is on the
|
29 |
-
what does the
|
30 |
-
how many people are
|
31 |
-
what is in the
|
32 |
-
what is this
|
33 |
-
do
|
34 |
-
what are
|
35 |
-
are they
|
36 |
-
what time
|
37 |
-
what sport is
|
38 |
-
are there any
|
39 |
-
is he
|
40 |
-
what color is
|
41 |
-
why
|
42 |
-
where are the
|
43 |
-
what color
|
44 |
-
who is
|
45 |
-
what animal is
|
46 |
-
is the woman
|
47 |
-
is this an
|
48 |
-
do you
|
49 |
-
how many people are in
|
50 |
-
what room is
|
51 |
-
has
|
52 |
-
is this person
|
53 |
-
what is the woman
|
54 |
-
can you
|
55 |
-
why is the
|
56 |
-
is the person
|
57 |
-
what is the color of the
|
58 |
-
what is the person
|
59 |
-
could
|
60 |
-
was
|
61 |
-
is that a
|
62 |
-
what number is
|
63 |
-
what is the name
|
64 |
-
what brand
|
65 |
-
none of the above
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/README.md
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
|
2 |
-
===================
|
3 |
-
## VQA v2.0 release ##
|
4 |
-
This release consists of
|
5 |
-
- Real
|
6 |
-
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
|
7 |
-
- 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
|
8 |
-
- 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
|
9 |
-
|
10 |
-
There is only one type of task
|
11 |
-
- Open-ended task
|
12 |
-
|
13 |
-
## VQA v1.0 release ##
|
14 |
-
This release consists of
|
15 |
-
- Real
|
16 |
-
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
|
17 |
-
- 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
|
18 |
-
- 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
|
19 |
-
- Abstract
|
20 |
-
- 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
|
21 |
-
- 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
|
22 |
-
- 600,000 answers for training and 300,000 answers for validation (10 per question)
|
23 |
-
|
24 |
-
There are two types of tasks
|
25 |
-
- Open-ended task
|
26 |
-
- Multiple-choice task (18 choices per question)
|
27 |
-
|
28 |
-
## Requirements ##
|
29 |
-
- python 2.7
|
30 |
-
- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
|
31 |
-
- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
|
32 |
-
|
33 |
-
## Files ##
|
34 |
-
./Questions
|
35 |
-
- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
|
36 |
-
- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
|
37 |
-
- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
|
38 |
-
- [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
|
39 |
-
- [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
|
40 |
-
- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
|
41 |
-
|
42 |
-
./Annotations
|
43 |
-
- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
|
44 |
-
- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
|
45 |
-
- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
|
46 |
-
- [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
|
47 |
-
- [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
|
48 |
-
- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
|
49 |
-
|
50 |
-
./Images
|
51 |
-
- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
|
52 |
-
- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
|
53 |
-
|
54 |
-
./PythonHelperTools
|
55 |
-
- This directory contains the Python API to read and visualize the VQA dataset
|
56 |
-
- vqaDemo.py (demo script)
|
57 |
-
- vqaTools (API to read and visualize data)
|
58 |
-
|
59 |
-
./PythonEvaluationTools
|
60 |
-
- This directory contains the Python evaluation code
|
61 |
-
- vqaEvalDemo.py (evaluation demo script)
|
62 |
-
- vqaEvaluation (evaluation code)
|
63 |
-
|
64 |
-
./Results
|
65 |
-
- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
|
66 |
-
- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
|
67 |
-
|
68 |
-
./QuestionTypes
|
69 |
-
- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
|
70 |
-
- mscoco_question_types.txt
|
71 |
-
- abstract_v002_question_types.txt
|
72 |
-
|
73 |
-
## References ##
|
74 |
-
- [VQA: Visual Question Answering](http://visualqa.org/)
|
75 |
-
- [Microsoft COCO](http://mscoco.org/)
|
76 |
-
|
77 |
-
## Developers ##
|
78 |
-
- Aishwarya Agrawal (Virginia Tech)
|
79 |
-
- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
|
80 |
-
- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/VQA/license.txt
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
Copyright (c) 2014, Aishwarya Agrawal
|
2 |
-
All rights reserved.
|
3 |
-
|
4 |
-
Redistribution and use in source and binary forms, with or without
|
5 |
-
modification, are permitted provided that the following conditions are met:
|
6 |
-
|
7 |
-
1. Redistributions of source code must retain the above copyright notice, this
|
8 |
-
list of conditions and the following disclaimer.
|
9 |
-
2. Redistributions in binary form must reproduce the above copyright notice,
|
10 |
-
this list of conditions and the following disclaimer in the documentation
|
11 |
-
and/or other materials provided with the distribution.
|
12 |
-
|
13 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
14 |
-
AND
|
15 |
-
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
16 |
-
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
18 |
-
FOR
|
19 |
-
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
|
26 |
-
The views and conclusions contained in the software and documentation are
|
27 |
-
those
|
28 |
-
of the authors and should not be interpreted as representing official
|
29 |
-
policies,
|
30 |
-
either expressed or implied, of the FreeBSD Project.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/__init__.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
__author__ = "aagrawal"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/vqa.py
DELETED
@@ -1,211 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
__author__ = "aagrawal"
|
9 |
-
__version__ = "0.9"
|
10 |
-
|
11 |
-
# Interface for accessing the VQA dataset.
|
12 |
-
|
13 |
-
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
14 |
-
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
15 |
-
|
16 |
-
# The following functions are defined:
|
17 |
-
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
18 |
-
# getQuesIds - Get question ids that satisfy given filter conditions.
|
19 |
-
# getImgIds - Get image ids that satisfy given filter conditions.
|
20 |
-
# loadQA - Load questions and answers with the specified question ids.
|
21 |
-
# showQA - Display the specified questions and answers.
|
22 |
-
# loadRes - Load result file and create result object.
|
23 |
-
|
24 |
-
# Help on each function can be accessed by: "help(COCO.function)"
|
25 |
-
|
26 |
-
import json
|
27 |
-
import datetime
|
28 |
-
import copy
|
29 |
-
|
30 |
-
|
31 |
-
class VQA:
|
32 |
-
def __init__(self, annotation_file=None, question_file=None):
|
33 |
-
"""
|
34 |
-
Constructor of VQA helper class for reading and visualizing questions and answers.
|
35 |
-
:param annotation_file (str): location of VQA annotation file
|
36 |
-
:return:
|
37 |
-
"""
|
38 |
-
# load dataset
|
39 |
-
self.dataset = {}
|
40 |
-
self.questions = {}
|
41 |
-
self.qa = {}
|
42 |
-
self.qqa = {}
|
43 |
-
self.imgToQA = {}
|
44 |
-
if not annotation_file == None and not question_file == None:
|
45 |
-
print("loading VQA annotations and questions into memory...")
|
46 |
-
time_t = datetime.datetime.utcnow()
|
47 |
-
dataset = json.load(open(annotation_file, "r"))
|
48 |
-
questions = json.load(open(question_file, "r"))
|
49 |
-
self.dataset = dataset
|
50 |
-
self.questions = questions
|
51 |
-
self.createIndex()
|
52 |
-
|
53 |
-
def createIndex(self):
|
54 |
-
# create index
|
55 |
-
print("creating index...")
|
56 |
-
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
|
57 |
-
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
58 |
-
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
59 |
-
for ann in self.dataset["annotations"]:
|
60 |
-
imgToQA[ann["image_id"]] += [ann]
|
61 |
-
qa[ann["question_id"]] = ann
|
62 |
-
for ques in self.questions["questions"]:
|
63 |
-
qqa[ques["question_id"]] = ques
|
64 |
-
print("index created!")
|
65 |
-
|
66 |
-
# create class members
|
67 |
-
self.qa = qa
|
68 |
-
self.qqa = qqa
|
69 |
-
self.imgToQA = imgToQA
|
70 |
-
|
71 |
-
def info(self):
|
72 |
-
"""
|
73 |
-
Print information about the VQA annotation file.
|
74 |
-
:return:
|
75 |
-
"""
|
76 |
-
for key, value in self.datset["info"].items():
|
77 |
-
print("%s: %s" % (key, value))
|
78 |
-
|
79 |
-
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
80 |
-
"""
|
81 |
-
Get question ids that satisfy given filter conditions. default skips that filter
|
82 |
-
:param imgIds (int array) : get question ids for given imgs
|
83 |
-
quesTypes (str array) : get question ids for given question types
|
84 |
-
ansTypes (str array) : get question ids for given answer types
|
85 |
-
:return: ids (int array) : integer array of question ids
|
86 |
-
"""
|
87 |
-
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
88 |
-
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
89 |
-
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
90 |
-
|
91 |
-
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
92 |
-
anns = self.dataset["annotations"]
|
93 |
-
else:
|
94 |
-
if not len(imgIds) == 0:
|
95 |
-
anns = sum(
|
96 |
-
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
|
97 |
-
[],
|
98 |
-
)
|
99 |
-
else:
|
100 |
-
anns = self.dataset["annotations"]
|
101 |
-
anns = (
|
102 |
-
anns
|
103 |
-
if len(quesTypes) == 0
|
104 |
-
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
105 |
-
)
|
106 |
-
anns = (
|
107 |
-
anns
|
108 |
-
if len(ansTypes) == 0
|
109 |
-
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
110 |
-
)
|
111 |
-
ids = [ann["question_id"] for ann in anns]
|
112 |
-
return ids
|
113 |
-
|
114 |
-
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
115 |
-
"""
|
116 |
-
Get image ids that satisfy given filter conditions. default skips that filter
|
117 |
-
:param quesIds (int array) : get image ids for given question ids
|
118 |
-
quesTypes (str array) : get image ids for given question types
|
119 |
-
ansTypes (str array) : get image ids for given answer types
|
120 |
-
:return: ids (int array) : integer array of image ids
|
121 |
-
"""
|
122 |
-
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
123 |
-
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
124 |
-
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
125 |
-
|
126 |
-
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
127 |
-
anns = self.dataset["annotations"]
|
128 |
-
else:
|
129 |
-
if not len(quesIds) == 0:
|
130 |
-
anns = sum(
|
131 |
-
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
|
132 |
-
)
|
133 |
-
else:
|
134 |
-
anns = self.dataset["annotations"]
|
135 |
-
anns = (
|
136 |
-
anns
|
137 |
-
if len(quesTypes) == 0
|
138 |
-
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
139 |
-
)
|
140 |
-
anns = (
|
141 |
-
anns
|
142 |
-
if len(ansTypes) == 0
|
143 |
-
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
144 |
-
)
|
145 |
-
ids = [ann["image_id"] for ann in anns]
|
146 |
-
return ids
|
147 |
-
|
148 |
-
def loadQA(self, ids=[]):
|
149 |
-
"""
|
150 |
-
Load questions and answers with the specified question ids.
|
151 |
-
:param ids (int array) : integer ids specifying question ids
|
152 |
-
:return: qa (object array) : loaded qa objects
|
153 |
-
"""
|
154 |
-
if type(ids) == list:
|
155 |
-
return [self.qa[id] for id in ids]
|
156 |
-
elif type(ids) == int:
|
157 |
-
return [self.qa[ids]]
|
158 |
-
|
159 |
-
def showQA(self, anns):
|
160 |
-
"""
|
161 |
-
Display the specified annotations.
|
162 |
-
:param anns (array of object): annotations to display
|
163 |
-
:return: None
|
164 |
-
"""
|
165 |
-
if len(anns) == 0:
|
166 |
-
return 0
|
167 |
-
for ann in anns:
|
168 |
-
quesId = ann["question_id"]
|
169 |
-
print("Question: %s" % (self.qqa[quesId]["question"]))
|
170 |
-
for ans in ann["answers"]:
|
171 |
-
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
|
172 |
-
|
173 |
-
def loadRes(self, resFile, quesFile):
|
174 |
-
"""
|
175 |
-
Load result file and return a result object.
|
176 |
-
:param resFile (str) : file name of result file
|
177 |
-
:return: res (obj) : result api object
|
178 |
-
"""
|
179 |
-
res = VQA()
|
180 |
-
res.questions = json.load(open(quesFile))
|
181 |
-
res.dataset["info"] = copy.deepcopy(self.questions["info"])
|
182 |
-
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
|
183 |
-
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
|
184 |
-
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
|
185 |
-
res.dataset["license"] = copy.deepcopy(self.questions["license"])
|
186 |
-
|
187 |
-
print("Loading and preparing results... ")
|
188 |
-
time_t = datetime.datetime.utcnow()
|
189 |
-
anns = json.load(open(resFile))
|
190 |
-
assert type(anns) == list, "results is not an array of objects"
|
191 |
-
annsQuesIds = [ann["question_id"] for ann in anns]
|
192 |
-
assert set(annsQuesIds) == set(
|
193 |
-
self.getQuesIds()
|
194 |
-
), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
|
195 |
-
for ann in anns:
|
196 |
-
quesId = ann["question_id"]
|
197 |
-
if res.dataset["task_type"] == "Multiple Choice":
|
198 |
-
assert (
|
199 |
-
ann["answer"] in self.qqa[quesId]["multiple_choices"]
|
200 |
-
), "predicted answer is not one of the multiple choices"
|
201 |
-
qaAnn = self.qa[quesId]
|
202 |
-
ann["image_id"] = qaAnn["image_id"]
|
203 |
-
ann["question_type"] = qaAnn["question_type"]
|
204 |
-
ann["answer_type"] = qaAnn["answer_type"]
|
205 |
-
print(
|
206 |
-
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
|
207 |
-
)
|
208 |
-
|
209 |
-
res.dataset["annotations"] = anns
|
210 |
-
res.createIndex()
|
211 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/common/vqa_tools/vqa_eval.py
DELETED
@@ -1,324 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
# coding=utf-8
|
9 |
-
|
10 |
-
__author__ = "aagrawal"
|
11 |
-
|
12 |
-
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
13 |
-
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
14 |
-
import sys
|
15 |
-
import re
|
16 |
-
|
17 |
-
|
18 |
-
class VQAEval:
|
19 |
-
def __init__(self, vqa=None, vqaRes=None, n=2):
|
20 |
-
self.n = n
|
21 |
-
self.accuracy = {}
|
22 |
-
self.evalQA = {}
|
23 |
-
self.evalQuesType = {}
|
24 |
-
self.evalAnsType = {}
|
25 |
-
self.vqa = vqa
|
26 |
-
self.vqaRes = vqaRes
|
27 |
-
if vqa is not None:
|
28 |
-
self.params = {"question_id": vqa.getQuesIds()}
|
29 |
-
self.contractions = {
|
30 |
-
"aint": "ain't",
|
31 |
-
"arent": "aren't",
|
32 |
-
"cant": "can't",
|
33 |
-
"couldve": "could've",
|
34 |
-
"couldnt": "couldn't",
|
35 |
-
"couldn'tve": "couldn't've",
|
36 |
-
"couldnt've": "couldn't've",
|
37 |
-
"didnt": "didn't",
|
38 |
-
"doesnt": "doesn't",
|
39 |
-
"dont": "don't",
|
40 |
-
"hadnt": "hadn't",
|
41 |
-
"hadnt've": "hadn't've",
|
42 |
-
"hadn'tve": "hadn't've",
|
43 |
-
"hasnt": "hasn't",
|
44 |
-
"havent": "haven't",
|
45 |
-
"hed": "he'd",
|
46 |
-
"hed've": "he'd've",
|
47 |
-
"he'dve": "he'd've",
|
48 |
-
"hes": "he's",
|
49 |
-
"howd": "how'd",
|
50 |
-
"howll": "how'll",
|
51 |
-
"hows": "how's",
|
52 |
-
"Id've": "I'd've",
|
53 |
-
"I'dve": "I'd've",
|
54 |
-
"Im": "I'm",
|
55 |
-
"Ive": "I've",
|
56 |
-
"isnt": "isn't",
|
57 |
-
"itd": "it'd",
|
58 |
-
"itd've": "it'd've",
|
59 |
-
"it'dve": "it'd've",
|
60 |
-
"itll": "it'll",
|
61 |
-
"let's": "let's",
|
62 |
-
"maam": "ma'am",
|
63 |
-
"mightnt": "mightn't",
|
64 |
-
"mightnt've": "mightn't've",
|
65 |
-
"mightn'tve": "mightn't've",
|
66 |
-
"mightve": "might've",
|
67 |
-
"mustnt": "mustn't",
|
68 |
-
"mustve": "must've",
|
69 |
-
"neednt": "needn't",
|
70 |
-
"notve": "not've",
|
71 |
-
"oclock": "o'clock",
|
72 |
-
"oughtnt": "oughtn't",
|
73 |
-
"ow's'at": "'ow's'at",
|
74 |
-
"'ows'at": "'ow's'at",
|
75 |
-
"'ow'sat": "'ow's'at",
|
76 |
-
"shant": "shan't",
|
77 |
-
"shed've": "she'd've",
|
78 |
-
"she'dve": "she'd've",
|
79 |
-
"she's": "she's",
|
80 |
-
"shouldve": "should've",
|
81 |
-
"shouldnt": "shouldn't",
|
82 |
-
"shouldnt've": "shouldn't've",
|
83 |
-
"shouldn'tve": "shouldn't've",
|
84 |
-
"somebody'd": "somebodyd",
|
85 |
-
"somebodyd've": "somebody'd've",
|
86 |
-
"somebody'dve": "somebody'd've",
|
87 |
-
"somebodyll": "somebody'll",
|
88 |
-
"somebodys": "somebody's",
|
89 |
-
"someoned": "someone'd",
|
90 |
-
"someoned've": "someone'd've",
|
91 |
-
"someone'dve": "someone'd've",
|
92 |
-
"someonell": "someone'll",
|
93 |
-
"someones": "someone's",
|
94 |
-
"somethingd": "something'd",
|
95 |
-
"somethingd've": "something'd've",
|
96 |
-
"something'dve": "something'd've",
|
97 |
-
"somethingll": "something'll",
|
98 |
-
"thats": "that's",
|
99 |
-
"thered": "there'd",
|
100 |
-
"thered've": "there'd've",
|
101 |
-
"there'dve": "there'd've",
|
102 |
-
"therere": "there're",
|
103 |
-
"theres": "there's",
|
104 |
-
"theyd": "they'd",
|
105 |
-
"theyd've": "they'd've",
|
106 |
-
"they'dve": "they'd've",
|
107 |
-
"theyll": "they'll",
|
108 |
-
"theyre": "they're",
|
109 |
-
"theyve": "they've",
|
110 |
-
"twas": "'twas",
|
111 |
-
"wasnt": "wasn't",
|
112 |
-
"wed've": "we'd've",
|
113 |
-
"we'dve": "we'd've",
|
114 |
-
"weve": "we've",
|
115 |
-
"werent": "weren't",
|
116 |
-
"whatll": "what'll",
|
117 |
-
"whatre": "what're",
|
118 |
-
"whats": "what's",
|
119 |
-
"whatve": "what've",
|
120 |
-
"whens": "when's",
|
121 |
-
"whered": "where'd",
|
122 |
-
"wheres": "where's",
|
123 |
-
"whereve": "where've",
|
124 |
-
"whod": "who'd",
|
125 |
-
"whod've": "who'd've",
|
126 |
-
"who'dve": "who'd've",
|
127 |
-
"wholl": "who'll",
|
128 |
-
"whos": "who's",
|
129 |
-
"whove": "who've",
|
130 |
-
"whyll": "why'll",
|
131 |
-
"whyre": "why're",
|
132 |
-
"whys": "why's",
|
133 |
-
"wont": "won't",
|
134 |
-
"wouldve": "would've",
|
135 |
-
"wouldnt": "wouldn't",
|
136 |
-
"wouldnt've": "wouldn't've",
|
137 |
-
"wouldn'tve": "wouldn't've",
|
138 |
-
"yall": "y'all",
|
139 |
-
"yall'll": "y'all'll",
|
140 |
-
"y'allll": "y'all'll",
|
141 |
-
"yall'd've": "y'all'd've",
|
142 |
-
"y'alld've": "y'all'd've",
|
143 |
-
"y'all'dve": "y'all'd've",
|
144 |
-
"youd": "you'd",
|
145 |
-
"youd've": "you'd've",
|
146 |
-
"you'dve": "you'd've",
|
147 |
-
"youll": "you'll",
|
148 |
-
"youre": "you're",
|
149 |
-
"youve": "you've",
|
150 |
-
}
|
151 |
-
self.manualMap = {
|
152 |
-
"none": "0",
|
153 |
-
"zero": "0",
|
154 |
-
"one": "1",
|
155 |
-
"two": "2",
|
156 |
-
"three": "3",
|
157 |
-
"four": "4",
|
158 |
-
"five": "5",
|
159 |
-
"six": "6",
|
160 |
-
"seven": "7",
|
161 |
-
"eight": "8",
|
162 |
-
"nine": "9",
|
163 |
-
"ten": "10",
|
164 |
-
}
|
165 |
-
self.articles = ["a", "an", "the"]
|
166 |
-
|
167 |
-
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
168 |
-
self.commaStrip = re.compile("(\d)(,)(\d)")
|
169 |
-
self.punct = [
|
170 |
-
";",
|
171 |
-
r"/",
|
172 |
-
"[",
|
173 |
-
"]",
|
174 |
-
'"',
|
175 |
-
"{",
|
176 |
-
"}",
|
177 |
-
"(",
|
178 |
-
")",
|
179 |
-
"=",
|
180 |
-
"+",
|
181 |
-
"\\",
|
182 |
-
"_",
|
183 |
-
"-",
|
184 |
-
">",
|
185 |
-
"<",
|
186 |
-
"@",
|
187 |
-
"`",
|
188 |
-
",",
|
189 |
-
"?",
|
190 |
-
"!",
|
191 |
-
]
|
192 |
-
|
193 |
-
def evaluate(self, quesIds=None):
|
194 |
-
if quesIds == None:
|
195 |
-
quesIds = [quesId for quesId in self.params["question_id"]]
|
196 |
-
gts = {}
|
197 |
-
res = {}
|
198 |
-
for quesId in quesIds:
|
199 |
-
gts[quesId] = self.vqa.qa[quesId]
|
200 |
-
res[quesId] = self.vqaRes.qa[quesId]
|
201 |
-
|
202 |
-
# =================================================
|
203 |
-
# Compute accuracy
|
204 |
-
# =================================================
|
205 |
-
accQA = []
|
206 |
-
accQuesType = {}
|
207 |
-
accAnsType = {}
|
208 |
-
print("computing accuracy")
|
209 |
-
step = 0
|
210 |
-
for quesId in quesIds:
|
211 |
-
resAns = res[quesId]["answer"]
|
212 |
-
resAns = resAns.replace("\n", " ")
|
213 |
-
resAns = resAns.replace("\t", " ")
|
214 |
-
resAns = resAns.strip()
|
215 |
-
resAns = self.processPunctuation(resAns)
|
216 |
-
resAns = self.processDigitArticle(resAns)
|
217 |
-
gtAcc = []
|
218 |
-
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
|
219 |
-
if len(set(gtAnswers)) > 1:
|
220 |
-
for ansDic in gts[quesId]["answers"]:
|
221 |
-
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
|
222 |
-
for gtAnsDatum in gts[quesId]["answers"]:
|
223 |
-
otherGTAns = [
|
224 |
-
item for item in gts[quesId]["answers"] if item != gtAnsDatum
|
225 |
-
]
|
226 |
-
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
|
227 |
-
acc = min(1, float(len(matchingAns)) / 3)
|
228 |
-
gtAcc.append(acc)
|
229 |
-
quesType = gts[quesId]["question_type"]
|
230 |
-
ansType = gts[quesId]["answer_type"]
|
231 |
-
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
232 |
-
accQA.append(avgGTAcc)
|
233 |
-
if quesType not in accQuesType:
|
234 |
-
accQuesType[quesType] = []
|
235 |
-
accQuesType[quesType].append(avgGTAcc)
|
236 |
-
if ansType not in accAnsType:
|
237 |
-
accAnsType[ansType] = []
|
238 |
-
accAnsType[ansType].append(avgGTAcc)
|
239 |
-
self.setEvalQA(quesId, avgGTAcc)
|
240 |
-
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
241 |
-
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
242 |
-
if step % 100 == 0:
|
243 |
-
self.updateProgress(step / float(len(quesIds)))
|
244 |
-
step = step + 1
|
245 |
-
|
246 |
-
self.setAccuracy(accQA, accQuesType, accAnsType)
|
247 |
-
print("Done computing accuracy")
|
248 |
-
|
249 |
-
def processPunctuation(self, inText):
|
250 |
-
outText = inText
|
251 |
-
for p in self.punct:
|
252 |
-
if (p + " " in inText or " " + p in inText) or (
|
253 |
-
re.search(self.commaStrip, inText) != None
|
254 |
-
):
|
255 |
-
outText = outText.replace(p, "")
|
256 |
-
else:
|
257 |
-
outText = outText.replace(p, " ")
|
258 |
-
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
259 |
-
return outText
|
260 |
-
|
261 |
-
def processDigitArticle(self, inText):
|
262 |
-
outText = []
|
263 |
-
tempText = inText.lower().split()
|
264 |
-
for word in tempText:
|
265 |
-
word = self.manualMap.setdefault(word, word)
|
266 |
-
if word not in self.articles:
|
267 |
-
outText.append(word)
|
268 |
-
else:
|
269 |
-
pass
|
270 |
-
for wordId, word in enumerate(outText):
|
271 |
-
if word in self.contractions:
|
272 |
-
outText[wordId] = self.contractions[word]
|
273 |
-
outText = " ".join(outText)
|
274 |
-
return outText
|
275 |
-
|
276 |
-
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
277 |
-
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
|
278 |
-
self.accuracy["perQuestionType"] = {
|
279 |
-
quesType: round(
|
280 |
-
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
|
281 |
-
self.n,
|
282 |
-
)
|
283 |
-
for quesType in accQuesType
|
284 |
-
}
|
285 |
-
self.accuracy["perAnswerType"] = {
|
286 |
-
ansType: round(
|
287 |
-
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
|
288 |
-
)
|
289 |
-
for ansType in accAnsType
|
290 |
-
}
|
291 |
-
|
292 |
-
def setEvalQA(self, quesId, acc):
|
293 |
-
self.evalQA[quesId] = round(100 * acc, self.n)
|
294 |
-
|
295 |
-
def setEvalQuesType(self, quesId, quesType, acc):
|
296 |
-
if quesType not in self.evalQuesType:
|
297 |
-
self.evalQuesType[quesType] = {}
|
298 |
-
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
299 |
-
|
300 |
-
def setEvalAnsType(self, quesId, ansType, acc):
|
301 |
-
if ansType not in self.evalAnsType:
|
302 |
-
self.evalAnsType[ansType] = {}
|
303 |
-
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
304 |
-
|
305 |
-
def updateProgress(self, progress):
|
306 |
-
barLength = 20
|
307 |
-
status = ""
|
308 |
-
if isinstance(progress, int):
|
309 |
-
progress = float(progress)
|
310 |
-
if not isinstance(progress, float):
|
311 |
-
progress = 0
|
312 |
-
status = "error: progress var must be float\r\n"
|
313 |
-
if progress < 0:
|
314 |
-
progress = 0
|
315 |
-
status = "Halt...\r\n"
|
316 |
-
if progress >= 1:
|
317 |
-
progress = 1
|
318 |
-
status = "Done...\r\n"
|
319 |
-
block = int(round(barLength * progress))
|
320 |
-
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
|
321 |
-
"#" * block + "-" * (barLength - block), int(progress * 100), status
|
322 |
-
)
|
323 |
-
sys.stdout.write(text)
|
324 |
-
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/configs/datasets/cc_combine/align.yaml
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, salesforce.com, inc.
|
2 |
-
# All rights reserved.
|
3 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
4 |
-
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
5 |
-
|
6 |
-
datasets:
|
7 |
-
cc_align:
|
8 |
-
data_type: images
|
9 |
-
build_info:
|
10 |
-
# Be careful not to append minus sign (-) before split to avoid itemizing
|
11 |
-
annotations:
|
12 |
-
train:
|
13 |
-
url: placeholder
|
14 |
-
storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/filter_cap.json
|
15 |
-
images:
|
16 |
-
storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/configs/datasets/cc_combine/defaults.yaml
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, salesforce.com, inc.
|
2 |
-
# All rights reserved.
|
3 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
4 |
-
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
5 |
-
|
6 |
-
datasets:
|
7 |
-
cc_combine:
|
8 |
-
data_type: images
|
9 |
-
build_info:
|
10 |
-
# Be careful not to append minus sign (-) before split to avoid itemizing
|
11 |
-
storage: /ibex/project/c2133/blip_dataset/cc3m/cc3m_cc12m_sbu/{00000..01255}.tar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/configs/datasets/laion/defaults.yaml
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, salesforce.com, inc.
|
2 |
-
# All rights reserved.
|
3 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
4 |
-
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
5 |
-
|
6 |
-
datasets:
|
7 |
-
laion:
|
8 |
-
|
9 |
-
data_type: images
|
10 |
-
|
11 |
-
build_info:
|
12 |
-
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
-
storage: /ibex/project/c2133/blip_dataset/laion_1b/laion_gpu/{00000..10488}.tar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/configs/default.yaml
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
env:
|
2 |
-
# For default users
|
3 |
-
# cache_root: "cache"
|
4 |
-
# For internal use with persistent storage
|
5 |
-
cache_root: "/export/home/.cache/minigpt4"
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/configs/models/minigpt4_vicuna0.yaml
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
arch: minigpt4
|
3 |
-
|
4 |
-
# vit encoder
|
5 |
-
image_size: 224
|
6 |
-
drop_path_rate: 0
|
7 |
-
use_grad_checkpoint: False
|
8 |
-
vit_precision: "fp16"
|
9 |
-
freeze_vit: True
|
10 |
-
freeze_qformer: True
|
11 |
-
|
12 |
-
# Q-Former
|
13 |
-
num_query_token: 32
|
14 |
-
|
15 |
-
# generation configs
|
16 |
-
prompt: ""
|
17 |
-
|
18 |
-
llama_model: "wangrongsheng/MiniGPT-4-LLaMA-7B"
|
19 |
-
|
20 |
-
preprocess:
|
21 |
-
vis_processor:
|
22 |
-
train:
|
23 |
-
name: "blip2_image_train"
|
24 |
-
image_size: 224
|
25 |
-
eval:
|
26 |
-
name: "blip2_image_eval"
|
27 |
-
image_size: 224
|
28 |
-
text_processor:
|
29 |
-
train:
|
30 |
-
name: "blip_caption"
|
31 |
-
eval:
|
32 |
-
name: "blip_caption"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/conversation/__init__.py
DELETED
File without changes
|
minigpt4/conversation/conversation.py
DELETED
@@ -1,233 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import time
|
3 |
-
from threading import Thread
|
4 |
-
from PIL import Image
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
8 |
-
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
9 |
-
|
10 |
-
import dataclasses
|
11 |
-
from enum import auto, Enum
|
12 |
-
from typing import List, Tuple, Any
|
13 |
-
|
14 |
-
from minigpt4.common.registry import registry
|
15 |
-
|
16 |
-
|
17 |
-
class SeparatorStyle(Enum):
|
18 |
-
"""Different separator style."""
|
19 |
-
SINGLE = auto()
|
20 |
-
TWO = auto()
|
21 |
-
|
22 |
-
|
23 |
-
@dataclasses.dataclass
|
24 |
-
class Conversation:
|
25 |
-
"""A class that keeps all conversation history."""
|
26 |
-
system: str
|
27 |
-
roles: List[str]
|
28 |
-
messages: List[List[str]]
|
29 |
-
offset: int
|
30 |
-
# system_img: List[Image.Image] = []
|
31 |
-
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
32 |
-
sep: str = "###"
|
33 |
-
sep2: str = None
|
34 |
-
|
35 |
-
skip_next: bool = False
|
36 |
-
conv_id: Any = None
|
37 |
-
|
38 |
-
def get_prompt(self):
|
39 |
-
if self.sep_style == SeparatorStyle.SINGLE:
|
40 |
-
ret = self.system + self.sep
|
41 |
-
for role, message in self.messages:
|
42 |
-
if message:
|
43 |
-
ret += role + message + self.sep
|
44 |
-
else:
|
45 |
-
ret += role
|
46 |
-
return ret
|
47 |
-
elif self.sep_style == SeparatorStyle.TWO:
|
48 |
-
seps = [self.sep, self.sep2]
|
49 |
-
ret = self.system + seps[0]
|
50 |
-
for i, (role, message) in enumerate(self.messages):
|
51 |
-
if message:
|
52 |
-
ret += role + message + seps[i % 2]
|
53 |
-
else:
|
54 |
-
ret += role
|
55 |
-
return ret
|
56 |
-
else:
|
57 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
58 |
-
|
59 |
-
def append_message(self, role, message):
|
60 |
-
self.messages.append([role, message])
|
61 |
-
|
62 |
-
def to_gradio_chatbot(self):
|
63 |
-
ret = []
|
64 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
65 |
-
if i % 2 == 0:
|
66 |
-
ret.append([msg, None])
|
67 |
-
else:
|
68 |
-
ret[-1][-1] = msg
|
69 |
-
return ret
|
70 |
-
|
71 |
-
def copy(self):
|
72 |
-
return Conversation(
|
73 |
-
system=self.system,
|
74 |
-
# system_img=self.system_img,
|
75 |
-
roles=self.roles,
|
76 |
-
messages=[[x, y] for x, y in self.messages],
|
77 |
-
offset=self.offset,
|
78 |
-
sep_style=self.sep_style,
|
79 |
-
sep=self.sep,
|
80 |
-
sep2=self.sep2,
|
81 |
-
conv_id=self.conv_id)
|
82 |
-
|
83 |
-
def dict(self):
|
84 |
-
return {
|
85 |
-
"system": self.system,
|
86 |
-
# "system_img": self.system_img,
|
87 |
-
"roles": self.roles,
|
88 |
-
"messages": self.messages,
|
89 |
-
"offset": self.offset,
|
90 |
-
"sep": self.sep,
|
91 |
-
"sep2": self.sep2,
|
92 |
-
"conv_id": self.conv_id,
|
93 |
-
}
|
94 |
-
|
95 |
-
|
96 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
97 |
-
|
98 |
-
def __init__(self, stops=[], encounters=1):
|
99 |
-
super().__init__()
|
100 |
-
self.stops = stops
|
101 |
-
|
102 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
103 |
-
for stop in self.stops:
|
104 |
-
if torch.all(input_ids[:, -len(stop):] == stop).item():
|
105 |
-
return True
|
106 |
-
|
107 |
-
return False
|
108 |
-
|
109 |
-
|
110 |
-
CONV_VISION_Vicuna0 = Conversation(
|
111 |
-
system="Give the following image: <Img>ImageContent</Img>. "
|
112 |
-
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
113 |
-
roles=("Human: ", "Assistant: "),
|
114 |
-
messages=[],
|
115 |
-
offset=2,
|
116 |
-
sep_style=SeparatorStyle.SINGLE,
|
117 |
-
sep="###",
|
118 |
-
)
|
119 |
-
|
120 |
-
CONV_VISION_LLama2 = Conversation(
|
121 |
-
system="Give the following image: <Img>ImageContent</Img>. "
|
122 |
-
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
123 |
-
roles=("<s>[INST] ", " [/INST] "),
|
124 |
-
messages=[],
|
125 |
-
offset=2,
|
126 |
-
sep_style=SeparatorStyle.SINGLE,
|
127 |
-
sep="",
|
128 |
-
)
|
129 |
-
|
130 |
-
CONV_VISION_minigptv2 = Conversation(
|
131 |
-
system="",
|
132 |
-
roles=("<s>[INST] ", " [/INST]"),
|
133 |
-
messages=[],
|
134 |
-
offset=2,
|
135 |
-
sep_style=SeparatorStyle.SINGLE,
|
136 |
-
sep="",
|
137 |
-
)
|
138 |
-
|
139 |
-
class Chat:
|
140 |
-
def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
|
141 |
-
self.device = device
|
142 |
-
self.model = model
|
143 |
-
self.vis_processor = vis_processor
|
144 |
-
|
145 |
-
if stopping_criteria is not None:
|
146 |
-
self.stopping_criteria = stopping_criteria
|
147 |
-
else:
|
148 |
-
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
149 |
-
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
150 |
-
|
151 |
-
def ask(self, text, conv):
|
152 |
-
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
153 |
-
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
154 |
-
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
155 |
-
else:
|
156 |
-
conv.append_message(conv.roles[0], text)
|
157 |
-
|
158 |
-
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
159 |
-
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
160 |
-
conv.append_message(conv.roles[1], None)
|
161 |
-
prompt = conv.get_prompt()
|
162 |
-
embs = self.model.get_context_emb(prompt, img_list)
|
163 |
-
|
164 |
-
current_max_len = embs.shape[1] + max_new_tokens
|
165 |
-
if current_max_len - max_length > 0:
|
166 |
-
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
167 |
-
'The model will not see the contexts outside the range.')
|
168 |
-
begin_idx = max(0, current_max_len - max_length)
|
169 |
-
embs = embs[:, begin_idx:]
|
170 |
-
|
171 |
-
generation_kwargs = dict(
|
172 |
-
inputs_embeds=embs,
|
173 |
-
max_new_tokens=max_new_tokens,
|
174 |
-
stopping_criteria=self.stopping_criteria,
|
175 |
-
num_beams=num_beams,
|
176 |
-
do_sample=True,
|
177 |
-
min_length=min_length,
|
178 |
-
top_p=top_p,
|
179 |
-
repetition_penalty=repetition_penalty,
|
180 |
-
length_penalty=length_penalty,
|
181 |
-
temperature=float(temperature),
|
182 |
-
)
|
183 |
-
return generation_kwargs
|
184 |
-
|
185 |
-
def answer(self, conv, img_list, **kargs):
|
186 |
-
generation_dict = self.answer_prepare(conv, img_list, **kargs)
|
187 |
-
output_token = self.model_generate(**generation_dict)[0]
|
188 |
-
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
189 |
-
|
190 |
-
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
191 |
-
output_text = output_text.split('Assistant:')[-1].strip()
|
192 |
-
|
193 |
-
conv.messages[-1][1] = output_text
|
194 |
-
return output_text, output_token.cpu().numpy()
|
195 |
-
|
196 |
-
def stream_answer(self, conv, img_list, **kargs):
|
197 |
-
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
198 |
-
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
199 |
-
generation_kwargs['streamer'] = streamer
|
200 |
-
thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
|
201 |
-
thread.start()
|
202 |
-
return streamer
|
203 |
-
|
204 |
-
def model_generate(self, *args, **kwargs):
|
205 |
-
# for 8 bit and 16 bit compatibility
|
206 |
-
with self.model.maybe_autocast():
|
207 |
-
output = self.model.llama_model.generate(*args, **kwargs)
|
208 |
-
return output
|
209 |
-
|
210 |
-
def encode_img(self, img_list):
|
211 |
-
image = img_list[0]
|
212 |
-
img_list.pop(0)
|
213 |
-
if isinstance(image, str): # is a image path
|
214 |
-
raw_image = Image.open(image).convert('RGB')
|
215 |
-
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
216 |
-
elif isinstance(image, Image.Image):
|
217 |
-
raw_image = image
|
218 |
-
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
219 |
-
elif isinstance(image, torch.Tensor):
|
220 |
-
if len(image.shape) == 3:
|
221 |
-
image = image.unsqueeze(0)
|
222 |
-
image = image.to(self.device)
|
223 |
-
|
224 |
-
image_emb, _ = self.model.encode_img(image)
|
225 |
-
img_list.append(image_emb)
|
226 |
-
|
227 |
-
def upload_img(self, image, conv, img_list):
|
228 |
-
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
229 |
-
img_list.append(image)
|
230 |
-
msg = "Received."
|
231 |
-
|
232 |
-
return msg
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/__init__.py
DELETED
File without changes
|
minigpt4/datasets/builders/__init__.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
|
9 |
-
from minigpt4.datasets.builders.image_text_pair_builder import (
|
10 |
-
CCSBUBuilder,
|
11 |
-
LaionBuilder,
|
12 |
-
CCSBUAlignBuilder
|
13 |
-
)
|
14 |
-
from minigpt4.common.registry import registry
|
15 |
-
|
16 |
-
__all__ = [
|
17 |
-
"CCSBUBuilder",
|
18 |
-
"LaionBuilder",
|
19 |
-
"CCSBUAlignBuilder"
|
20 |
-
]
|
21 |
-
|
22 |
-
|
23 |
-
def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
|
24 |
-
"""
|
25 |
-
Example
|
26 |
-
|
27 |
-
>>> dataset = load_dataset("coco_caption", cfg=None)
|
28 |
-
>>> splits = dataset.keys()
|
29 |
-
>>> print([len(dataset[split]) for split in splits])
|
30 |
-
|
31 |
-
"""
|
32 |
-
if cfg_path is None:
|
33 |
-
cfg = None
|
34 |
-
else:
|
35 |
-
cfg = load_dataset_config(cfg_path)
|
36 |
-
|
37 |
-
try:
|
38 |
-
builder = registry.get_builder_class(name)(cfg)
|
39 |
-
except TypeError:
|
40 |
-
print(
|
41 |
-
f"Dataset {name} not found. Available datasets:\n"
|
42 |
-
+ ", ".join([str(k) for k in dataset_zoo.get_names()])
|
43 |
-
)
|
44 |
-
exit(1)
|
45 |
-
|
46 |
-
if vis_path is not None:
|
47 |
-
if data_type is None:
|
48 |
-
# use default data type in the config
|
49 |
-
data_type = builder.config.data_type
|
50 |
-
|
51 |
-
assert (
|
52 |
-
data_type in builder.config.build_info
|
53 |
-
), f"Invalid data_type {data_type} for {name}."
|
54 |
-
|
55 |
-
builder.config.build_info.get(data_type).storage = vis_path
|
56 |
-
|
57 |
-
dataset = builder.build_datasets()
|
58 |
-
return dataset
|
59 |
-
|
60 |
-
|
61 |
-
class DatasetZoo:
|
62 |
-
def __init__(self) -> None:
|
63 |
-
self.dataset_zoo = {
|
64 |
-
k: list(v.DATASET_CONFIG_DICT.keys())
|
65 |
-
for k, v in sorted(registry.mapping["builder_name_mapping"].items())
|
66 |
-
}
|
67 |
-
|
68 |
-
def get_names(self):
|
69 |
-
return list(self.dataset_zoo.keys())
|
70 |
-
|
71 |
-
|
72 |
-
dataset_zoo = DatasetZoo()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/builders/base_dataset_builder.py
DELETED
@@ -1,236 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This file is from
|
3 |
-
Copyright (c) 2022, salesforce.com, inc.
|
4 |
-
All rights reserved.
|
5 |
-
SPDX-License-Identifier: BSD-3-Clause
|
6 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
7 |
-
"""
|
8 |
-
|
9 |
-
import logging
|
10 |
-
import os
|
11 |
-
import shutil
|
12 |
-
import warnings
|
13 |
-
|
14 |
-
from omegaconf import OmegaConf
|
15 |
-
import torch.distributed as dist
|
16 |
-
from torchvision.datasets.utils import download_url
|
17 |
-
|
18 |
-
import minigpt4.common.utils as utils
|
19 |
-
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
20 |
-
from minigpt4.common.registry import registry
|
21 |
-
from minigpt4.processors.base_processor import BaseProcessor
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
class BaseDatasetBuilder:
|
26 |
-
train_dataset_cls, eval_dataset_cls = None, None
|
27 |
-
|
28 |
-
def __init__(self, cfg=None):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
if cfg is None:
|
32 |
-
# help to create datasets from default config.
|
33 |
-
self.config = load_dataset_config(self.default_config_path())
|
34 |
-
elif isinstance(cfg, str):
|
35 |
-
self.config = load_dataset_config(cfg)
|
36 |
-
else:
|
37 |
-
# when called from task.build_dataset()
|
38 |
-
self.config = cfg
|
39 |
-
|
40 |
-
self.data_type = self.config.data_type
|
41 |
-
|
42 |
-
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
43 |
-
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
44 |
-
|
45 |
-
def build_datasets(self):
|
46 |
-
# download, split, etc...
|
47 |
-
# only called on 1 GPU/TPU in distributed
|
48 |
-
|
49 |
-
if is_main_process():
|
50 |
-
self._download_data()
|
51 |
-
|
52 |
-
if is_dist_avail_and_initialized():
|
53 |
-
dist.barrier()
|
54 |
-
|
55 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
56 |
-
logging.info("Building datasets...")
|
57 |
-
datasets = self.build() # dataset['train'/'val'/'test']
|
58 |
-
|
59 |
-
return datasets
|
60 |
-
|
61 |
-
def build_processors(self):
|
62 |
-
vis_proc_cfg = self.config.get("vis_processor")
|
63 |
-
txt_proc_cfg = self.config.get("text_processor")
|
64 |
-
|
65 |
-
if vis_proc_cfg is not None:
|
66 |
-
vis_train_cfg = vis_proc_cfg.get("train")
|
67 |
-
vis_eval_cfg = vis_proc_cfg.get("eval")
|
68 |
-
|
69 |
-
self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
|
70 |
-
self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
|
71 |
-
|
72 |
-
if txt_proc_cfg is not None:
|
73 |
-
txt_train_cfg = txt_proc_cfg.get("train")
|
74 |
-
txt_eval_cfg = txt_proc_cfg.get("eval")
|
75 |
-
|
76 |
-
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
77 |
-
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
78 |
-
|
79 |
-
@staticmethod
|
80 |
-
def _build_proc_from_cfg(cfg):
|
81 |
-
return (
|
82 |
-
registry.get_processor_class(cfg.name).from_config(cfg)
|
83 |
-
if cfg is not None
|
84 |
-
else None
|
85 |
-
)
|
86 |
-
|
87 |
-
@classmethod
|
88 |
-
def default_config_path(cls, type="default"):
|
89 |
-
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
90 |
-
|
91 |
-
def _download_data(self):
|
92 |
-
self._download_ann()
|
93 |
-
self._download_vis()
|
94 |
-
|
95 |
-
def _download_ann(self):
|
96 |
-
"""
|
97 |
-
Download annotation files if necessary.
|
98 |
-
All the vision-language datasets should have annotations of unified format.
|
99 |
-
|
100 |
-
storage_path can be:
|
101 |
-
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
102 |
-
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
103 |
-
|
104 |
-
Local annotation paths should be relative.
|
105 |
-
"""
|
106 |
-
anns = self.config.build_info.annotations
|
107 |
-
|
108 |
-
splits = anns.keys()
|
109 |
-
|
110 |
-
cache_root = registry.get_path("cache_root")
|
111 |
-
|
112 |
-
for split in splits:
|
113 |
-
info = anns[split]
|
114 |
-
|
115 |
-
urls, storage_paths = info.get("url", None), info.storage
|
116 |
-
|
117 |
-
if isinstance(urls, str):
|
118 |
-
urls = [urls]
|
119 |
-
if isinstance(storage_paths, str):
|
120 |
-
storage_paths = [storage_paths]
|
121 |
-
|
122 |
-
assert len(urls) == len(storage_paths)
|
123 |
-
|
124 |
-
for url_or_filename, storage_path in zip(urls, storage_paths):
|
125 |
-
# if storage_path is relative, make it full by prefixing with cache_root.
|
126 |
-
if not os.path.isabs(storage_path):
|
127 |
-
storage_path = os.path.join(cache_root, storage_path)
|
128 |
-
|
129 |
-
dirname = os.path.dirname(storage_path)
|
130 |
-
if not os.path.exists(dirname):
|
131 |
-
os.makedirs(dirname)
|
132 |
-
|
133 |
-
if os.path.isfile(url_or_filename):
|
134 |
-
src, dst = url_or_filename, storage_path
|
135 |
-
if not os.path.exists(dst):
|
136 |
-
shutil.copyfile(src=src, dst=dst)
|
137 |
-
else:
|
138 |
-
logging.info("Using existing file {}.".format(dst))
|
139 |
-
else:
|
140 |
-
if os.path.isdir(storage_path):
|
141 |
-
# if only dirname is provided, suffix with basename of URL.
|
142 |
-
raise ValueError(
|
143 |
-
"Expecting storage_path to be a file path, got directory {}".format(
|
144 |
-
storage_path
|
145 |
-
)
|
146 |
-
)
|
147 |
-
else:
|
148 |
-
filename = os.path.basename(storage_path)
|
149 |
-
|
150 |
-
download_url(url=url_or_filename, root=dirname, filename=filename)
|
151 |
-
|
152 |
-
def _download_vis(self):
|
153 |
-
|
154 |
-
storage_path = self.config.build_info.get(self.data_type).storage
|
155 |
-
storage_path = utils.get_cache_path(storage_path)
|
156 |
-
|
157 |
-
if not os.path.exists(storage_path):
|
158 |
-
warnings.warn(
|
159 |
-
f"""
|
160 |
-
The specified path {storage_path} for visual inputs does not exist.
|
161 |
-
Please provide a correct path to the visual inputs or
|
162 |
-
refer to datasets/download_scripts/README.md for downloading instructions.
|
163 |
-
"""
|
164 |
-
)
|
165 |
-
|
166 |
-
def build(self):
|
167 |
-
"""
|
168 |
-
Create by split datasets inheriting torch.utils.data.Datasets.
|
169 |
-
|
170 |
-
# build() can be dataset-specific. Overwrite to customize.
|
171 |
-
"""
|
172 |
-
self.build_processors()
|
173 |
-
|
174 |
-
build_info = self.config.build_info
|
175 |
-
|
176 |
-
ann_info = build_info.annotations
|
177 |
-
vis_info = build_info.get(self.data_type)
|
178 |
-
|
179 |
-
datasets = dict()
|
180 |
-
for split in ann_info.keys():
|
181 |
-
if split not in ["train", "val", "test"]:
|
182 |
-
continue
|
183 |
-
|
184 |
-
is_train = split == "train"
|
185 |
-
|
186 |
-
# processors
|
187 |
-
vis_processor = (
|
188 |
-
self.vis_processors["train"]
|
189 |
-
if is_train
|
190 |
-
else self.vis_processors["eval"]
|
191 |
-
)
|
192 |
-
text_processor = (
|
193 |
-
self.text_processors["train"]
|
194 |
-
if is_train
|
195 |
-
else self.text_processors["eval"]
|
196 |
-
)
|
197 |
-
|
198 |
-
# annotation path
|
199 |
-
ann_paths = ann_info.get(split).storage
|
200 |
-
if isinstance(ann_paths, str):
|
201 |
-
ann_paths = [ann_paths]
|
202 |
-
|
203 |
-
abs_ann_paths = []
|
204 |
-
for ann_path in ann_paths:
|
205 |
-
if not os.path.isabs(ann_path):
|
206 |
-
ann_path = utils.get_cache_path(ann_path)
|
207 |
-
abs_ann_paths.append(ann_path)
|
208 |
-
ann_paths = abs_ann_paths
|
209 |
-
|
210 |
-
# visual data storage path
|
211 |
-
vis_path = os.path.join(vis_info.storage, split)
|
212 |
-
|
213 |
-
if not os.path.isabs(vis_path):
|
214 |
-
# vis_path = os.path.join(utils.get_cache_path(), vis_path)
|
215 |
-
vis_path = utils.get_cache_path(vis_path)
|
216 |
-
|
217 |
-
if not os.path.exists(vis_path):
|
218 |
-
warnings.warn("storage path {} does not exist.".format(vis_path))
|
219 |
-
|
220 |
-
# create datasets
|
221 |
-
dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
222 |
-
datasets[split] = dataset_cls(
|
223 |
-
vis_processor=vis_processor,
|
224 |
-
text_processor=text_processor,
|
225 |
-
ann_paths=ann_paths,
|
226 |
-
vis_root=vis_path,
|
227 |
-
)
|
228 |
-
|
229 |
-
return datasets
|
230 |
-
|
231 |
-
|
232 |
-
def load_dataset_config(cfg_path):
|
233 |
-
cfg = OmegaConf.load(cfg_path).datasets
|
234 |
-
cfg = cfg[list(cfg.keys())[0]]
|
235 |
-
|
236 |
-
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/builders/image_text_pair_builder.py
DELETED
@@ -1,535 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
import warnings
|
4 |
-
|
5 |
-
from minigpt4.common.registry import registry
|
6 |
-
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
7 |
-
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
|
8 |
-
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
|
9 |
-
from minigpt4.datasets.datasets.text_caps import TextCapDataset
|
10 |
-
from minigpt4.datasets.datasets.llava_dataset import LlavaDetailDataset, LlavaReasonDataset, LlavaConversationDataset
|
11 |
-
from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset
|
12 |
-
from minigpt4.datasets.datasets.multitask_conversation import MultiTaskConversationDataset
|
13 |
-
from minigpt4.datasets.datasets.flickr import GroundedDetailDataset,CaptionToObjectDataset,PhraseToObjectDataset
|
14 |
-
from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
|
15 |
-
from minigpt4.datasets.datasets.coco_dataset import ReferCOCODataset, InvReferCOCODataset
|
16 |
-
from minigpt4.datasets.datasets.gqa_datasets import GQADataset
|
17 |
-
from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset
|
18 |
-
from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
|
19 |
-
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
|
20 |
-
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
|
21 |
-
|
22 |
-
|
23 |
-
@registry.register_builder("multitask_conversation")
|
24 |
-
class MultitaskConversationBuilder(BaseDatasetBuilder):
|
25 |
-
train_dataset_cls = MultiTaskConversationDataset
|
26 |
-
DATASET_CONFIG_DICT = {
|
27 |
-
"default": "configs/datasets/multitask_conversation/default.yaml",
|
28 |
-
}
|
29 |
-
|
30 |
-
def build_datasets(self):
|
31 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
32 |
-
logging.info("Building datasets...")
|
33 |
-
self.build_processors()
|
34 |
-
build_info = self.config.build_info
|
35 |
-
datasets = dict()
|
36 |
-
|
37 |
-
# create datasets
|
38 |
-
dataset_cls = self.train_dataset_cls
|
39 |
-
datasets['train'] = dataset_cls(
|
40 |
-
vis_processor=self.vis_processors["train"],
|
41 |
-
text_processor=self.text_processors["train"],
|
42 |
-
ann_path=build_info.ann_path,
|
43 |
-
vis_root=build_info.image_path,
|
44 |
-
)
|
45 |
-
|
46 |
-
return datasets
|
47 |
-
|
48 |
-
|
49 |
-
@registry.register_builder("unnatural_instruction")
|
50 |
-
class UnnaturalInstructionBuilder(BaseDatasetBuilder):
|
51 |
-
train_dataset_cls = UnnaturalDataset
|
52 |
-
DATASET_CONFIG_DICT = {
|
53 |
-
"default": "configs/datasets/nlp/unnatural_instruction.yaml",
|
54 |
-
}
|
55 |
-
|
56 |
-
def build_datasets(self):
|
57 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
58 |
-
logging.info("Building datasets...")
|
59 |
-
self.build_processors()
|
60 |
-
build_info = self.config.build_info
|
61 |
-
datasets = dict()
|
62 |
-
|
63 |
-
# create datasets
|
64 |
-
dataset_cls = self.train_dataset_cls
|
65 |
-
datasets['train'] = dataset_cls(
|
66 |
-
text_processor=self.text_processors["train"],
|
67 |
-
ann_path=build_info.ann_path,
|
68 |
-
)
|
69 |
-
|
70 |
-
return datasets
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
@registry.register_builder("llava_detail")
|
75 |
-
class LlavaDetailBuilder(BaseDatasetBuilder):
|
76 |
-
train_dataset_cls = LlavaDetailDataset
|
77 |
-
DATASET_CONFIG_DICT = {
|
78 |
-
"default": "configs/datasets/llava/detail.yaml",
|
79 |
-
}
|
80 |
-
|
81 |
-
def build_datasets(self):
|
82 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
83 |
-
logging.info("Building datasets...")
|
84 |
-
self.build_processors()
|
85 |
-
build_info = self.config.build_info
|
86 |
-
datasets = dict()
|
87 |
-
|
88 |
-
# create datasets
|
89 |
-
dataset_cls = self.train_dataset_cls
|
90 |
-
datasets['train'] = dataset_cls(
|
91 |
-
vis_processor=self.vis_processors["train"],
|
92 |
-
text_processor=self.text_processors["train"],
|
93 |
-
ann_path=build_info.ann_path,
|
94 |
-
vis_root=build_info.image_path,
|
95 |
-
)
|
96 |
-
|
97 |
-
return datasets
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
@registry.register_builder("llava_reason")
|
102 |
-
class LlavaReasonBuilder(BaseDatasetBuilder):
|
103 |
-
train_dataset_cls = LlavaReasonDataset
|
104 |
-
DATASET_CONFIG_DICT = {
|
105 |
-
"default": "configs/datasets/llava/reason.yaml",
|
106 |
-
}
|
107 |
-
|
108 |
-
def build_datasets(self):
|
109 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
110 |
-
logging.info("Building datasets...")
|
111 |
-
self.build_processors()
|
112 |
-
build_info = self.config.build_info
|
113 |
-
datasets = dict()
|
114 |
-
|
115 |
-
# create datasets
|
116 |
-
dataset_cls = self.train_dataset_cls
|
117 |
-
datasets['train'] = dataset_cls(
|
118 |
-
vis_processor=self.vis_processors["train"],
|
119 |
-
text_processor=self.text_processors["train"],
|
120 |
-
ann_path=build_info.ann_path,
|
121 |
-
vis_root=build_info.image_path,
|
122 |
-
)
|
123 |
-
|
124 |
-
return datasets
|
125 |
-
|
126 |
-
@registry.register_builder("llava_conversation")
|
127 |
-
class LlavaReasonBuilder(BaseDatasetBuilder):
|
128 |
-
train_dataset_cls = LlavaConversationDataset
|
129 |
-
DATASET_CONFIG_DICT = {
|
130 |
-
"default": "configs/datasets/llava/conversation.yaml",
|
131 |
-
}
|
132 |
-
|
133 |
-
def build_datasets(self):
|
134 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
135 |
-
logging.info("Building datasets...")
|
136 |
-
self.build_processors()
|
137 |
-
build_info = self.config.build_info
|
138 |
-
datasets = dict()
|
139 |
-
|
140 |
-
# create datasets
|
141 |
-
dataset_cls = self.train_dataset_cls
|
142 |
-
datasets['train'] = dataset_cls(
|
143 |
-
vis_processor=self.vis_processors["train"],
|
144 |
-
text_processor=self.text_processors["train"],
|
145 |
-
ann_path=build_info.ann_path,
|
146 |
-
vis_root=build_info.image_path,
|
147 |
-
)
|
148 |
-
|
149 |
-
return datasets
|
150 |
-
|
151 |
-
|
152 |
-
class AllRefCOCOBuilder(BaseDatasetBuilder):
|
153 |
-
|
154 |
-
def build_datasets(self):
|
155 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
156 |
-
logging.info("Building datasets...")
|
157 |
-
self.build_processors()
|
158 |
-
|
159 |
-
build_info = self.config.build_info
|
160 |
-
image_path = build_info.image_path
|
161 |
-
ann_path = build_info.ann_path
|
162 |
-
|
163 |
-
datasets = dict()
|
164 |
-
|
165 |
-
if not os.path.exists(image_path):
|
166 |
-
warnings.warn("image path {} does not exist.".format(image_path))
|
167 |
-
if not os.path.exists(ann_path):
|
168 |
-
warnings.warn("ann path {} does not exist.".format(ann_path))
|
169 |
-
|
170 |
-
# create datasets
|
171 |
-
dataset_cls = self.train_dataset_cls
|
172 |
-
datasets['train'] = dataset_cls(
|
173 |
-
vis_processor=self.vis_processors["train"],
|
174 |
-
text_processor=self.text_processors["train"],
|
175 |
-
ann_path=ann_path,
|
176 |
-
vis_root=image_path,
|
177 |
-
dataset=build_info.dataset,
|
178 |
-
splitBy=build_info.splitBy
|
179 |
-
)
|
180 |
-
|
181 |
-
return datasets
|
182 |
-
|
183 |
-
|
184 |
-
@registry.register_builder("refcoco")
|
185 |
-
class RefCOCOBuilder(AllRefCOCOBuilder):
|
186 |
-
train_dataset_cls = ReferCOCODataset
|
187 |
-
DATASET_CONFIG_DICT = {
|
188 |
-
"default": "configs/datasets/coco_bbox/refcoco.yaml",
|
189 |
-
}
|
190 |
-
|
191 |
-
@registry.register_builder("refcocop")
|
192 |
-
class RefCOCOPBuilder(AllRefCOCOBuilder):
|
193 |
-
train_dataset_cls = ReferCOCODataset
|
194 |
-
DATASET_CONFIG_DICT = {
|
195 |
-
"default": "configs/datasets/coco_bbox/refcocop.yaml",
|
196 |
-
}
|
197 |
-
|
198 |
-
|
199 |
-
@registry.register_builder("refcocog")
|
200 |
-
class RefCOCOGBuilder(AllRefCOCOBuilder):
|
201 |
-
train_dataset_cls = ReferCOCODataset
|
202 |
-
DATASET_CONFIG_DICT = {
|
203 |
-
"default": "configs/datasets/coco_bbox/refcocog.yaml",
|
204 |
-
}
|
205 |
-
|
206 |
-
@registry.register_builder("invrefcoco")
|
207 |
-
class RefCOCOBuilder(AllRefCOCOBuilder):
|
208 |
-
train_dataset_cls = InvReferCOCODataset
|
209 |
-
DATASET_CONFIG_DICT = {
|
210 |
-
"default": "configs/datasets/coco_bbox/invrefcoco.yaml",
|
211 |
-
}
|
212 |
-
|
213 |
-
|
214 |
-
@registry.register_builder("invrefcocop")
|
215 |
-
class RefCOCOPBuilder(AllRefCOCOBuilder):
|
216 |
-
train_dataset_cls = InvReferCOCODataset
|
217 |
-
DATASET_CONFIG_DICT = {
|
218 |
-
"default": "configs/datasets/coco_bbox/invrefcocop.yaml",
|
219 |
-
}
|
220 |
-
|
221 |
-
|
222 |
-
@registry.register_builder("invrefcocog")
|
223 |
-
class RefCOCOGBuilder(AllRefCOCOBuilder):
|
224 |
-
train_dataset_cls = InvReferCOCODataset
|
225 |
-
DATASET_CONFIG_DICT = {
|
226 |
-
"default": "configs/datasets/coco_bbox/invrefcocog.yaml",
|
227 |
-
}
|
228 |
-
|
229 |
-
@registry.register_builder("refvg")
|
230 |
-
class RefVisualGenomeBuilder(BaseDatasetBuilder):
|
231 |
-
train_dataset_cls = ReferVisualGenomeDataset
|
232 |
-
DATASET_CONFIG_DICT = {
|
233 |
-
"default": "configs/datasets/vg/ref.yaml",
|
234 |
-
}
|
235 |
-
|
236 |
-
def build_datasets(self):
|
237 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
238 |
-
logging.info("Building datasets...")
|
239 |
-
self.build_processors()
|
240 |
-
|
241 |
-
build_info = self.config.build_info
|
242 |
-
data_dir = build_info.data_dir
|
243 |
-
datasets = dict()
|
244 |
-
|
245 |
-
# create datasets
|
246 |
-
dataset_cls = self.train_dataset_cls
|
247 |
-
datasets['train'] = dataset_cls(
|
248 |
-
vis_processor=self.vis_processors["train"],
|
249 |
-
text_processor=self.text_processors["train"],
|
250 |
-
data_dir=data_dir,
|
251 |
-
)
|
252 |
-
|
253 |
-
return datasets
|
254 |
-
|
255 |
-
|
256 |
-
@registry.register_builder("textcaps_caption")
|
257 |
-
class TextcapCaptionBuilder(BaseDatasetBuilder):
|
258 |
-
train_dataset_cls = TextCapDataset
|
259 |
-
|
260 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/textcaps/caption.yaml"}
|
261 |
-
|
262 |
-
def _download_ann(self):
|
263 |
-
pass
|
264 |
-
|
265 |
-
def _download_vis(self):
|
266 |
-
pass
|
267 |
-
|
268 |
-
def build(self):
|
269 |
-
self.build_processors()
|
270 |
-
|
271 |
-
build_info = self.config.build_info
|
272 |
-
|
273 |
-
datasets = dict()
|
274 |
-
split = "train"
|
275 |
-
|
276 |
-
# create datasets
|
277 |
-
# [NOTE] return inner_datasets (wds.DataPipeline)
|
278 |
-
dataset_cls = self.train_dataset_cls
|
279 |
-
datasets[split] = dataset_cls(
|
280 |
-
vis_processor=self.vis_processors[split],
|
281 |
-
text_processor=self.text_processors[split],
|
282 |
-
ann_path=build_info.ann_path,
|
283 |
-
vis_root=build_info.image_path,
|
284 |
-
)
|
285 |
-
|
286 |
-
return datasets
|
287 |
-
|
288 |
-
@registry.register_builder("coco_vqa")
|
289 |
-
class COCOVQABuilder(BaseDatasetBuilder):
|
290 |
-
train_dataset_cls = COCOVQADataset
|
291 |
-
|
292 |
-
DATASET_CONFIG_DICT = {
|
293 |
-
"default": "configs/datasets/coco/defaults_vqa.yaml",
|
294 |
-
}
|
295 |
-
|
296 |
-
@registry.register_builder("ok_vqa")
|
297 |
-
class OKVQABuilder(COCOVQABuilder):
|
298 |
-
DATASET_CONFIG_DICT = {
|
299 |
-
"default": "configs/datasets/okvqa/defaults.yaml",
|
300 |
-
}
|
301 |
-
|
302 |
-
|
303 |
-
@registry.register_builder("aok_vqa")
|
304 |
-
class AOKVQABuilder(BaseDatasetBuilder):
|
305 |
-
train_dataset_cls = AOKVQADataset
|
306 |
-
|
307 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"}
|
308 |
-
|
309 |
-
|
310 |
-
@registry.register_builder("gqa")
|
311 |
-
class GQABuilder(BaseDatasetBuilder):
|
312 |
-
train_dataset_cls = GQADataset
|
313 |
-
DATASET_CONFIG_DICT = {
|
314 |
-
"default": "configs/datasets/gqa/balanced_val.yaml",
|
315 |
-
}
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
@registry.register_builder("flickr_grounded_caption")
|
321 |
-
class GroundedCaptionBuilder(BaseDatasetBuilder):
|
322 |
-
train_dataset_cls = GroundedDetailDataset
|
323 |
-
DATASET_CONFIG_DICT = {
|
324 |
-
"default": "configs/datasets/flickr/default.yaml",
|
325 |
-
}
|
326 |
-
|
327 |
-
def build_datasets(self):
|
328 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
329 |
-
logging.info("Building datasets...")
|
330 |
-
self.build_processors()
|
331 |
-
build_info = self.config.build_info
|
332 |
-
datasets = dict()
|
333 |
-
|
334 |
-
# create datasets
|
335 |
-
dataset_cls = self.train_dataset_cls
|
336 |
-
datasets['train'] = dataset_cls(
|
337 |
-
vis_processor=self.vis_processors["train"],
|
338 |
-
text_processor=self.text_processors["train"],
|
339 |
-
ann_path=build_info.ann_path,
|
340 |
-
vis_root=build_info.image_path,
|
341 |
-
)
|
342 |
-
|
343 |
-
return datasets
|
344 |
-
|
345 |
-
|
346 |
-
@registry.register_builder("flickr_CaptionToPhrase")
|
347 |
-
class CaptionToPhraseBuilder(BaseDatasetBuilder):
|
348 |
-
train_dataset_cls = CaptionToObjectDataset
|
349 |
-
DATASET_CONFIG_DICT = {
|
350 |
-
"default": "configs/datasets/flickr/caption_to_phrase.yaml",
|
351 |
-
}
|
352 |
-
|
353 |
-
def build_datasets(self):
|
354 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
355 |
-
logging.info("Building datasets...")
|
356 |
-
self.build_processors()
|
357 |
-
build_info = self.config.build_info
|
358 |
-
datasets = dict()
|
359 |
-
|
360 |
-
# create datasets
|
361 |
-
dataset_cls = self.train_dataset_cls
|
362 |
-
datasets['train'] = dataset_cls(
|
363 |
-
vis_processor=self.vis_processors["train"],
|
364 |
-
text_processor=self.text_processors["train"],
|
365 |
-
ann_path=build_info.ann_path,
|
366 |
-
vis_root=build_info.image_path,
|
367 |
-
)
|
368 |
-
|
369 |
-
return datasets
|
370 |
-
|
371 |
-
@registry.register_builder("flickr_ObjectToPhrase")
|
372 |
-
class CaptionToPhraseBuilder(BaseDatasetBuilder):
|
373 |
-
train_dataset_cls = PhraseToObjectDataset
|
374 |
-
DATASET_CONFIG_DICT = {
|
375 |
-
"default": "configs/datasets/flickr/object_to_phrase.yaml",
|
376 |
-
}
|
377 |
-
|
378 |
-
def build_datasets(self):
|
379 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
380 |
-
logging.info("Building datasets...")
|
381 |
-
self.build_processors()
|
382 |
-
build_info = self.config.build_info
|
383 |
-
datasets = dict()
|
384 |
-
|
385 |
-
# create datasets
|
386 |
-
dataset_cls = self.train_dataset_cls
|
387 |
-
datasets['train'] = dataset_cls(
|
388 |
-
vis_processor=self.vis_processors["train"],
|
389 |
-
text_processor=self.text_processors["train"],
|
390 |
-
ann_path=build_info.ann_path,
|
391 |
-
vis_root=build_info.image_path,
|
392 |
-
)
|
393 |
-
|
394 |
-
return datasets
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
class DocumentVQABuilder(BaseDatasetBuilder):
|
400 |
-
def _download_ann(self):
|
401 |
-
pass
|
402 |
-
|
403 |
-
def _download_vis(self):
|
404 |
-
pass
|
405 |
-
|
406 |
-
def build(self):
|
407 |
-
self.build_processors()
|
408 |
-
build_info = self.config.build_info
|
409 |
-
|
410 |
-
datasets = dict()
|
411 |
-
split = "train"
|
412 |
-
|
413 |
-
dataset_cls = self.train_dataset_cls
|
414 |
-
datasets[split] = dataset_cls(
|
415 |
-
vis_processor=self.vis_processors[split],
|
416 |
-
text_processor=self.text_processors[split],
|
417 |
-
vis_root=build_info.image_path,
|
418 |
-
ann_path=build_info.ann_path
|
419 |
-
)
|
420 |
-
|
421 |
-
return datasets
|
422 |
-
|
423 |
-
|
424 |
-
@registry.register_builder("ocrvqa")
|
425 |
-
class OCRVQABuilder(DocumentVQABuilder):
|
426 |
-
train_dataset_cls = OCRVQADataset
|
427 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/ocrvqa/ocrvqa.yaml"}
|
428 |
-
|
429 |
-
|
430 |
-
@registry.register_builder("cc_sbu")
|
431 |
-
class CCSBUBuilder(BaseDatasetBuilder):
|
432 |
-
train_dataset_cls = CCSBUDataset
|
433 |
-
|
434 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
435 |
-
|
436 |
-
def _download_ann(self):
|
437 |
-
pass
|
438 |
-
|
439 |
-
def _download_vis(self):
|
440 |
-
pass
|
441 |
-
|
442 |
-
def build(self):
|
443 |
-
self.build_processors()
|
444 |
-
|
445 |
-
build_info = self.config.build_info
|
446 |
-
|
447 |
-
datasets = dict()
|
448 |
-
split = "train"
|
449 |
-
|
450 |
-
# create datasets
|
451 |
-
# [NOTE] return inner_datasets (wds.DataPipeline)
|
452 |
-
dataset_cls = self.train_dataset_cls
|
453 |
-
datasets[split] = dataset_cls(
|
454 |
-
vis_processor=self.vis_processors[split],
|
455 |
-
text_processor=self.text_processors[split],
|
456 |
-
location=build_info.storage,
|
457 |
-
).inner_dataset
|
458 |
-
|
459 |
-
return datasets
|
460 |
-
|
461 |
-
|
462 |
-
@registry.register_builder("laion")
|
463 |
-
class LaionBuilder(BaseDatasetBuilder):
|
464 |
-
train_dataset_cls = LaionDataset
|
465 |
-
|
466 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
|
467 |
-
|
468 |
-
def _download_ann(self):
|
469 |
-
pass
|
470 |
-
|
471 |
-
def _download_vis(self):
|
472 |
-
pass
|
473 |
-
|
474 |
-
def build(self):
|
475 |
-
self.build_processors()
|
476 |
-
|
477 |
-
build_info = self.config.build_info
|
478 |
-
|
479 |
-
datasets = dict()
|
480 |
-
split = "train"
|
481 |
-
|
482 |
-
# create datasets
|
483 |
-
# [NOTE] return inner_datasets (wds.DataPipeline)
|
484 |
-
dataset_cls = self.train_dataset_cls
|
485 |
-
datasets[split] = dataset_cls(
|
486 |
-
vis_processor=self.vis_processors[split],
|
487 |
-
text_processor=self.text_processors[split],
|
488 |
-
location=build_info.storage,
|
489 |
-
).inner_dataset
|
490 |
-
|
491 |
-
return datasets
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
@registry.register_builder("coco_caption")
|
496 |
-
class COCOCapBuilder(BaseDatasetBuilder):
|
497 |
-
train_dataset_cls = COCOCapDataset
|
498 |
-
|
499 |
-
DATASET_CONFIG_DICT = {
|
500 |
-
"default": "configs/datasets/coco/caption.yaml",
|
501 |
-
}
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
@registry.register_builder("cc_sbu_align")
|
506 |
-
class CCSBUAlignBuilder(BaseDatasetBuilder):
|
507 |
-
train_dataset_cls = CCSBUAlignDataset
|
508 |
-
|
509 |
-
DATASET_CONFIG_DICT = {
|
510 |
-
"default": "configs/datasets/cc_sbu/align.yaml",
|
511 |
-
}
|
512 |
-
|
513 |
-
def build_datasets(self):
|
514 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
515 |
-
logging.info("Building datasets...")
|
516 |
-
self.build_processors()
|
517 |
-
|
518 |
-
build_info = self.config.build_info
|
519 |
-
storage_path = build_info.storage
|
520 |
-
|
521 |
-
datasets = dict()
|
522 |
-
|
523 |
-
if not os.path.exists(storage_path):
|
524 |
-
warnings.warn("storage path {} does not exist.".format(storage_path))
|
525 |
-
|
526 |
-
# create datasets
|
527 |
-
dataset_cls = self.train_dataset_cls
|
528 |
-
datasets['train'] = dataset_cls(
|
529 |
-
vis_processor=self.vis_processors["train"],
|
530 |
-
text_processor=self.text_processors["train"],
|
531 |
-
ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
|
532 |
-
vis_root=os.path.join(storage_path, 'image'),
|
533 |
-
)
|
534 |
-
|
535 |
-
return datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/data_utils.py
DELETED
@@ -1,199 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import gzip
|
9 |
-
import logging
|
10 |
-
import os
|
11 |
-
import random as rnd
|
12 |
-
import tarfile
|
13 |
-
import zipfile
|
14 |
-
import random
|
15 |
-
from typing import List
|
16 |
-
from tqdm import tqdm
|
17 |
-
|
18 |
-
import decord
|
19 |
-
from decord import VideoReader
|
20 |
-
import webdataset as wds
|
21 |
-
import numpy as np
|
22 |
-
import torch
|
23 |
-
from torch.utils.data.dataset import IterableDataset
|
24 |
-
|
25 |
-
from minigpt4.common.registry import registry
|
26 |
-
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
|
27 |
-
|
28 |
-
|
29 |
-
decord.bridge.set_bridge("torch")
|
30 |
-
MAX_INT = registry.get("MAX_INT")
|
31 |
-
|
32 |
-
|
33 |
-
class ChainDataset(wds.DataPipeline):
|
34 |
-
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
35 |
-
|
36 |
-
This class is useful to assemble different existing dataset streams. The
|
37 |
-
chaining operation is done on-the-fly, so concatenating large-scale
|
38 |
-
datasets with this class will be efficient.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
datasets (iterable of IterableDataset): datasets to be chained together
|
42 |
-
"""
|
43 |
-
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
44 |
-
super().__init__()
|
45 |
-
self.datasets = datasets
|
46 |
-
self.prob = []
|
47 |
-
self.names = []
|
48 |
-
for dataset in self.datasets:
|
49 |
-
if hasattr(dataset, 'name'):
|
50 |
-
self.names.append(dataset.name)
|
51 |
-
else:
|
52 |
-
self.names.append('Unknown')
|
53 |
-
if hasattr(dataset, 'sample_ratio'):
|
54 |
-
self.prob.append(dataset.sample_ratio)
|
55 |
-
else:
|
56 |
-
self.prob.append(1)
|
57 |
-
logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
|
58 |
-
|
59 |
-
def __iter__(self):
|
60 |
-
datastreams = [iter(dataset) for dataset in self.datasets]
|
61 |
-
while True:
|
62 |
-
select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
|
63 |
-
yield next(select_datastream)
|
64 |
-
|
65 |
-
|
66 |
-
def apply_to_sample(f, sample):
|
67 |
-
if len(sample) == 0:
|
68 |
-
return {}
|
69 |
-
|
70 |
-
def _apply(x):
|
71 |
-
if torch.is_tensor(x):
|
72 |
-
return f(x)
|
73 |
-
elif isinstance(x, dict):
|
74 |
-
return {key: _apply(value) for key, value in x.items()}
|
75 |
-
elif isinstance(x, list):
|
76 |
-
return [_apply(x) for x in x]
|
77 |
-
else:
|
78 |
-
return x
|
79 |
-
|
80 |
-
return _apply(sample)
|
81 |
-
|
82 |
-
|
83 |
-
def move_to_cuda(sample):
|
84 |
-
def _move_to_cuda(tensor):
|
85 |
-
return tensor.cuda()
|
86 |
-
|
87 |
-
return apply_to_sample(_move_to_cuda, sample)
|
88 |
-
|
89 |
-
|
90 |
-
def prepare_sample(samples, cuda_enabled=True):
|
91 |
-
if cuda_enabled:
|
92 |
-
samples = move_to_cuda(samples)
|
93 |
-
|
94 |
-
# TODO fp16 support
|
95 |
-
|
96 |
-
return samples
|
97 |
-
|
98 |
-
|
99 |
-
def reorg_datasets_by_split(datasets, batch_sizes):
|
100 |
-
"""
|
101 |
-
Organizes datasets by split.
|
102 |
-
|
103 |
-
Args:
|
104 |
-
datasets: dict of torch.utils.data.Dataset objects by name.
|
105 |
-
|
106 |
-
Returns:
|
107 |
-
Dict of datasets by split {split_name: List[Datasets]}.
|
108 |
-
"""
|
109 |
-
# if len(datasets) == 1:
|
110 |
-
# return datasets[list(datasets.keys())[0]]
|
111 |
-
# else:
|
112 |
-
reorg_datasets = dict()
|
113 |
-
reorg_batch_sizes = dict()
|
114 |
-
|
115 |
-
# reorganize by split
|
116 |
-
for dataset_name, dataset in datasets.items():
|
117 |
-
for split_name, dataset_split in dataset.items():
|
118 |
-
if split_name not in reorg_datasets:
|
119 |
-
reorg_datasets[split_name] = [dataset_split]
|
120 |
-
reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
|
121 |
-
else:
|
122 |
-
reorg_datasets[split_name].append(dataset_split)
|
123 |
-
reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
|
124 |
-
|
125 |
-
return reorg_datasets, reorg_batch_sizes
|
126 |
-
|
127 |
-
|
128 |
-
def concat_datasets(datasets):
|
129 |
-
"""
|
130 |
-
Concatenates multiple datasets into a single dataset.
|
131 |
-
|
132 |
-
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
|
133 |
-
generic IterableDataset because it requires creating separate samplers.
|
134 |
-
|
135 |
-
Now only supports conctenating training datasets and assuming validation and testing
|
136 |
-
have only a single dataset. This is because metrics should not be computed on the concatenated
|
137 |
-
datasets.
|
138 |
-
|
139 |
-
Args:
|
140 |
-
datasets: dict of torch.utils.data.Dataset objects by split.
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
|
144 |
-
"val" and "test" remain the same.
|
145 |
-
|
146 |
-
If the input training datasets contain both map-style and DataPipeline datasets, returns
|
147 |
-
a tuple, where the first element is a concatenated map-style dataset and the second
|
148 |
-
element is a chained DataPipeline dataset.
|
149 |
-
|
150 |
-
"""
|
151 |
-
# concatenate datasets in the same split
|
152 |
-
for split_name in datasets:
|
153 |
-
if split_name != "train":
|
154 |
-
assert (
|
155 |
-
len(datasets[split_name]) == 1
|
156 |
-
), "Do not support multiple {} datasets.".format(split_name)
|
157 |
-
datasets[split_name] = datasets[split_name][0]
|
158 |
-
else:
|
159 |
-
iterable_datasets, map_datasets = [], []
|
160 |
-
for dataset in datasets[split_name]:
|
161 |
-
if isinstance(dataset, wds.DataPipeline):
|
162 |
-
logging.info(
|
163 |
-
"Dataset {} is IterableDataset, can't be concatenated.".format(
|
164 |
-
dataset
|
165 |
-
)
|
166 |
-
)
|
167 |
-
iterable_datasets.append(dataset)
|
168 |
-
elif isinstance(dataset, IterableDataset):
|
169 |
-
raise NotImplementedError(
|
170 |
-
"Do not support concatenation of generic IterableDataset."
|
171 |
-
)
|
172 |
-
else:
|
173 |
-
map_datasets.append(dataset)
|
174 |
-
|
175 |
-
# if len(iterable_datasets) > 0:
|
176 |
-
# concatenate map-style datasets and iterable-style datasets separately
|
177 |
-
if len(iterable_datasets) > 1:
|
178 |
-
chained_datasets = (
|
179 |
-
ChainDataset(iterable_datasets)
|
180 |
-
)
|
181 |
-
elif len(iterable_datasets) == 1:
|
182 |
-
chained_datasets = iterable_datasets[0]
|
183 |
-
else:
|
184 |
-
chained_datasets = None
|
185 |
-
|
186 |
-
concat_datasets = (
|
187 |
-
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
188 |
-
)
|
189 |
-
|
190 |
-
train_datasets = concat_datasets, chained_datasets
|
191 |
-
train_datasets = tuple([x for x in train_datasets if x is not None])
|
192 |
-
train_datasets = (
|
193 |
-
train_datasets[0] if len(train_datasets) == 1 else train_datasets
|
194 |
-
)
|
195 |
-
|
196 |
-
datasets[split_name] = train_datasets
|
197 |
-
|
198 |
-
return datasets
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/__init__.py
DELETED
File without changes
|
minigpt4/datasets/datasets/aok_vqa_datasets.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
from collections import OrderedDict
|
9 |
-
import json
|
10 |
-
import os
|
11 |
-
import random
|
12 |
-
import torch
|
13 |
-
|
14 |
-
from PIL import Image
|
15 |
-
|
16 |
-
from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset
|
17 |
-
|
18 |
-
|
19 |
-
class __DisplMixin:
|
20 |
-
def displ_item(self, index):
|
21 |
-
sample, ann = self.__getitem__(index), self.annotation[index]
|
22 |
-
return OrderedDict(
|
23 |
-
{
|
24 |
-
"file": ann["image"],
|
25 |
-
"question": ann["question"],
|
26 |
-
"question_id": ann["question_id"],
|
27 |
-
"direct_answers": "; ".join(ann["direct_answers"]),
|
28 |
-
"choices": "; ".join(ann["choices"]),
|
29 |
-
"correct_choice": ann["choices"][ann["correct_choice_idx"]],
|
30 |
-
"image": sample["image"],
|
31 |
-
}
|
32 |
-
)
|
33 |
-
|
34 |
-
|
35 |
-
class AOKVQADataset(VQADataset, __DisplMixin):
|
36 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
37 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
38 |
-
|
39 |
-
self.instruction_pool =[
|
40 |
-
"[vqa] {}",
|
41 |
-
"[vqa] Based on the image, respond to this question with a short answer: {}"
|
42 |
-
]
|
43 |
-
|
44 |
-
exist_annotation = []
|
45 |
-
for ann in self.annotation:
|
46 |
-
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
|
47 |
-
if os.path.exists(image_path):
|
48 |
-
exist_annotation.append(ann)
|
49 |
-
self.annotation = exist_annotation
|
50 |
-
|
51 |
-
def get_data(self, index):
|
52 |
-
ann = self.annotation[index]
|
53 |
-
|
54 |
-
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
|
55 |
-
image = Image.open(image_path).convert("RGB")
|
56 |
-
|
57 |
-
image = self.vis_processor(image)
|
58 |
-
question = self.text_processor(ann["question"])
|
59 |
-
|
60 |
-
answer_key = "direct_answers"
|
61 |
-
|
62 |
-
answer_weight = {}
|
63 |
-
for answer in ann[answer_key]:
|
64 |
-
if answer in answer_weight.keys():
|
65 |
-
answer_weight[answer] += 1 / len(ann[answer_key])
|
66 |
-
else:
|
67 |
-
answer_weight[answer] = 1 / len(ann[answer_key])
|
68 |
-
|
69 |
-
answers = list(answer_weight.keys())
|
70 |
-
weights = list(answer_weight.values())
|
71 |
-
|
72 |
-
answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
|
73 |
-
|
74 |
-
return {
|
75 |
-
"image": image,
|
76 |
-
"question": question,
|
77 |
-
"answer": answer,
|
78 |
-
}
|
79 |
-
|
80 |
-
def __getitem__(self, index):
|
81 |
-
data = self.get_data(index)
|
82 |
-
question = self.text_processor(data["question"])
|
83 |
-
instruction = random.choice(self.instruction_pool).format(question)
|
84 |
-
|
85 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
86 |
-
answer = self.text_processor(data['answer'])
|
87 |
-
|
88 |
-
return {
|
89 |
-
"image": data['image'],
|
90 |
-
"instruction_input": instruction,
|
91 |
-
"answer": answer,
|
92 |
-
}
|
93 |
-
|
94 |
-
|
95 |
-
class AOKVQGDataset(AOKVQADataset):
|
96 |
-
|
97 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
98 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
99 |
-
self.instruction_pool = [
|
100 |
-
'Given the image, generate a question whose answer is: {}',
|
101 |
-
'Based on the image, provide a question with the answer: {}',
|
102 |
-
'Given the visual representation, create a question for which the answer is "{}"',
|
103 |
-
'From the image provided, craft a question that leads to the reply: {}',
|
104 |
-
'Considering the picture, come up with a question where the answer is: {}',
|
105 |
-
'Taking the image into account, generate an question that has the answer: {}'
|
106 |
-
]
|
107 |
-
|
108 |
-
def __getitem__(self, index):
|
109 |
-
data = self.get_data(index)
|
110 |
-
instruction = random.choice(self.instruction_pool).format(data['answer'])
|
111 |
-
|
112 |
-
return {
|
113 |
-
"image": data['image'],
|
114 |
-
"instruction_input": instruction,
|
115 |
-
"answer": data['question'],
|
116 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/base_dataset.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import json
|
9 |
-
from typing import Iterable
|
10 |
-
|
11 |
-
from torch.utils.data import Dataset, ConcatDataset
|
12 |
-
from torch.utils.data.dataloader import default_collate
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
class BaseDataset(Dataset):
|
18 |
-
def __init__(
|
19 |
-
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
20 |
-
):
|
21 |
-
"""
|
22 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
23 |
-
ann_root (string): directory to store the annotation file
|
24 |
-
"""
|
25 |
-
self.vis_root = vis_root
|
26 |
-
|
27 |
-
self.annotation = []
|
28 |
-
# print("ann paths", ann_paths)
|
29 |
-
for ann_path in ann_paths:
|
30 |
-
# print("ann_path", ann_path)
|
31 |
-
ann = json.load(open(ann_path, "r"))
|
32 |
-
if isinstance(ann, dict):
|
33 |
-
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
34 |
-
# self.annotation.extend(json.load(open(ann_path, "r")))
|
35 |
-
else:
|
36 |
-
self.annotation.extend(json.load(open(ann_path, "r")))
|
37 |
-
|
38 |
-
self.vis_processor = vis_processor
|
39 |
-
self.text_processor = text_processor
|
40 |
-
|
41 |
-
self._add_instance_ids()
|
42 |
-
|
43 |
-
def __len__(self):
|
44 |
-
return len(self.annotation)
|
45 |
-
|
46 |
-
def collater(self, samples):
|
47 |
-
return default_collate(samples)
|
48 |
-
|
49 |
-
def set_processors(self, vis_processor, text_processor):
|
50 |
-
self.vis_processor = vis_processor
|
51 |
-
self.text_processor = text_processor
|
52 |
-
|
53 |
-
def _add_instance_ids(self, key="instance_id"):
|
54 |
-
for idx, ann in enumerate(self.annotation):
|
55 |
-
ann[key] = str(idx)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
class ConcatDataset(ConcatDataset):
|
60 |
-
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
61 |
-
super().__init__(datasets)
|
62 |
-
|
63 |
-
def collater(self, samples):
|
64 |
-
# TODO For now only supports datasets with same underlying collater implementations
|
65 |
-
|
66 |
-
all_keys = set()
|
67 |
-
for s in samples:
|
68 |
-
all_keys.update(s)
|
69 |
-
|
70 |
-
shared_keys = all_keys
|
71 |
-
for s in samples:
|
72 |
-
shared_keys = shared_keys & set(s.keys())
|
73 |
-
|
74 |
-
samples_shared_keys = []
|
75 |
-
for s in samples:
|
76 |
-
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
77 |
-
|
78 |
-
return self.datasets[0].collater(samples_shared_keys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/caption_datasets.py
DELETED
@@ -1,151 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
from collections import OrderedDict
|
10 |
-
|
11 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
12 |
-
from PIL import Image
|
13 |
-
import random
|
14 |
-
|
15 |
-
|
16 |
-
class __DisplMixin:
|
17 |
-
def displ_item(self, index):
|
18 |
-
sample, ann = self.__getitem__(index), self.annotation[index]
|
19 |
-
|
20 |
-
return OrderedDict(
|
21 |
-
{
|
22 |
-
"file": ann["image"],
|
23 |
-
"caption": ann["caption"],
|
24 |
-
"image": sample["image"],
|
25 |
-
}
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
class CaptionDataset(BaseDataset, __DisplMixin):
|
30 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
31 |
-
"""
|
32 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
33 |
-
ann_root (string): directory to store the annotation file
|
34 |
-
"""
|
35 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
36 |
-
|
37 |
-
self.img_ids = {}
|
38 |
-
n = 0
|
39 |
-
for ann in self.annotation:
|
40 |
-
img_id = ann["image_id"]
|
41 |
-
if img_id not in self.img_ids.keys():
|
42 |
-
self.img_ids[img_id] = n
|
43 |
-
n += 1
|
44 |
-
|
45 |
-
def __getitem__(self, index):
|
46 |
-
|
47 |
-
# TODO this assumes image input, not general enough
|
48 |
-
ann = self.annotation[index]
|
49 |
-
|
50 |
-
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
51 |
-
image_path = os.path.join(self.vis_root, img_file)
|
52 |
-
image = Image.open(image_path).convert("RGB")
|
53 |
-
|
54 |
-
image = self.vis_processor(image)
|
55 |
-
caption = self.text_processor(ann["caption"])
|
56 |
-
|
57 |
-
return {
|
58 |
-
"image": image,
|
59 |
-
"text_input": caption,
|
60 |
-
"image_id": self.img_ids[ann["image_id"]],
|
61 |
-
}
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
class COCOCaptionDataset(BaseDataset, __DisplMixin):
|
66 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
67 |
-
"""
|
68 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
69 |
-
ann_root (string): directory to store the annotation file
|
70 |
-
"""
|
71 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
72 |
-
|
73 |
-
self.img_ids = {}
|
74 |
-
n = 0
|
75 |
-
|
76 |
-
self.filter_anntation = []
|
77 |
-
|
78 |
-
for ann in self.annotation:
|
79 |
-
if "train" in ann["image"]:
|
80 |
-
self.filter_anntation.append(ann)
|
81 |
-
self.annotation = self.filter_anntation
|
82 |
-
|
83 |
-
for ann in self.annotation:
|
84 |
-
img_id = ann["image_id"]
|
85 |
-
if img_id not in self.img_ids.keys():
|
86 |
-
self.img_ids[img_id] = n
|
87 |
-
n += 1
|
88 |
-
|
89 |
-
self.instruction_pool = [
|
90 |
-
'Briefly describe this image.',
|
91 |
-
'Provide a concise depiction of this image.',
|
92 |
-
'Present a short description of this image.',
|
93 |
-
'Summarize this image in a few words.',
|
94 |
-
'A short image caption:',
|
95 |
-
'A short image description:',
|
96 |
-
'A photo of ',
|
97 |
-
'An image that shows ',
|
98 |
-
'Write a short description for the image. ',
|
99 |
-
'Write a description for the photo.',
|
100 |
-
'Provide a description of what is presented in the photo.',
|
101 |
-
'Briefly describe the content of the image.',
|
102 |
-
'Can you briefly explain what you see in the image?',
|
103 |
-
'Could you use a few words to describe what you perceive in the photo?',
|
104 |
-
'Please provide a short depiction of the picture.',
|
105 |
-
'Using language, provide a short account of the image.',
|
106 |
-
'Use a few words to illustrate what is happening in the picture.',
|
107 |
-
]
|
108 |
-
def __getitem__(self, index):
|
109 |
-
|
110 |
-
# TODO this assumes image input, not general enough
|
111 |
-
ann = self.annotation[index]
|
112 |
-
|
113 |
-
img_file = ann["image"].split("/")[-1]
|
114 |
-
image_path = os.path.join(self.vis_root, img_file)
|
115 |
-
image = Image.open(image_path).convert("RGB")
|
116 |
-
|
117 |
-
image = self.vis_processor(image)
|
118 |
-
caption = self.text_processor(ann["caption"])
|
119 |
-
|
120 |
-
instruction = random.choice(self.instruction_pool)
|
121 |
-
instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
|
122 |
-
|
123 |
-
return {
|
124 |
-
"image": image,
|
125 |
-
"answer": caption,
|
126 |
-
"instruction_input": instruction,
|
127 |
-
}
|
128 |
-
|
129 |
-
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
130 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
131 |
-
"""
|
132 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
133 |
-
ann_root (string): directory to store the annotation file
|
134 |
-
split (string): val or test
|
135 |
-
"""
|
136 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
137 |
-
|
138 |
-
def __getitem__(self, index):
|
139 |
-
|
140 |
-
ann = self.annotation[index]
|
141 |
-
|
142 |
-
image_path = os.path.join(self.vis_root, ann["image"])
|
143 |
-
image = Image.open(image_path).convert("RGB")
|
144 |
-
|
145 |
-
image = self.vis_processor(image)
|
146 |
-
|
147 |
-
return {
|
148 |
-
"image": image,
|
149 |
-
"image_id": ann["image_id"],
|
150 |
-
"instance_id": ann["instance_id"],
|
151 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/cc_sbu_dataset.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from PIL import Image
|
3 |
-
import webdataset as wds
|
4 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
5 |
-
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
6 |
-
|
7 |
-
|
8 |
-
class CCSBUDataset(BaseDataset):
|
9 |
-
def __init__(self, vis_processor, text_processor, location):
|
10 |
-
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
11 |
-
|
12 |
-
self.inner_dataset = wds.DataPipeline(
|
13 |
-
wds.ResampledShards(location),
|
14 |
-
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
15 |
-
wds.shuffle(1000, handler=wds.warn_and_continue),
|
16 |
-
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
17 |
-
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
18 |
-
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
19 |
-
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
20 |
-
)
|
21 |
-
|
22 |
-
def to_dict(self, sample):
|
23 |
-
return {
|
24 |
-
"image": sample[0],
|
25 |
-
"answer": self.text_processor(sample[1]["caption"]),
|
26 |
-
}
|
27 |
-
|
28 |
-
|
29 |
-
class CCSBUAlignDataset(CaptionDataset):
|
30 |
-
|
31 |
-
def __getitem__(self, index):
|
32 |
-
|
33 |
-
# TODO this assumes image input, not general enough
|
34 |
-
ann = self.annotation[index]
|
35 |
-
|
36 |
-
img_file = '{}.jpg'.format(ann["image_id"])
|
37 |
-
image_path = os.path.join(self.vis_root, img_file)
|
38 |
-
image = Image.open(image_path).convert("RGB")
|
39 |
-
|
40 |
-
image = self.vis_processor(image)
|
41 |
-
caption = ann["caption"]
|
42 |
-
|
43 |
-
return {
|
44 |
-
"image": image,
|
45 |
-
"answer": caption,
|
46 |
-
"image_id": self.img_ids[ann["image_id"]],
|
47 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/coco_caption.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
import json
|
10 |
-
import torch
|
11 |
-
import numpy as np
|
12 |
-
|
13 |
-
from PIL import Image
|
14 |
-
from PIL import ImageFile
|
15 |
-
|
16 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
17 |
-
|
18 |
-
from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset
|
19 |
-
|
20 |
-
COCOCapDataset = COCOCaptionDataset
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
class COCOCapEvalDataset(CaptionEvalDataset):
|
27 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
28 |
-
"""
|
29 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
30 |
-
ann_root (string): directory to store the annotation file
|
31 |
-
split (string): val or test
|
32 |
-
"""
|
33 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
34 |
-
|
35 |
-
def __getitem__(self, index):
|
36 |
-
ann = self.annotation[index]
|
37 |
-
|
38 |
-
image_path = os.path.join(self.vis_root, ann["image"])
|
39 |
-
image = Image.open(image_path).convert("RGB")
|
40 |
-
|
41 |
-
image = self.vis_processor(image)
|
42 |
-
|
43 |
-
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
|
44 |
-
|
45 |
-
return {
|
46 |
-
"image": image,
|
47 |
-
"image_id": img_id,
|
48 |
-
"instance_id": ann["instance_id"],
|
49 |
-
}
|
50 |
-
|
51 |
-
|
52 |
-
class NoCapsEvalDataset(CaptionEvalDataset):
|
53 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
54 |
-
"""
|
55 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
56 |
-
ann_root (string): directory to store the annotation file
|
57 |
-
split (string): val or test
|
58 |
-
"""
|
59 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
60 |
-
|
61 |
-
def __getitem__(self, index):
|
62 |
-
ann = self.annotation[index]
|
63 |
-
|
64 |
-
image_path = os.path.join(self.vis_root, ann["image"])
|
65 |
-
image = Image.open(image_path).convert("RGB")
|
66 |
-
|
67 |
-
image = self.vis_processor(image)
|
68 |
-
|
69 |
-
img_id = ann["img_id"]
|
70 |
-
|
71 |
-
return {
|
72 |
-
"image": image,
|
73 |
-
"image_id": img_id,
|
74 |
-
"instance_id": ann["instance_id"],
|
75 |
-
}
|
76 |
-
|
77 |
-
|
78 |
-
class RefCOCOEvalData(torch.utils.data.Dataset):
|
79 |
-
def __init__(self, loaded_data, vis_processor, root_path):
|
80 |
-
self.loaded_data = loaded_data
|
81 |
-
self.root_path = root_path
|
82 |
-
self.vis_processor = vis_processor
|
83 |
-
|
84 |
-
def __len__(self):
|
85 |
-
return len(self.loaded_data)
|
86 |
-
|
87 |
-
def __getitem__(self, idx):
|
88 |
-
data = self.loaded_data[idx]
|
89 |
-
img_id = data['img_id']
|
90 |
-
sent = data['sents']
|
91 |
-
image_path = os.path.join(self.root_path, f'{img_id[:27]}.jpg')
|
92 |
-
image = Image.open(image_path).convert('RGB')
|
93 |
-
image = self.vis_processor(image)
|
94 |
-
question = f"[refer] give me the location of {sent}"
|
95 |
-
return image, question, img_id
|
96 |
-
|
97 |
-
class EvalCaptionData(torch.utils.data.Dataset):
|
98 |
-
def __init__(self, loaded_data, vis_processor, root_path):
|
99 |
-
self.loaded_data = loaded_data
|
100 |
-
self.root_path = root_path
|
101 |
-
self.vis_processor = vis_processor
|
102 |
-
ann = dict()
|
103 |
-
for item in self.loaded_data:
|
104 |
-
image_id = item['image_id']
|
105 |
-
ann[image_id] = item['image']
|
106 |
-
self.ann = [{'image_id':image_id, 'image': ann[image_id]} for image_id in ann]
|
107 |
-
|
108 |
-
def __len__(self):
|
109 |
-
return len(self.ann)
|
110 |
-
|
111 |
-
def __getitem__(self, idx):
|
112 |
-
data = self.ann[idx]
|
113 |
-
image_id = data['image_id']
|
114 |
-
img_file = data['image'].split('/')[-1]
|
115 |
-
image_path = os.path.join(self.root_path, img_file)
|
116 |
-
image = Image.open(image_path).convert('RGB')
|
117 |
-
|
118 |
-
image = self.vis_processor(image)
|
119 |
-
question = f"[caption] please describe this image?"
|
120 |
-
return image, question, image_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/coco_dataset.py
DELETED
@@ -1,348 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import pickle
|
4 |
-
import random
|
5 |
-
import time
|
6 |
-
import itertools
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
from PIL import Image
|
10 |
-
import skimage.io as io
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
from matplotlib.collections import PatchCollection
|
13 |
-
from matplotlib.patches import Polygon, Rectangle
|
14 |
-
from torch.utils.data import Dataset
|
15 |
-
import webdataset as wds
|
16 |
-
|
17 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
18 |
-
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
19 |
-
|
20 |
-
|
21 |
-
class ReferCOCODataset(Dataset):
|
22 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path, dataset='refcoco', splitBy='unc'):
|
23 |
-
"""
|
24 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
25 |
-
ann_root (string): directory to store the annotation file
|
26 |
-
"""
|
27 |
-
self.vis_root = vis_root
|
28 |
-
|
29 |
-
self.vis_processor = vis_processor
|
30 |
-
self.text_processor = text_processor
|
31 |
-
|
32 |
-
self.refer = REFER(ann_path, vis_root, dataset, splitBy)
|
33 |
-
self.ref_ids = self.refer.getRefIds(split="train")
|
34 |
-
|
35 |
-
self.instruction_pool = [
|
36 |
-
"[refer] {}",
|
37 |
-
"[refer] give me the location of {}",
|
38 |
-
"[refer] where is {} ?",
|
39 |
-
"[refer] from this image, tell me the location of {}",
|
40 |
-
"[refer] the location of {} is",
|
41 |
-
"[refer] could you tell me the location for {} ?",
|
42 |
-
"[refer] where can I locate the {} ?",
|
43 |
-
]
|
44 |
-
|
45 |
-
|
46 |
-
def __len__(self):
|
47 |
-
return len(self.ref_ids)
|
48 |
-
|
49 |
-
def preprocess(self, index):
|
50 |
-
ref_id = self.ref_ids[index]
|
51 |
-
ref = self.refer.loadRefs(ref_id)[0]
|
52 |
-
|
53 |
-
image_file = 'COCO_train2014_{:0>12}.jpg'.format(ref["image_id"])
|
54 |
-
image_path = os.path.join(self.vis_root, image_file)
|
55 |
-
image = Image.open(image_path).convert("RGB")
|
56 |
-
image_orig_size = image.size
|
57 |
-
image = self.vis_processor(image)
|
58 |
-
image_new_size = [image.shape[1], image.shape[2]]
|
59 |
-
|
60 |
-
image_new_size = [100,100]
|
61 |
-
|
62 |
-
sample_sentence = random.choice(ref['sentences'])['raw']
|
63 |
-
refer_sentence = self.text_processor(sample_sentence)
|
64 |
-
|
65 |
-
|
66 |
-
bbox = self.refer.getRefBox(ref['ref_id'])
|
67 |
-
bbox = [
|
68 |
-
bbox[0] / image_orig_size[0] * image_new_size[0],
|
69 |
-
bbox[1] / image_orig_size[1] * image_new_size[1],
|
70 |
-
(bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0],
|
71 |
-
(bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1]
|
72 |
-
]
|
73 |
-
bbox = [int(x) for x in bbox]
|
74 |
-
bbox = "{{<{}><{}><{}><{}>}}".format(*bbox)
|
75 |
-
return {
|
76 |
-
"image": image,
|
77 |
-
"refer_sentence": refer_sentence,
|
78 |
-
"bbox": bbox,
|
79 |
-
"image_id": ref['image_id'],
|
80 |
-
}
|
81 |
-
|
82 |
-
def __getitem__(self, index):
|
83 |
-
data = self.preprocess(index)
|
84 |
-
instruction = random.choice(self.instruction_pool).format(data['refer_sentence'])
|
85 |
-
|
86 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
87 |
-
|
88 |
-
return {
|
89 |
-
"image": data['image'],
|
90 |
-
"instruction_input": instruction,
|
91 |
-
"answer": data['bbox'],
|
92 |
-
"image_id": data['image_id'],
|
93 |
-
}
|
94 |
-
|
95 |
-
|
96 |
-
class InvReferCOCODataset(ReferCOCODataset):
|
97 |
-
def __init__(self, *args, **kwargs):
|
98 |
-
super(InvReferCOCODataset, self).__init__(*args, **kwargs)
|
99 |
-
|
100 |
-
self.instruction_pool = [
|
101 |
-
"[identify] {}",
|
102 |
-
"[identify] what object is in this location {}",
|
103 |
-
"[identify] identify the object present at this location {}",
|
104 |
-
"[identify] what is it in {}",
|
105 |
-
"[identify] describe this object in {}",
|
106 |
-
"[identify] this {} is",
|
107 |
-
"[identify] the object in {} is",
|
108 |
-
]
|
109 |
-
|
110 |
-
def __getitem__(self, index):
|
111 |
-
data = self.preprocess(index)
|
112 |
-
|
113 |
-
instruction = random.choice(self.instruction_pool).format(data['bbox'])
|
114 |
-
|
115 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
116 |
-
|
117 |
-
return {
|
118 |
-
"image": data['image'],
|
119 |
-
"instruction_input": instruction,
|
120 |
-
"answer": self.text_processor(data['refer_sentence']),
|
121 |
-
"image_id": data['image_id'],
|
122 |
-
}
|
123 |
-
|
124 |
-
|
125 |
-
class REFER:
|
126 |
-
def __init__(self, data_root, vis_root, dataset='refcoco', splitBy='unc'):
|
127 |
-
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
128 |
-
# also provide dataset name and splitBy information
|
129 |
-
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
130 |
-
dataset = dataset.split('inv')[-1] # inv dataset is stored in the same path as normal dataset
|
131 |
-
print('loading dataset %s into memory...' % dataset)
|
132 |
-
self.ann_dir = os.path.join(data_root, dataset)
|
133 |
-
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
|
134 |
-
self.vis_root = vis_root
|
135 |
-
elif dataset == 'refclef':
|
136 |
-
raise 'No RefClef image data'
|
137 |
-
else:
|
138 |
-
raise 'No refer dataset is called [%s]' % dataset
|
139 |
-
|
140 |
-
# load refs from data/dataset/refs(dataset).json
|
141 |
-
tic = time.time()
|
142 |
-
ref_file = os.path.join(self.ann_dir, 'refs(' + splitBy + ').p')
|
143 |
-
self.data = {}
|
144 |
-
self.data['dataset'] = dataset
|
145 |
-
self.data['refs'] = pickle.load(open(ref_file, 'rb'))
|
146 |
-
|
147 |
-
# load annotations from data/dataset/instances.json
|
148 |
-
instances_file = os.path.join(self.ann_dir, 'instances.json')
|
149 |
-
instances = json.load(open(instances_file, 'r'))
|
150 |
-
self.data['images'] = instances['images']
|
151 |
-
self.data['annotations'] = instances['annotations']
|
152 |
-
self.data['categories'] = instances['categories']
|
153 |
-
|
154 |
-
# create index
|
155 |
-
self.createIndex()
|
156 |
-
print('DONE (t=%.2fs)' % (time.time() - tic))
|
157 |
-
|
158 |
-
def createIndex(self):
|
159 |
-
# create sets of mapping
|
160 |
-
# 1) Refs: {ref_id: ref}
|
161 |
-
# 2) Anns: {ann_id: ann}
|
162 |
-
# 3) Imgs: {image_id: image}
|
163 |
-
# 4) Cats: {category_id: category_name}
|
164 |
-
# 5) Sents: {sent_id: sent}
|
165 |
-
# 6) imgToRefs: {image_id: refs}
|
166 |
-
# 7) imgToAnns: {image_id: anns}
|
167 |
-
# 8) refToAnn: {ref_id: ann}
|
168 |
-
# 9) annToRef: {ann_id: ref}
|
169 |
-
# 10) catToRefs: {category_id: refs}
|
170 |
-
# 11) sentToRef: {sent_id: ref}
|
171 |
-
# 12) sentToTokens: {sent_id: tokens}
|
172 |
-
print('creating index...')
|
173 |
-
# fetch info from instances
|
174 |
-
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
175 |
-
for ann in self.data['annotations']:
|
176 |
-
Anns[ann['id']] = ann
|
177 |
-
imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
|
178 |
-
for img in self.data['images']:
|
179 |
-
Imgs[img['id']] = img
|
180 |
-
for cat in self.data['categories']:
|
181 |
-
Cats[cat['id']] = cat['name']
|
182 |
-
|
183 |
-
# fetch info from refs
|
184 |
-
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
185 |
-
Sents, sentToRef, sentToTokens = {}, {}, {}
|
186 |
-
for ref in self.data['refs']:
|
187 |
-
# ids
|
188 |
-
ref_id = ref['ref_id']
|
189 |
-
ann_id = ref['ann_id']
|
190 |
-
category_id = ref['category_id']
|
191 |
-
image_id = ref['image_id']
|
192 |
-
|
193 |
-
# add mapping related to ref
|
194 |
-
Refs[ref_id] = ref
|
195 |
-
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
196 |
-
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
197 |
-
refToAnn[ref_id] = Anns[ann_id]
|
198 |
-
annToRef[ann_id] = ref
|
199 |
-
|
200 |
-
# add mapping of sent
|
201 |
-
for sent in ref['sentences']:
|
202 |
-
Sents[sent['sent_id']] = sent
|
203 |
-
sentToRef[sent['sent_id']] = ref
|
204 |
-
sentToTokens[sent['sent_id']] = sent['tokens']
|
205 |
-
|
206 |
-
# create class members
|
207 |
-
self.Refs = Refs
|
208 |
-
self.Anns = Anns
|
209 |
-
self.Imgs = Imgs
|
210 |
-
self.Cats = Cats
|
211 |
-
self.Sents = Sents
|
212 |
-
self.imgToRefs = imgToRefs
|
213 |
-
self.imgToAnns = imgToAnns
|
214 |
-
self.refToAnn = refToAnn
|
215 |
-
self.annToRef = annToRef
|
216 |
-
self.catToRefs = catToRefs
|
217 |
-
self.sentToRef = sentToRef
|
218 |
-
self.sentToTokens = sentToTokens
|
219 |
-
print('index created.')
|
220 |
-
|
221 |
-
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
|
222 |
-
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
223 |
-
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
224 |
-
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
225 |
-
|
226 |
-
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
227 |
-
refs = self.data['refs']
|
228 |
-
else:
|
229 |
-
if not len(image_ids) == 0:
|
230 |
-
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
231 |
-
else:
|
232 |
-
refs = self.data['refs']
|
233 |
-
if not len(cat_ids) == 0:
|
234 |
-
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
|
235 |
-
if not len(ref_ids) == 0:
|
236 |
-
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
|
237 |
-
if not len(split) == 0:
|
238 |
-
if split in ['testA', 'testB', 'testC']:
|
239 |
-
refs = [ref for ref in refs if
|
240 |
-
split[-1] in ref['split']] # we also consider testAB, testBC, ...
|
241 |
-
elif split in ['testAB', 'testBC', 'testAC']:
|
242 |
-
refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
|
243 |
-
elif split == 'test':
|
244 |
-
refs = [ref for ref in refs if 'test' in ref['split']]
|
245 |
-
elif split == 'train' or split == 'val':
|
246 |
-
refs = [ref for ref in refs if ref['split'] == split]
|
247 |
-
else:
|
248 |
-
raise 'No such split [%s]' % split
|
249 |
-
ref_ids = [ref['ref_id'] for ref in refs]
|
250 |
-
return ref_ids
|
251 |
-
|
252 |
-
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
253 |
-
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
254 |
-
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
255 |
-
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
256 |
-
|
257 |
-
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
258 |
-
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
259 |
-
else:
|
260 |
-
if not len(image_ids) == 0:
|
261 |
-
lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
|
262 |
-
anns = list(itertools.chain.from_iterable(lists))
|
263 |
-
else:
|
264 |
-
anns = self.data['annotations']
|
265 |
-
if not len(cat_ids) == 0:
|
266 |
-
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
|
267 |
-
ann_ids = [ann['id'] for ann in anns]
|
268 |
-
if not len(ref_ids) == 0:
|
269 |
-
ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
|
270 |
-
return ann_ids
|
271 |
-
|
272 |
-
def getImgIds(self, ref_ids=[]):
|
273 |
-
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
274 |
-
|
275 |
-
if not len(ref_ids) == 0:
|
276 |
-
image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
|
277 |
-
else:
|
278 |
-
image_ids = self.Imgs.keys()
|
279 |
-
return image_ids
|
280 |
-
|
281 |
-
def getCatIds(self):
|
282 |
-
return self.Cats.keys()
|
283 |
-
|
284 |
-
def loadRefs(self, ref_ids=[]):
|
285 |
-
if type(ref_ids) == list:
|
286 |
-
return [self.Refs[ref_id] for ref_id in ref_ids]
|
287 |
-
elif type(ref_ids) == int:
|
288 |
-
return [self.Refs[ref_ids]]
|
289 |
-
|
290 |
-
def loadAnns(self, ann_ids=[]):
|
291 |
-
if type(ann_ids) == list:
|
292 |
-
return [self.Anns[ann_id] for ann_id in ann_ids]
|
293 |
-
elif type(ann_ids) == int:
|
294 |
-
return [self.Anns[ann_ids]]
|
295 |
-
|
296 |
-
def loadImgs(self, image_ids=[]):
|
297 |
-
if type(image_ids) == list:
|
298 |
-
return [self.Imgs[image_id] for image_id in image_ids]
|
299 |
-
elif type(image_ids) == int:
|
300 |
-
return [self.Imgs[image_ids]]
|
301 |
-
|
302 |
-
def loadCats(self, cat_ids=[]):
|
303 |
-
if type(cat_ids) == list:
|
304 |
-
return [self.Cats[cat_id] for cat_id in cat_ids]
|
305 |
-
elif type(cat_ids) == int:
|
306 |
-
return [self.Cats[cat_ids]]
|
307 |
-
|
308 |
-
def getRefBox(self, ref_id):
|
309 |
-
ref = self.Refs[ref_id]
|
310 |
-
ann = self.refToAnn[ref_id]
|
311 |
-
return ann['bbox'] # [x, y, w, h]
|
312 |
-
|
313 |
-
def showRef(self, ref, seg_box='box'):
|
314 |
-
ax = plt.gca()
|
315 |
-
# show image
|
316 |
-
image = self.Imgs[ref['image_id']]
|
317 |
-
I = io.imread(os.path.join(self.vis_root, image['file_name']))
|
318 |
-
ax.imshow(I)
|
319 |
-
# show refer expression
|
320 |
-
for sid, sent in enumerate(ref['sentences']):
|
321 |
-
print('%s. %s' % (sid + 1, sent['sent']))
|
322 |
-
# show segmentations
|
323 |
-
if seg_box == 'seg':
|
324 |
-
ann_id = ref['ann_id']
|
325 |
-
ann = self.Anns[ann_id]
|
326 |
-
polygons = []
|
327 |
-
color = []
|
328 |
-
c = 'none'
|
329 |
-
if type(ann['segmentation'][0]) == list:
|
330 |
-
# polygon used for refcoco*
|
331 |
-
for seg in ann['segmentation']:
|
332 |
-
poly = np.array(seg).reshape((len(seg) / 2, 2))
|
333 |
-
polygons.append(Polygon(poly, True, alpha=0.4))
|
334 |
-
color.append(c)
|
335 |
-
p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 1, 0, 0), linewidths=3, alpha=1)
|
336 |
-
ax.add_collection(p) # thick yellow polygon
|
337 |
-
p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 0, 0, 0), linewidths=1, alpha=1)
|
338 |
-
ax.add_collection(p) # thin red polygon
|
339 |
-
else:
|
340 |
-
# mask used for refclef
|
341 |
-
raise NotImplementedError('RefClef is not downloaded')
|
342 |
-
# show bounding-box
|
343 |
-
elif seg_box == 'box':
|
344 |
-
ann_id = ref['ann_id']
|
345 |
-
ann = self.Anns[ann_id]
|
346 |
-
bbox = self.getRefBox(ref['ref_id'])
|
347 |
-
box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
|
348 |
-
ax.add_patch(box_plot)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/coco_vqa_datasets.py
DELETED
@@ -1,145 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
import json
|
10 |
-
import random
|
11 |
-
|
12 |
-
from PIL import Image
|
13 |
-
|
14 |
-
from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
|
15 |
-
|
16 |
-
from collections import OrderedDict
|
17 |
-
|
18 |
-
|
19 |
-
class __DisplMixin:
|
20 |
-
def displ_item(self, index):
|
21 |
-
sample, ann = self.__getitem__(index), self.annotation[index]
|
22 |
-
|
23 |
-
return OrderedDict(
|
24 |
-
{
|
25 |
-
"file": ann["image"],
|
26 |
-
"question": ann["question"],
|
27 |
-
"question_id": ann["question_id"],
|
28 |
-
"answers": "; ".join(ann["answer"]),
|
29 |
-
"image": sample["image"],
|
30 |
-
}
|
31 |
-
)
|
32 |
-
|
33 |
-
|
34 |
-
class COCOVQADataset(VQADataset, __DisplMixin):
|
35 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
36 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
37 |
-
|
38 |
-
self.instruction_pool =[
|
39 |
-
"[vqa] {}",
|
40 |
-
"[vqa] Based on the image, respond to this question with a short answer: {}"
|
41 |
-
]
|
42 |
-
|
43 |
-
exist_annotation = []
|
44 |
-
for ann in self.annotation:
|
45 |
-
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
|
46 |
-
if os.path.exists(image_path):
|
47 |
-
exist_annotation.append(ann)
|
48 |
-
self.annotation = exist_annotation
|
49 |
-
|
50 |
-
|
51 |
-
def get_data(self, index):
|
52 |
-
ann = self.annotation[index]
|
53 |
-
|
54 |
-
image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
|
55 |
-
image = Image.open(image_path).convert("RGB")
|
56 |
-
|
57 |
-
image = self.vis_processor(image)
|
58 |
-
question = self.text_processor(ann["question"])
|
59 |
-
question_id = ann["question_id"]
|
60 |
-
|
61 |
-
answer_weight = {}
|
62 |
-
for answer in ann["answer"]:
|
63 |
-
if answer in answer_weight.keys():
|
64 |
-
answer_weight[answer] += 1 / len(ann["answer"])
|
65 |
-
else:
|
66 |
-
answer_weight[answer] = 1 / len(ann["answer"])
|
67 |
-
|
68 |
-
answers = list(answer_weight.keys())
|
69 |
-
weights = list(answer_weight.values())
|
70 |
-
|
71 |
-
answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
|
72 |
-
|
73 |
-
|
74 |
-
return {
|
75 |
-
"image": image,
|
76 |
-
"question": question,
|
77 |
-
"question_id": question_id,
|
78 |
-
"answer": answer,
|
79 |
-
}
|
80 |
-
|
81 |
-
def __getitem__(self, index):
|
82 |
-
data = self.get_data(index)
|
83 |
-
instruction = random.choice(self.instruction_pool).format(data['question'])
|
84 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
85 |
-
|
86 |
-
return {
|
87 |
-
"image": data['image'],
|
88 |
-
"question_id": data["question_id"],
|
89 |
-
"instruction_input": instruction,
|
90 |
-
"answer": self.text_processor(data['answer']),
|
91 |
-
}
|
92 |
-
|
93 |
-
|
94 |
-
class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin):
|
95 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
96 |
-
"""
|
97 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
98 |
-
ann_root (string): directory to store the annotation file
|
99 |
-
"""
|
100 |
-
|
101 |
-
self.instruction_pool = [
|
102 |
-
'Question: {} Short answer:',
|
103 |
-
]
|
104 |
-
self.vis_root = vis_root
|
105 |
-
|
106 |
-
self.annotation = json.load(open(ann_paths[0]))
|
107 |
-
|
108 |
-
answer_list_path = ann_paths[1]
|
109 |
-
if os.path.exists(answer_list_path):
|
110 |
-
self.answer_list = json.load(open(answer_list_path))
|
111 |
-
else:
|
112 |
-
self.answer_list = None
|
113 |
-
|
114 |
-
try:
|
115 |
-
self.coco_fmt_qust_file = ann_paths[2]
|
116 |
-
self.coco_fmt_anno_file = ann_paths[3]
|
117 |
-
except IndexError:
|
118 |
-
self.coco_fmt_qust_file = None
|
119 |
-
self.coco_fmt_anno_file = None
|
120 |
-
|
121 |
-
self.vis_processor = vis_processor
|
122 |
-
self.text_processor = text_processor
|
123 |
-
|
124 |
-
self._add_instance_ids()
|
125 |
-
|
126 |
-
def __getitem__(self, index):
|
127 |
-
ann = self.annotation[index]
|
128 |
-
|
129 |
-
image_path = os.path.join(self.vis_root, ann["image"])
|
130 |
-
image = Image.open(image_path).convert("RGB")
|
131 |
-
|
132 |
-
image = self.vis_processor(image)
|
133 |
-
question = self.text_processor(ann["question"])
|
134 |
-
|
135 |
-
instruction = random.choice(self.instruction_pool).format(question)
|
136 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
137 |
-
|
138 |
-
return {
|
139 |
-
"image": image,
|
140 |
-
'image_path': image_path,
|
141 |
-
"question": question,
|
142 |
-
"question_id": ann["question_id"],
|
143 |
-
"instruction_input": instruction,
|
144 |
-
"instance_id": ann["instance_id"],
|
145 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/dataloader_utils.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import time
|
9 |
-
import random
|
10 |
-
import torch
|
11 |
-
from minigpt4.datasets.data_utils import move_to_cuda
|
12 |
-
from torch.utils.data import DataLoader
|
13 |
-
|
14 |
-
|
15 |
-
class MultiIterLoader:
|
16 |
-
"""
|
17 |
-
A simple wrapper for iterating over multiple iterators.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
loaders (List[Loader]): List of Iterator loaders.
|
21 |
-
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self, loaders, ratios=None):
|
25 |
-
# assert all loaders has __next__ method
|
26 |
-
for loader in loaders:
|
27 |
-
assert hasattr(
|
28 |
-
loader, "__next__"
|
29 |
-
), "Loader {} has no __next__ method.".format(loader)
|
30 |
-
|
31 |
-
if ratios is None:
|
32 |
-
ratios = [1.0] * len(loaders)
|
33 |
-
else:
|
34 |
-
assert len(ratios) == len(loaders)
|
35 |
-
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
|
36 |
-
|
37 |
-
self.loaders = loaders
|
38 |
-
self.ratios = ratios
|
39 |
-
|
40 |
-
def __next__(self):
|
41 |
-
# random sample from each loader by ratio
|
42 |
-
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
|
43 |
-
return next(self.loaders[loader_idx])
|
44 |
-
|
45 |
-
|
46 |
-
class PrefetchLoader(object):
|
47 |
-
"""
|
48 |
-
Modified from https://github.com/ChenRocks/UNITER.
|
49 |
-
|
50 |
-
overlap compute and cuda data transfer
|
51 |
-
(copied and then modified from nvidia apex)
|
52 |
-
"""
|
53 |
-
|
54 |
-
def __init__(self, loader):
|
55 |
-
self.loader = loader
|
56 |
-
self.stream = torch.cuda.Stream()
|
57 |
-
|
58 |
-
def __iter__(self):
|
59 |
-
loader_it = iter(self.loader)
|
60 |
-
self.preload(loader_it)
|
61 |
-
batch = self.next(loader_it)
|
62 |
-
while batch is not None:
|
63 |
-
is_tuple = isinstance(batch, tuple)
|
64 |
-
if is_tuple:
|
65 |
-
task, batch = batch
|
66 |
-
|
67 |
-
if is_tuple:
|
68 |
-
yield task, batch
|
69 |
-
else:
|
70 |
-
yield batch
|
71 |
-
batch = self.next(loader_it)
|
72 |
-
|
73 |
-
def __len__(self):
|
74 |
-
return len(self.loader)
|
75 |
-
|
76 |
-
def preload(self, it):
|
77 |
-
try:
|
78 |
-
self.batch = next(it)
|
79 |
-
except StopIteration:
|
80 |
-
self.batch = None
|
81 |
-
return
|
82 |
-
# if record_stream() doesn't work, another option is to make sure
|
83 |
-
# device inputs are created on the main stream.
|
84 |
-
# self.next_input_gpu = torch.empty_like(self.next_input,
|
85 |
-
# device='cuda')
|
86 |
-
# self.next_target_gpu = torch.empty_like(self.next_target,
|
87 |
-
# device='cuda')
|
88 |
-
# Need to make sure the memory allocated for next_* is not still in use
|
89 |
-
# by the main stream at the time we start copying to next_*:
|
90 |
-
# self.stream.wait_stream(torch.cuda.current_stream())
|
91 |
-
with torch.cuda.stream(self.stream):
|
92 |
-
self.batch = move_to_cuda(self.batch)
|
93 |
-
# more code for the alternative if record_stream() doesn't work:
|
94 |
-
# copy_ will record the use of the pinned source tensor in this
|
95 |
-
# side stream.
|
96 |
-
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
97 |
-
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
98 |
-
# self.next_input = self.next_input_gpu
|
99 |
-
# self.next_target = self.next_target_gpu
|
100 |
-
|
101 |
-
def next(self, it):
|
102 |
-
torch.cuda.current_stream().wait_stream(self.stream)
|
103 |
-
batch = self.batch
|
104 |
-
if batch is not None:
|
105 |
-
record_cuda_stream(batch)
|
106 |
-
self.preload(it)
|
107 |
-
return batch
|
108 |
-
|
109 |
-
def __getattr__(self, name):
|
110 |
-
method = self.loader.__getattribute__(name)
|
111 |
-
return method
|
112 |
-
|
113 |
-
|
114 |
-
def record_cuda_stream(batch):
|
115 |
-
if isinstance(batch, torch.Tensor):
|
116 |
-
batch.record_stream(torch.cuda.current_stream())
|
117 |
-
elif isinstance(batch, list) or isinstance(batch, tuple):
|
118 |
-
for t in batch:
|
119 |
-
record_cuda_stream(t)
|
120 |
-
elif isinstance(batch, dict):
|
121 |
-
for t in batch.values():
|
122 |
-
record_cuda_stream(t)
|
123 |
-
else:
|
124 |
-
pass
|
125 |
-
|
126 |
-
|
127 |
-
class IterLoader:
|
128 |
-
"""
|
129 |
-
A wrapper to convert DataLoader as an infinite iterator.
|
130 |
-
|
131 |
-
Modified from:
|
132 |
-
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
133 |
-
"""
|
134 |
-
|
135 |
-
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
136 |
-
self._dataloader = dataloader
|
137 |
-
self.iter_loader = iter(self._dataloader)
|
138 |
-
self._use_distributed = use_distributed
|
139 |
-
self._epoch = 0
|
140 |
-
|
141 |
-
@property
|
142 |
-
def epoch(self) -> int:
|
143 |
-
return self._epoch
|
144 |
-
|
145 |
-
def __next__(self):
|
146 |
-
try:
|
147 |
-
data = next(self.iter_loader)
|
148 |
-
except StopIteration:
|
149 |
-
self._epoch += 1
|
150 |
-
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
151 |
-
self._dataloader.sampler.set_epoch(self._epoch)
|
152 |
-
time.sleep(2) # Prevent possible deadlock during epoch transition
|
153 |
-
self.iter_loader = iter(self._dataloader)
|
154 |
-
data = next(self.iter_loader)
|
155 |
-
|
156 |
-
return data
|
157 |
-
|
158 |
-
def __iter__(self):
|
159 |
-
return self
|
160 |
-
|
161 |
-
def __len__(self):
|
162 |
-
return len(self._dataloader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/flickr.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import pickle
|
4 |
-
import random
|
5 |
-
import time
|
6 |
-
import itertools
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
from PIL import Image
|
10 |
-
import skimage.io as io
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
from matplotlib.collections import PatchCollection
|
13 |
-
from matplotlib.patches import Polygon, Rectangle
|
14 |
-
from torch.utils.data import Dataset
|
15 |
-
import webdataset as wds
|
16 |
-
|
17 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
18 |
-
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
19 |
-
|
20 |
-
|
21 |
-
class GroundedDetailDataset(Dataset):
|
22 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
23 |
-
"""
|
24 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
25 |
-
ann_root (string): directory to store the annotation file
|
26 |
-
"""
|
27 |
-
self.vis_root = vis_root
|
28 |
-
|
29 |
-
self.vis_processor = vis_processor
|
30 |
-
self.text_processor = text_processor
|
31 |
-
|
32 |
-
self.instruction_pool = [
|
33 |
-
'[grounding] please describe this image in details',
|
34 |
-
'[grounding] describe this image as detailed as possible',
|
35 |
-
'[grounding] summarize this image in details',
|
36 |
-
'[grounding] give a thorough description of what you see in this image',
|
37 |
-
]
|
38 |
-
|
39 |
-
with open(ann_path, 'r') as f:
|
40 |
-
self.ann = json.load(f)
|
41 |
-
|
42 |
-
def __len__(self):
|
43 |
-
return len(self.ann)
|
44 |
-
|
45 |
-
def __getitem__(self, index):
|
46 |
-
info = self.ann[index]
|
47 |
-
|
48 |
-
# image_file = 'COCO_train2014_{}.jpg'.format(info['image_id'])
|
49 |
-
image_file = '{}.jpg'.format(info['image_id'])
|
50 |
-
image_path = os.path.join(self.vis_root, image_file)
|
51 |
-
image = Image.open(image_path).convert("RGB")
|
52 |
-
image = self.vis_processor(image)
|
53 |
-
|
54 |
-
answer = info['grounded_caption']
|
55 |
-
instruction = random.choice(self.instruction_pool)
|
56 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
57 |
-
|
58 |
-
return {
|
59 |
-
"image": image,
|
60 |
-
"instruction_input": instruction,
|
61 |
-
"answer": answer,
|
62 |
-
"image_id": info['image_id'],
|
63 |
-
}
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
class CaptionToObjectDataset(Dataset):
|
69 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
70 |
-
"""
|
71 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
72 |
-
ann_root (string): directory to store the annotation file
|
73 |
-
"""
|
74 |
-
self.vis_root = vis_root
|
75 |
-
|
76 |
-
self.vis_processor = vis_processor
|
77 |
-
self.text_processor = text_processor
|
78 |
-
|
79 |
-
self.instruction_pool = [
|
80 |
-
'[detection] {}',
|
81 |
-
]
|
82 |
-
|
83 |
-
with open(ann_path, 'r') as f:
|
84 |
-
self.ann = json.load(f)
|
85 |
-
|
86 |
-
def __len__(self):
|
87 |
-
return len(self.ann)
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
info = self.ann[index]
|
91 |
-
|
92 |
-
image_file = '{}.jpg'.format(info['image_id'])
|
93 |
-
image_path = os.path.join(self.vis_root, image_file)
|
94 |
-
image = Image.open(image_path).convert("RGB")
|
95 |
-
image = self.vis_processor(image)
|
96 |
-
|
97 |
-
input = info["caption"]
|
98 |
-
answer = info["output"]
|
99 |
-
|
100 |
-
instruction = random.choice(self.instruction_pool).format(input)
|
101 |
-
|
102 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
103 |
-
|
104 |
-
print("CaptionToObject instruction", instruction)
|
105 |
-
print("CaptionToObject answer", answer)
|
106 |
-
|
107 |
-
return {
|
108 |
-
"image": image,
|
109 |
-
"instruction_input": instruction,
|
110 |
-
"answer": answer,
|
111 |
-
"image_id": info['image_id'],
|
112 |
-
}
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
class PhraseToObjectDataset(Dataset):
|
118 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
119 |
-
"""
|
120 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
121 |
-
ann_root (string): directory to store the annotation file
|
122 |
-
"""
|
123 |
-
self.vis_root = vis_root
|
124 |
-
|
125 |
-
self.vis_processor = vis_processor
|
126 |
-
self.text_processor = text_processor
|
127 |
-
|
128 |
-
self.instruction_pool = [
|
129 |
-
'[detection] {}',
|
130 |
-
]
|
131 |
-
|
132 |
-
with open(ann_path, 'r') as f:
|
133 |
-
self.ann = json.load(f)
|
134 |
-
|
135 |
-
def __len__(self):
|
136 |
-
return len(self.ann)
|
137 |
-
|
138 |
-
def __getitem__(self, index):
|
139 |
-
info = self.ann[index]
|
140 |
-
image_file = '{}.jpg'.format(info['image_id'])
|
141 |
-
image_path = os.path.join(self.vis_root, image_file)
|
142 |
-
image = Image.open(image_path).convert("RGB")
|
143 |
-
image = self.vis_processor(image)
|
144 |
-
|
145 |
-
input = info["phrase"]
|
146 |
-
answer = "<p>"+input+"</p> "+info["bbox"]
|
147 |
-
instruction = random.choice(self.instruction_pool).format(input)
|
148 |
-
|
149 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
150 |
-
|
151 |
-
print("PhraseToObject instruction", instruction)
|
152 |
-
print("PhraseToObject answer", answer)
|
153 |
-
|
154 |
-
return {
|
155 |
-
"image": image,
|
156 |
-
"instruction_input": instruction,
|
157 |
-
"answer": answer,
|
158 |
-
"image_id": info['image_id'],
|
159 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/gqa_datasets.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
import json
|
10 |
-
|
11 |
-
from PIL import Image
|
12 |
-
|
13 |
-
from minigpt4.datasets.datasets.vqa_datasets import VQADataset
|
14 |
-
|
15 |
-
from collections import OrderedDict
|
16 |
-
import random
|
17 |
-
|
18 |
-
class __DisplMixin:
|
19 |
-
def displ_item(self, index):
|
20 |
-
sample, ann = self.__getitem__(index), self.annotation[index]
|
21 |
-
|
22 |
-
return OrderedDict(
|
23 |
-
{
|
24 |
-
"file": ann["image"],
|
25 |
-
"question": ann["question"],
|
26 |
-
"question_id": ann["question_id"],
|
27 |
-
"answers": "; ".join(ann["answer"]),
|
28 |
-
"image": sample["image"],
|
29 |
-
}
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
class GQADataset(VQADataset, __DisplMixin):
|
34 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
35 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
36 |
-
self.instruction_pool =[
|
37 |
-
"[vqa] {}",
|
38 |
-
"[vqa] Based on the image, respond to this question with a short answer: {}"
|
39 |
-
]
|
40 |
-
|
41 |
-
def __getitem__(self, index):
|
42 |
-
ann = self.annotation[index]
|
43 |
-
|
44 |
-
image_path = os.path.join(self.vis_root, ann["image"])
|
45 |
-
image = Image.open(image_path).convert("RGB")
|
46 |
-
|
47 |
-
image = self.vis_processor(image)
|
48 |
-
question = self.text_processor(ann["question"])
|
49 |
-
|
50 |
-
instruction = random.choice(self.instruction_pool).format(question)
|
51 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
52 |
-
|
53 |
-
answers = self.text_processor(ann["answer"])
|
54 |
-
|
55 |
-
return {
|
56 |
-
"image": image,
|
57 |
-
"instruction_input": instruction,
|
58 |
-
"answer": answers,
|
59 |
-
}
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/laion_dataset.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import webdataset as wds
|
9 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
10 |
-
|
11 |
-
|
12 |
-
class LaionDataset(BaseDataset):
|
13 |
-
def __init__(self, vis_processor, text_processor, location):
|
14 |
-
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
15 |
-
|
16 |
-
self.inner_dataset = wds.DataPipeline(
|
17 |
-
wds.ResampledShards(location),
|
18 |
-
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
19 |
-
wds.shuffle(1000, handler=wds.warn_and_continue),
|
20 |
-
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
21 |
-
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
22 |
-
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
23 |
-
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
24 |
-
)
|
25 |
-
|
26 |
-
def to_dict(self, sample):
|
27 |
-
return {
|
28 |
-
"image": sample[0],
|
29 |
-
"answer": self.text_processor(sample[1]["caption"]),
|
30 |
-
}
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/llava_dataset.py
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import pickle
|
4 |
-
import random
|
5 |
-
import time
|
6 |
-
import numpy as np
|
7 |
-
from PIL import Image
|
8 |
-
import skimage.io as io
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
-
from matplotlib.collections import PatchCollection
|
11 |
-
from matplotlib.patches import Polygon, Rectangle
|
12 |
-
from torch.utils.data import Dataset
|
13 |
-
import webdataset as wds
|
14 |
-
|
15 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
16 |
-
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
17 |
-
|
18 |
-
class LlavaDetailDataset(Dataset):
|
19 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
20 |
-
"""
|
21 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
22 |
-
ann_root (string): directory to store the annotation file
|
23 |
-
"""
|
24 |
-
self.vis_root = vis_root
|
25 |
-
|
26 |
-
self.vis_processor = vis_processor
|
27 |
-
self.text_processor = text_processor
|
28 |
-
|
29 |
-
with open(ann_path, 'r') as f:
|
30 |
-
self.ann = json.load(f)
|
31 |
-
|
32 |
-
def __len__(self):
|
33 |
-
return len(self.ann)
|
34 |
-
|
35 |
-
def __getitem__(self, index):
|
36 |
-
info = self.ann[index]
|
37 |
-
|
38 |
-
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
|
39 |
-
image_path = os.path.join(self.vis_root, image_file)
|
40 |
-
image = Image.open(image_path).convert("RGB")
|
41 |
-
image = self.vis_processor(image)
|
42 |
-
|
43 |
-
answer = info['conversations'][1]['value']
|
44 |
-
instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
|
45 |
-
|
46 |
-
instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))
|
47 |
-
|
48 |
-
return {
|
49 |
-
"image": image,
|
50 |
-
"instruction_input": instruction,
|
51 |
-
"answer": answer,
|
52 |
-
"image_id": info['id'],
|
53 |
-
}
|
54 |
-
|
55 |
-
class LlavaReasonDataset(Dataset):
|
56 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
57 |
-
"""
|
58 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
59 |
-
ann_root (string): directory to store the annotation file
|
60 |
-
"""
|
61 |
-
self.vis_root = vis_root
|
62 |
-
|
63 |
-
self.vis_processor = vis_processor
|
64 |
-
self.text_processor = text_processor
|
65 |
-
|
66 |
-
with open(ann_path, 'r') as f:
|
67 |
-
self.ann = json.load(f)
|
68 |
-
|
69 |
-
def __len__(self):
|
70 |
-
return len(self.ann)
|
71 |
-
|
72 |
-
def __getitem__(self, index):
|
73 |
-
info = self.ann[index]
|
74 |
-
|
75 |
-
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
|
76 |
-
image_path = os.path.join(self.vis_root, image_file)
|
77 |
-
image = Image.open(image_path).convert("RGB")
|
78 |
-
image = self.vis_processor(image)
|
79 |
-
|
80 |
-
answer = info['conversations'][1]['value']
|
81 |
-
instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
|
82 |
-
|
83 |
-
instruction = '<Img><ImageHere></Img> {} '.format(self.text_processor(instruction))
|
84 |
-
|
85 |
-
return {
|
86 |
-
"image": image,
|
87 |
-
"instruction_input": instruction,
|
88 |
-
"answer": answer,
|
89 |
-
"image_id": info['id'],
|
90 |
-
}
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
class LlavaConversationDataset(Dataset):
|
96 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
97 |
-
"""
|
98 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
99 |
-
ann_root (string): directory to store the annotation file
|
100 |
-
"""
|
101 |
-
self.vis_root = vis_root
|
102 |
-
|
103 |
-
self.vis_processor = vis_processor
|
104 |
-
self.text_processor = text_processor
|
105 |
-
|
106 |
-
self.ann=[]
|
107 |
-
|
108 |
-
|
109 |
-
with open(ann_path, 'r') as f:
|
110 |
-
self.ann = json.load(f)
|
111 |
-
|
112 |
-
self.connect_sym = "!@#"
|
113 |
-
|
114 |
-
def __len__(self):
|
115 |
-
return len(self.ann)
|
116 |
-
|
117 |
-
def __getitem__(self, index):
|
118 |
-
info = self.ann[index]
|
119 |
-
|
120 |
-
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
|
121 |
-
image_path = os.path.join(self.vis_root, image_file)
|
122 |
-
image = Image.open(image_path).convert("RGB")
|
123 |
-
image = self.vis_processor(image)
|
124 |
-
|
125 |
-
first_instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
|
126 |
-
first_instruction = '<Img><ImageHere></Img> {} '.format(first_instruction)
|
127 |
-
|
128 |
-
questions = [first_instruction]
|
129 |
-
answers = []
|
130 |
-
|
131 |
-
for i, item in enumerate(info["conversations"][1:]):
|
132 |
-
if i % 2 ==0: # assistant
|
133 |
-
assistant_answer = item["value"]
|
134 |
-
answers.append(assistant_answer)
|
135 |
-
else:
|
136 |
-
human_instruction = item["value"]+" "
|
137 |
-
questions.append(human_instruction)
|
138 |
-
|
139 |
-
questions = self.connect_sym.join(questions)
|
140 |
-
answers = self.connect_sym.join(answers)
|
141 |
-
|
142 |
-
|
143 |
-
return {
|
144 |
-
"image": image,
|
145 |
-
"conv_q": questions,
|
146 |
-
'conv_a': answers,
|
147 |
-
"image_id": info['id'],
|
148 |
-
"connect_sym": self.connect_sym
|
149 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/multitask_conversation.py
DELETED
@@ -1,75 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import pickle
|
4 |
-
import random
|
5 |
-
import time
|
6 |
-
import itertools
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
from PIL import Image
|
10 |
-
import skimage.io as io
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
from matplotlib.collections import PatchCollection
|
13 |
-
from matplotlib.patches import Polygon, Rectangle
|
14 |
-
from torch.utils.data import Dataset
|
15 |
-
import webdataset as wds
|
16 |
-
|
17 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
18 |
-
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
class MultiTaskConversationDataset(Dataset):
|
24 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
25 |
-
"""
|
26 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
27 |
-
ann_root (string): directory to store the annotation file
|
28 |
-
"""
|
29 |
-
self.vis_root = vis_root
|
30 |
-
|
31 |
-
self.vis_processor = vis_processor
|
32 |
-
self.text_processor = text_processor
|
33 |
-
|
34 |
-
|
35 |
-
with open(ann_path, 'r') as f:
|
36 |
-
self.ann = json.load(f)
|
37 |
-
|
38 |
-
self.connect_sym = "!@#"
|
39 |
-
|
40 |
-
def __len__(self):
|
41 |
-
return len(self.ann)
|
42 |
-
|
43 |
-
def __getitem__(self, index):
|
44 |
-
info = self.ann[index]
|
45 |
-
|
46 |
-
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
|
47 |
-
image_path = os.path.join(self.vis_root, image_file)
|
48 |
-
image = Image.open(image_path).convert("RGB")
|
49 |
-
image = self.vis_processor(image)
|
50 |
-
|
51 |
-
first_instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
|
52 |
-
first_instruction = '<Img><ImageHere></Img> {} '.format(first_instruction)
|
53 |
-
|
54 |
-
questions = [first_instruction]
|
55 |
-
answers = []
|
56 |
-
|
57 |
-
for i, item in enumerate(info["conversations"][1:]):
|
58 |
-
if i % 2 ==0: # assistant
|
59 |
-
assistant_answer = item["value"]
|
60 |
-
answers.append(assistant_answer)
|
61 |
-
else:
|
62 |
-
human_instruction = item["value"]+" "
|
63 |
-
questions.append(human_instruction)
|
64 |
-
|
65 |
-
questions = self.connect_sym.join(questions)
|
66 |
-
answers = self.connect_sym.join(answers)
|
67 |
-
|
68 |
-
|
69 |
-
return {
|
70 |
-
"image": image,
|
71 |
-
"conv_q": questions,
|
72 |
-
'conv_a': answers,
|
73 |
-
"image_id": info['id'],
|
74 |
-
"connect_sym": self.connect_sym
|
75 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minigpt4/datasets/datasets/ocrvqa_dataset.py
DELETED
@@ -1,77 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import pickle
|
4 |
-
import random
|
5 |
-
import time
|
6 |
-
import itertools
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
from PIL import Image
|
10 |
-
import skimage.io as io
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
from matplotlib.collections import PatchCollection
|
13 |
-
from matplotlib.patches import Polygon, Rectangle
|
14 |
-
from torch.utils.data import Dataset
|
15 |
-
import webdataset as wds
|
16 |
-
|
17 |
-
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
18 |
-
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
19 |
-
|
20 |
-
|
21 |
-
class OCRVQADataset(Dataset):
|
22 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
23 |
-
"""
|
24 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
25 |
-
ann_root (string): directory to store the annotation file
|
26 |
-
"""
|
27 |
-
self.vis_root = vis_root
|
28 |
-
|
29 |
-
self.vis_processor = vis_processor
|
30 |
-
self.text_processor = text_processor
|
31 |
-
self.data = self.create_data(ann_path)
|
32 |
-
|
33 |
-
self.instruction_pool =[
|
34 |
-
"[vqa] {}",
|
35 |
-
"[vqa] Based on the image, respond to this question with a short answer: {}"
|
36 |
-
]
|
37 |
-
|
38 |
-
def create_data(self, ann_path):
|
39 |
-
processed_data = []
|
40 |
-
with open(ann_path, 'r') as f:
|
41 |
-
data = json.load(f)
|
42 |
-
for k in data.keys():
|
43 |
-
if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test
|
44 |
-
ext = os.path.splitext(data[k]['imageURL'])[1]
|
45 |
-
imageFile = k + ext
|
46 |
-
assert len(data[k]['questions']) == len(data[k]['answers'])
|
47 |
-
for q, a in zip(data[k]['questions'], data[k]['answers']):
|
48 |
-
processed_data.append(
|
49 |
-
{'question': q,
|
50 |
-
'answer': a,
|
51 |
-
'image_path': imageFile,
|
52 |
-
'image_id': k,
|
53 |
-
'title': data[k]['title'],
|
54 |
-
'genre': data[k]['genre'],
|
55 |
-
}
|
56 |
-
)
|
57 |
-
return processed_data
|
58 |
-
|
59 |
-
def __len__(self):
|
60 |
-
return len(self.data)
|
61 |
-
|
62 |
-
def __getitem__(self, index):
|
63 |
-
sample = self.data[index]
|
64 |
-
image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
|
65 |
-
image = self.vis_processor(image)
|
66 |
-
question = self.text_processor(sample["question"])
|
67 |
-
answer = self.text_processor(sample["answer"])
|
68 |
-
|
69 |
-
instruction = random.choice(self.instruction_pool).format(question)
|
70 |
-
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
71 |
-
return {
|
72 |
-
"image": image,
|
73 |
-
"instruction_input": instruction,
|
74 |
-
"answer": answer,
|
75 |
-
"image_id": sample['image_id']
|
76 |
-
}
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|