nicolaus625 commited on
Commit
092410b
1 Parent(s): 4fe117d

Upload model

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. modelling_musilingo.py +778 -14
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- license: cc-by-nc-4.0
3
  language:
4
  - en
 
5
  library_name: transformers
6
  tags:
7
  - music
 
1
  ---
 
2
  language:
3
  - en
4
+ license: cc-by-nc-4.0
5
  library_name: transformers
6
  tags:
7
  - music
modelling_musilingo.py CHANGED
@@ -3,6 +3,11 @@ import os
3
  import random
4
  import math
5
  import re
 
 
 
 
 
6
  from typing import List, Optional, Tuple, Union
7
 
8
  from torch.cuda.amp import autocast as autocast
@@ -14,7 +19,7 @@ from torch.nn import CrossEntropyLoss
14
  from transformers import Wav2Vec2FeatureExtractor
15
  from omegaconf import OmegaConf
16
 
17
- from musilingo_huggingface.configuration_musilingo import MusiLingoConfig, PATH
18
  import timm.models.hub as timm_hub
19
 
20
 
@@ -28,6 +33,765 @@ from transformers import PreTrainedModel
28
 
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class Registry:
32
  mapping = {
33
  "builder_name_mapping": {},
@@ -49,12 +813,12 @@ class Registry:
49
 
50
  Usage:
51
 
52
- from lavi.common.registry import registry
53
- from lavi.datasets.base_dataset_builder import BaseDatasetBuilder
54
  """
55
 
56
  def wrap(builder_cls):
57
- from musilingo.datasets.builders.base_dataset_builder import BaseDatasetBuilder
58
 
59
  assert issubclass(
60
  builder_cls, BaseDatasetBuilder
@@ -81,11 +845,11 @@ class Registry:
81
 
82
  Usage:
83
 
84
- from lavi.common.registry import registry
85
  """
86
 
87
  def wrap(task_cls):
88
- from musilingo.tasks.base_task import BaseTask
89
 
90
  assert issubclass(
91
  task_cls, BaseTask
@@ -110,7 +874,7 @@ class Registry:
110
 
111
  Usage:
112
 
113
- from lavi.common.registry import registry
114
  """
115
 
116
  def wrap(model_cls):
@@ -138,11 +902,11 @@ class Registry:
138
 
139
  Usage:
140
 
141
- from lavi.common.registry import registry
142
  """
143
 
144
  def wrap(processor_cls):
145
- from musilingo.processors import BaseProcessor
146
 
147
  assert issubclass(
148
  processor_cls, BaseProcessor
@@ -167,7 +931,7 @@ class Registry:
167
 
168
  Usage:
169
 
170
- from minigpt4.common.registry import registry
171
  """
172
 
173
  def wrap(lr_sched_cls):
@@ -191,7 +955,7 @@ class Registry:
191
 
192
  Usage:
193
 
194
- from minigpt4.common.registry import registry
195
  """
196
 
197
  def wrap(runner_cls):
@@ -215,7 +979,7 @@ class Registry:
215
 
216
  Usage:
217
 
218
- from minigpt4.common.registry import registry
219
  """
220
  assert isinstance(path, str), "All path must be str."
221
  if name in cls.mapping["paths"]:
@@ -231,7 +995,7 @@ class Registry:
231
 
232
  Usage::
233
 
234
- from minigpt4.common.registry import registry
235
 
236
  registry.register("config", {})
237
  """
@@ -340,7 +1104,7 @@ class Registry:
340
  name: Key which needs to be removed.
341
  Usage::
342
 
343
- from mmf.common.registry import registry
344
 
345
  config = registry.unregister("config")
346
  """
 
3
  import random
4
  import math
5
  import re
6
+ import shutil
7
+ import warnings
8
+ import datetime
9
+ import time
10
+ from collections import defaultdict, deque
11
  from typing import List, Optional, Tuple, Union
12
 
13
  from torch.cuda.amp import autocast as autocast
 
19
  from transformers import Wav2Vec2FeatureExtractor
20
  from omegaconf import OmegaConf
21
 
22
+ from .configuration_musilingo import MusiLingoConfig, PATH
23
  import timm.models.hub as timm_hub
24
 
25
 
 
33
 
34
 
35
 
36
+ def download_url(
37
+ url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
38
+ ) -> None:
39
+ """Download a file from a url and place it in root.
40
+
41
+ Args:
42
+ url (str): URL to download file from
43
+ root (str): Directory to place downloaded file in
44
+ filename (str, optional): Name to save the file under. If None, use the basename of the URL
45
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
46
+ max_redirect_hops (int, optional): Maximum number of redirect hops allowed
47
+ """
48
+ root = os.path.expanduser(root)
49
+ if not filename:
50
+ filename = os.path.basename(url)
51
+ fpath = os.path.join(root, filename)
52
+
53
+ os.makedirs(root, exist_ok=True)
54
+
55
+ # check if file is already present locally
56
+ if check_integrity(fpath, md5):
57
+ print("Using downloaded and verified file: " + fpath)
58
+ return
59
+
60
+ if _is_remote_location_available():
61
+ _download_file_from_remote_location(fpath, url)
62
+ else:
63
+ # expand redirect chain if needed
64
+ url = _get_redirect_url(url, max_hops=max_redirect_hops)
65
+
66
+ # check if file is located on Google Drive
67
+ file_id = _get_google_drive_file_id(url)
68
+ if file_id is not None:
69
+ return download_file_from_google_drive(file_id, root, filename, md5)
70
+
71
+ # download the file
72
+ try:
73
+ print("Downloading " + url + " to " + fpath)
74
+ _urlretrieve(url, fpath)
75
+ except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
76
+ if url[:5] == "https":
77
+ url = url.replace("https:", "http:")
78
+ print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
79
+ _urlretrieve(url, fpath)
80
+ else:
81
+ raise e
82
+
83
+ # check integrity of downloaded file
84
+ if not check_integrity(fpath, md5):
85
+ raise RuntimeError("File not found or corrupted.")
86
+
87
+
88
+
89
+ def load_dataset_config(cfg_path):
90
+ cfg = OmegaConf.load(cfg_path).datasets
91
+ cfg = cfg[list(cfg.keys())[0]]
92
+
93
+ return cfg
94
+
95
+ class SmoothedValue(object):
96
+ """Track a series of values and provide access to smoothed values over a
97
+ window or the global series average.
98
+ """
99
+
100
+ def __init__(self, window_size=20, fmt=None):
101
+ if fmt is None:
102
+ fmt = "{median:.4f} ({global_avg:.4f})"
103
+ self.deque = deque(maxlen=window_size)
104
+ self.total = 0.0
105
+ self.count = 0
106
+ self.fmt = fmt
107
+
108
+ def update(self, value, n=1):
109
+ self.deque.append(value)
110
+ self.count += n
111
+ self.total += value * n
112
+
113
+ def synchronize_between_processes(self):
114
+ """
115
+ Warning: does not synchronize the deque!
116
+ """
117
+ if not is_dist_avail_and_initialized():
118
+ return
119
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
120
+ dist.barrier()
121
+ dist.all_reduce(t)
122
+ t = t.tolist()
123
+ self.count = int(t[0])
124
+ self.total = t[1]
125
+
126
+ @property
127
+ def median(self):
128
+ d = torch.tensor(list(self.deque))
129
+ return d.median().item()
130
+
131
+ @property
132
+ def avg(self):
133
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
134
+ return d.mean().item()
135
+
136
+ @property
137
+ def global_avg(self):
138
+ return self.total / self.count
139
+
140
+ @property
141
+ def max(self):
142
+ return max(self.deque)
143
+
144
+ @property
145
+ def value(self):
146
+ return self.deque[-1]
147
+
148
+ def __str__(self):
149
+ return self.fmt.format(
150
+ median=self.median,
151
+ avg=self.avg,
152
+ global_avg=self.global_avg,
153
+ max=self.max,
154
+ value=self.value,
155
+ )
156
+
157
+
158
+ class MetricLogger(object):
159
+ def __init__(self, delimiter="\t"):
160
+ self.meters = defaultdict(SmoothedValue)
161
+ self.delimiter = delimiter
162
+
163
+ def update(self, **kwargs):
164
+ for k, v in kwargs.items():
165
+ if isinstance(v, torch.Tensor):
166
+ v = v.item()
167
+ assert isinstance(v, (float, int))
168
+ self.meters[k].update(v)
169
+
170
+ def __getattr__(self, attr):
171
+ if attr in self.meters:
172
+ return self.meters[attr]
173
+ if attr in self.__dict__:
174
+ return self.__dict__[attr]
175
+ raise AttributeError(
176
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
177
+ )
178
+
179
+ def __str__(self):
180
+ loss_str = []
181
+ for name, meter in self.meters.items():
182
+ loss_str.append("{}: {}".format(name, str(meter)))
183
+ return self.delimiter.join(loss_str)
184
+
185
+ def global_avg(self):
186
+ loss_str = []
187
+ for name, meter in self.meters.items():
188
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
189
+ return self.delimiter.join(loss_str)
190
+
191
+ def synchronize_between_processes(self):
192
+ for meter in self.meters.values():
193
+ meter.synchronize_between_processes()
194
+
195
+ def add_meter(self, name, meter):
196
+ self.meters[name] = meter
197
+
198
+ def log_every(self, iterable, print_freq, header=None):
199
+ i = 0
200
+ if not header:
201
+ header = ""
202
+ start_time = time.time()
203
+ end = time.time()
204
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
205
+ data_time = SmoothedValue(fmt="{avg:.4f}")
206
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
207
+ log_msg = [
208
+ header,
209
+ "[{0" + space_fmt + "}/{1}]",
210
+ "eta: {eta}",
211
+ "{meters}",
212
+ "time: {time}",
213
+ "data: {data}",
214
+ ]
215
+ if torch.cuda.is_available():
216
+ log_msg.append("max mem: {memory:.0f}")
217
+ log_msg = self.delimiter.join(log_msg)
218
+ MB = 1024.0 * 1024.0
219
+ for obj in iterable:
220
+ data_time.update(time.time() - end)
221
+ yield obj
222
+ iter_time.update(time.time() - end)
223
+ if i % print_freq == 0 or i == len(iterable) - 1:
224
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
225
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
226
+ if torch.cuda.is_available():
227
+ print(
228
+ log_msg.format(
229
+ i,
230
+ len(iterable),
231
+ eta=eta_string,
232
+ meters=str(self),
233
+ time=str(iter_time),
234
+ data=str(data_time),
235
+ memory=torch.cuda.max_memory_allocated() / MB,
236
+ )
237
+ )
238
+ else:
239
+ print(
240
+ log_msg.format(
241
+ i,
242
+ len(iterable),
243
+ eta=eta_string,
244
+ meters=str(self),
245
+ time=str(iter_time),
246
+ data=str(data_time),
247
+ )
248
+ )
249
+ i += 1
250
+ end = time.time()
251
+ total_time = time.time() - start_time
252
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
253
+ print(
254
+ "{} Total time: {} ({:.4f} s / it)".format(
255
+ header, total_time_str, total_time / len(iterable)
256
+ )
257
+ )
258
+
259
+
260
+ def move_to_cuda(sample):
261
+ def _move_to_cuda(tensor):
262
+ return tensor.cuda()
263
+
264
+ return apply_to_sample(_move_to_cuda, sample)
265
+
266
+ def apply_to_sample(f, sample):
267
+ if len(sample) == 0:
268
+ return {}
269
+
270
+ def _apply(x):
271
+ if torch.is_tensor(x):
272
+ return f(x)
273
+ elif isinstance(x, dict):
274
+ return {key: _apply(value) for key, value in x.items()}
275
+ elif isinstance(x, list):
276
+ return [_apply(x) for x in x]
277
+ else:
278
+ return x
279
+
280
+ return _apply(sample)
281
+
282
+ def prepare_sample(samples, cuda_enabled=True):
283
+ if cuda_enabled:
284
+ samples = move_to_cuda(samples)
285
+
286
+ # TODO fp16 support
287
+
288
+ return samples
289
+
290
+ def get_world_size():
291
+ if not is_dist_avail_and_initialized():
292
+ return 1
293
+ return dist.get_world_size()
294
+
295
+ class BaseTask:
296
+ def __init__(self, **kwargs):
297
+ super().__init__()
298
+
299
+ self.inst_id_key = "instance_id"
300
+
301
+ @classmethod
302
+ def setup_task(cls, **kwargs):
303
+ return cls()
304
+
305
+ def build_model(self, cfg):
306
+ model_config = cfg.model_cfg
307
+
308
+ model_cls = registry.get_model_class(model_config.arch)
309
+ return model_cls.from_config(model_config)
310
+
311
+ def build_datasets(self, cfg):
312
+ """
313
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
314
+ Download dataset and annotations automatically if not exist.
315
+
316
+ Args:
317
+ cfg (common.config.Config): _description_
318
+
319
+ Returns:
320
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
321
+ """
322
+
323
+ datasets = dict()
324
+
325
+ datasets_config = cfg.datasets_cfg
326
+
327
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
328
+
329
+ for name in datasets_config:
330
+ dataset_config = datasets_config[name]
331
+
332
+ builder = registry.get_builder_class(name)(dataset_config)
333
+ dataset = builder.build_datasets()
334
+
335
+ dataset['train'].name = name
336
+ if 'sample_ratio' in dataset_config:
337
+ dataset['train'].sample_ratio = dataset_config.sample_ratio
338
+
339
+ datasets[name] = dataset
340
+
341
+ return datasets
342
+
343
+ def train_step(self, model, samples):
344
+ loss = model(samples)["loss"]
345
+ return loss
346
+
347
+ def valid_step(self, model, samples):
348
+ raise NotImplementedError
349
+
350
+ def before_evaluation(self, model, dataset, **kwargs):
351
+ model.before_evaluation(dataset=dataset, task_type=type(self))
352
+
353
+ def after_evaluation(self, **kwargs):
354
+ pass
355
+
356
+ def inference_step(self):
357
+ raise NotImplementedError
358
+
359
+ def evaluation(self, model, data_loader, cuda_enabled=True):
360
+ metric_logger = MetricLogger(delimiter=" ")
361
+ header = "Evaluation"
362
+ # TODO make it configurable
363
+ print_freq = 10
364
+
365
+ results = []
366
+
367
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
368
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
369
+
370
+ eval_output = self.valid_step(model=model, samples=samples)
371
+ results.extend(eval_output)
372
+
373
+ if is_dist_avail_and_initialized():
374
+ dist.barrier()
375
+
376
+ return results
377
+
378
+ def train_epoch(
379
+ self,
380
+ epoch,
381
+ model,
382
+ data_loader,
383
+ optimizer,
384
+ lr_scheduler,
385
+ scaler=None,
386
+ cuda_enabled=False,
387
+ log_freq=50,
388
+ accum_grad_iters=1,
389
+ ):
390
+ return self._train_inner_loop(
391
+ epoch=epoch,
392
+ iters_per_epoch=lr_scheduler.iters_per_epoch,
393
+ model=model,
394
+ data_loader=data_loader,
395
+ optimizer=optimizer,
396
+ scaler=scaler,
397
+ lr_scheduler=lr_scheduler,
398
+ log_freq=log_freq,
399
+ cuda_enabled=cuda_enabled,
400
+ accum_grad_iters=accum_grad_iters,
401
+ )
402
+
403
+ def train_iters(
404
+ self,
405
+ epoch,
406
+ start_iters,
407
+ iters_per_inner_epoch,
408
+ model,
409
+ data_loader,
410
+ optimizer,
411
+ lr_scheduler,
412
+ scaler=None,
413
+ cuda_enabled=False,
414
+ log_freq=50,
415
+ accum_grad_iters=1,
416
+ ):
417
+ return self._train_inner_loop(
418
+ epoch=epoch,
419
+ start_iters=start_iters,
420
+ iters_per_epoch=iters_per_inner_epoch,
421
+ model=model,
422
+ data_loader=data_loader,
423
+ optimizer=optimizer,
424
+ scaler=scaler,
425
+ lr_scheduler=lr_scheduler,
426
+ log_freq=log_freq,
427
+ cuda_enabled=cuda_enabled,
428
+ accum_grad_iters=accum_grad_iters,
429
+ )
430
+
431
+ def _train_inner_loop(
432
+ self,
433
+ epoch,
434
+ iters_per_epoch,
435
+ model,
436
+ data_loader,
437
+ optimizer,
438
+ lr_scheduler,
439
+ scaler=None,
440
+ start_iters=None,
441
+ log_freq=50,
442
+ cuda_enabled=False,
443
+ accum_grad_iters=1,
444
+ ):
445
+ """
446
+ An inner training loop compatible with both epoch-based and iter-based training.
447
+
448
+ When using epoch-based, training stops after one epoch; when using iter-based,
449
+ training stops after #iters_per_epoch iterations.
450
+ """
451
+ use_amp = scaler is not None
452
+
453
+ if not hasattr(data_loader, "__next__"):
454
+ # convert to iterator if not already
455
+ data_loader = iter(data_loader)
456
+
457
+ metric_logger = MetricLogger(delimiter=" ")
458
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
459
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
460
+
461
+ # if iter-based runner, schedule lr based on inner epoch.
462
+ logging.info(
463
+ "Start training epoch {}, {} iters per inner epoch.".format(
464
+ epoch, iters_per_epoch
465
+ )
466
+ )
467
+ header = "Train: data epoch: [{}]".format(epoch)
468
+ if start_iters is None:
469
+ # epoch-based runner
470
+ inner_epoch = epoch
471
+ else:
472
+ # In iter-based runner, we schedule the learning rate based on iterations.
473
+ inner_epoch = start_iters // iters_per_epoch
474
+ header = header + "; inner epoch [{}]".format(inner_epoch)
475
+
476
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
477
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
478
+ if i >= iters_per_epoch:
479
+ break
480
+
481
+ samples = next(data_loader)
482
+
483
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
484
+ samples.update(
485
+ {
486
+ "epoch": inner_epoch,
487
+ "num_iters_per_epoch": iters_per_epoch,
488
+ "iters": i,
489
+ }
490
+ )
491
+
492
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
493
+
494
+ with torch.cuda.amp.autocast(enabled=use_amp):
495
+ loss = self.train_step(model=model, samples=samples)
496
+
497
+ # after_train_step()
498
+ if use_amp:
499
+ scaler.scale(loss).backward()
500
+ else:
501
+ loss.backward()
502
+
503
+ # update gradients every accum_grad_iters iterations
504
+ if (i + 1) % accum_grad_iters == 0:
505
+ if use_amp:
506
+ scaler.step(optimizer)
507
+ scaler.update()
508
+ else:
509
+ optimizer.step()
510
+ optimizer.zero_grad()
511
+
512
+ metric_logger.update(loss=loss.item())
513
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
514
+
515
+ # after train_epoch()
516
+ # gather the stats from all processes
517
+ metric_logger.synchronize_between_processes()
518
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
519
+ return {
520
+ k: "{:.3f}".format(meter.global_avg)
521
+ for k, meter in metric_logger.meters.items()
522
+ }
523
+
524
+ @staticmethod
525
+ def save_result(result, result_dir, filename, remove_duplicate=""):
526
+ import json
527
+
528
+ result_file = os.path.join(
529
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
530
+ )
531
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
532
+
533
+ json.dump(result, open(result_file, "w"))
534
+
535
+ if is_dist_avail_and_initialized():
536
+ dist.barrier()
537
+
538
+ if is_main_process():
539
+ logging.warning("rank %d starts merging results." % get_rank())
540
+ # combine results from all processes
541
+ result = []
542
+
543
+ for rank in range(get_world_size()):
544
+ result_file = os.path.join(
545
+ result_dir, "%s_rank%d.json" % (filename, rank)
546
+ )
547
+ res = json.load(open(result_file, "r"))
548
+ result += res
549
+
550
+ if remove_duplicate:
551
+ result_new = []
552
+ id_list = []
553
+ for res in result:
554
+ if res[remove_duplicate] not in id_list:
555
+ id_list.append(res[remove_duplicate])
556
+ result_new.append(res)
557
+ result = result_new
558
+
559
+ json.dump(result, open(final_result_file, "w"))
560
+ print("result file saved to %s" % final_result_file)
561
+
562
+ return final_result_file
563
+
564
+
565
+ class BaseProcessor:
566
+ def __init__(self):
567
+ self.transform = lambda x: x
568
+ return
569
+
570
+ def __call__(self, item):
571
+ return self.transform(item)
572
+
573
+ @classmethod
574
+ def from_config(cls, cfg=None):
575
+ return cls()
576
+
577
+ def build(self, **kwargs):
578
+ cfg = OmegaConf.create(kwargs)
579
+
580
+ return self.from_config(cfg)
581
+
582
+ def get_cache_path(rel_path):
583
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
584
+
585
+
586
+ class BaseDatasetBuilder:
587
+ train_dataset_cls, eval_dataset_cls = None, None
588
+
589
+ def __init__(self, cfg=None):
590
+ super().__init__()
591
+
592
+ if cfg is None:
593
+ # help to create datasets from default config.
594
+ self.config = load_dataset_config(self.default_config_path())
595
+ elif isinstance(cfg, str):
596
+ self.config = load_dataset_config(cfg)
597
+ else:
598
+ # when called from task.build_dataset()
599
+ self.config = cfg
600
+
601
+ self.data_type = self.config.data_type
602
+
603
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
604
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
605
+
606
+ def build_datasets(self):
607
+ # download, split, etc...
608
+ # only called on 1 GPU/TPU in distributed
609
+
610
+ if is_main_process():
611
+ self._download_data()
612
+
613
+ if is_dist_avail_and_initialized():
614
+ dist.barrier()
615
+
616
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
617
+ logging.info("Building datasets...")
618
+ datasets = self.build() # dataset['train'/'val'/'test']
619
+
620
+ return datasets
621
+
622
+ def build_processors(self):
623
+ vis_proc_cfg = self.config.get("vis_processor")
624
+ txt_proc_cfg = self.config.get("text_processor")
625
+
626
+ if vis_proc_cfg is not None:
627
+ vis_train_cfg = vis_proc_cfg.get("train")
628
+ vis_eval_cfg = vis_proc_cfg.get("eval")
629
+
630
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
631
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
632
+
633
+ if txt_proc_cfg is not None:
634
+ txt_train_cfg = txt_proc_cfg.get("train")
635
+ txt_eval_cfg = txt_proc_cfg.get("eval")
636
+
637
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
638
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
639
+
640
+ @staticmethod
641
+ def _build_proc_from_cfg(cfg):
642
+ return (
643
+ registry.get_processor_class(cfg.name).from_config(cfg)
644
+ if cfg is not None
645
+ else None
646
+ )
647
+
648
+ @classmethod
649
+ def default_config_path(cls, type="default"):
650
+ return get_abs_path(cls.DATASET_CONFIG_DICT[type])
651
+
652
+ def _download_data(self):
653
+ self._download_ann()
654
+ self._download_vis()
655
+
656
+ def _download_ann(self):
657
+ """
658
+ Download annotation files if necessary.
659
+ All the vision-language datasets should have annotations of unified format.
660
+
661
+ storage_path can be:
662
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
663
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
664
+
665
+ Local annotation paths should be relative.
666
+ """
667
+ anns = self.config.build_info.annotations
668
+
669
+ splits = anns.keys()
670
+
671
+ cache_root = registry.get_path("cache_root")
672
+
673
+ for split in splits:
674
+ info = anns[split]
675
+
676
+ urls, storage_paths = info.get("url", None), info.storage
677
+
678
+ if isinstance(urls, str):
679
+ urls = [urls]
680
+ if isinstance(storage_paths, str):
681
+ storage_paths = [storage_paths]
682
+
683
+ assert len(urls) == len(storage_paths)
684
+
685
+ for url_or_filename, storage_path in zip(urls, storage_paths):
686
+ # if storage_path is relative, make it full by prefixing with cache_root.
687
+ if not os.path.isabs(storage_path):
688
+ storage_path = os.path.join(cache_root, storage_path)
689
+
690
+ dirname = os.path.dirname(storage_path)
691
+ if not os.path.exists(dirname):
692
+ os.makedirs(dirname)
693
+
694
+ if os.path.isfile(url_or_filename):
695
+ src, dst = url_or_filename, storage_path
696
+ if not os.path.exists(dst):
697
+ shutil.copyfile(src=src, dst=dst)
698
+ else:
699
+ logging.info("Using existing file {}.".format(dst))
700
+ else:
701
+ if os.path.isdir(storage_path):
702
+ # if only dirname is provided, suffix with basename of URL.
703
+ raise ValueError(
704
+ "Expecting storage_path to be a file path, got directory {}".format(
705
+ storage_path
706
+ )
707
+ )
708
+ else:
709
+ filename = os.path.basename(storage_path)
710
+
711
+ download_url(url=url_or_filename, root=dirname, filename=filename)
712
+
713
+ def _download_vis(self):
714
+
715
+ storage_path = self.config.build_info.get(self.data_type).storage
716
+ storage_path = get_cache_path(storage_path)
717
+
718
+ if not os.path.exists(storage_path):
719
+ warnings.warn(
720
+ f"""
721
+ The specified path {storage_path} for visual inputs does not exist.
722
+ Please provide a correct path to the visual inputs or
723
+ refer to datasets/download_scripts/README.md for downloading instructions.
724
+ """
725
+ )
726
+
727
+ def build(self):
728
+ """
729
+ Create by split datasets inheriting torch.utils.data.Datasets.
730
+
731
+ # build() can be dataset-specific. Overwrite to customize.
732
+ """
733
+ self.build_processors()
734
+
735
+ build_info = self.config.build_info
736
+
737
+ ann_info = build_info.annotations
738
+ vis_info = build_info.get(self.data_type)
739
+
740
+ datasets = dict()
741
+ for split in ann_info.keys():
742
+ if split not in ["train", "val", "test"]:
743
+ continue
744
+
745
+ is_train = split == "train"
746
+
747
+ # processors
748
+ vis_processor = (
749
+ self.vis_processors["train"]
750
+ if is_train
751
+ else self.vis_processors["eval"]
752
+ )
753
+ text_processor = (
754
+ self.text_processors["train"]
755
+ if is_train
756
+ else self.text_processors["eval"]
757
+ )
758
+
759
+ # annotation path
760
+ ann_paths = ann_info.get(split).storage
761
+ if isinstance(ann_paths, str):
762
+ ann_paths = [ann_paths]
763
+
764
+ abs_ann_paths = []
765
+ for ann_path in ann_paths:
766
+ if not os.path.isabs(ann_path):
767
+ ann_path = get_cache_path(ann_path)
768
+ abs_ann_paths.append(ann_path)
769
+ ann_paths = abs_ann_paths
770
+
771
+ # visual data storage path
772
+ vis_path = os.path.join(vis_info.storage, split)
773
+
774
+ if not os.path.isabs(vis_path):
775
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
776
+ vis_path = get_cache_path(vis_path)
777
+
778
+ if not os.path.exists(vis_path):
779
+ warnings.warn("storage path {} does not exist.".format(vis_path))
780
+
781
+ # create datasets
782
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
783
+ datasets[split] = dataset_cls(
784
+ vis_processor=vis_processor,
785
+ text_processor=text_processor,
786
+ ann_paths=ann_paths,
787
+ vis_root=vis_path,
788
+ )
789
+
790
+ return datasets
791
+
792
+
793
+
794
+
795
  class Registry:
796
  mapping = {
797
  "builder_name_mapping": {},
 
813
 
814
  Usage:
815
 
816
+ # from lavi.common.registry import registry
817
+ # from lavi.datasets.base_dataset_builder import BaseDatasetBuilder
818
  """
819
 
820
  def wrap(builder_cls):
821
+ # from musilingo.datasets.builders.base_dataset_builder import BaseDatasetBuilder
822
 
823
  assert issubclass(
824
  builder_cls, BaseDatasetBuilder
 
845
 
846
  Usage:
847
 
848
+ # from lavi.common.registry import registry
849
  """
850
 
851
  def wrap(task_cls):
852
+ # from musilingo.tasks.base_task import BaseTask
853
 
854
  assert issubclass(
855
  task_cls, BaseTask
 
874
 
875
  Usage:
876
 
877
+ # from lavi.common.registry import registry
878
  """
879
 
880
  def wrap(model_cls):
 
902
 
903
  Usage:
904
 
905
+ # from lavi.common.registry import registry
906
  """
907
 
908
  def wrap(processor_cls):
909
+ # from musilingo.processors import BaseProcessor
910
 
911
  assert issubclass(
912
  processor_cls, BaseProcessor
 
931
 
932
  Usage:
933
 
934
+ # from minigpt4.common.registry import registry
935
  """
936
 
937
  def wrap(lr_sched_cls):
 
955
 
956
  Usage:
957
 
958
+ # from minigpt4.common.registry import registry
959
  """
960
 
961
  def wrap(runner_cls):
 
979
 
980
  Usage:
981
 
982
+ # from minigpt4.common.registry import registry
983
  """
984
  assert isinstance(path, str), "All path must be str."
985
  if name in cls.mapping["paths"]:
 
995
 
996
  Usage::
997
 
998
+ # from minigpt4.common.registry import registry
999
 
1000
  registry.register("config", {})
1001
  """
 
1104
  name: Key which needs to be removed.
1105
  Usage::
1106
 
1107
+ # from mmf.common.registry import registry
1108
 
1109
  config = registry.unregister("config")
1110
  """