nicolaus625 commited on
Commit
5174e76
1 Parent(s): 5295899

Upload model

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. modelling_musilingo.py +765 -3
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- license: cc-by-4.0
3
  language:
4
  - en
 
5
  tags:
6
  - music
7
  - art
 
1
  ---
 
2
  language:
3
  - en
4
+ license: cc-by-4.0
5
  tags:
6
  - music
7
  - art
modelling_musilingo.py CHANGED
@@ -3,6 +3,11 @@ import os
3
  import random
4
  import math
5
  import re
 
 
 
 
 
6
  from typing import List, Optional, Tuple, Union
7
 
8
  from torch.cuda.amp import autocast as autocast
@@ -28,6 +33,763 @@ from transformers import PreTrainedModel
28
 
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class Registry:
32
  mapping = {
33
  "builder_name_mapping": {},
@@ -54,7 +816,7 @@ class Registry:
54
  """
55
 
56
  def wrap(builder_cls):
57
- from musilingo.datasets.builders.base_dataset_builder import BaseDatasetBuilder
58
 
59
  assert issubclass(
60
  builder_cls, BaseDatasetBuilder
@@ -85,7 +847,7 @@ class Registry:
85
  """
86
 
87
  def wrap(task_cls):
88
- from musilingo.tasks.base_task import BaseTask
89
 
90
  assert issubclass(
91
  task_cls, BaseTask
@@ -142,7 +904,7 @@ class Registry:
142
  """
143
 
144
  def wrap(processor_cls):
145
- from musilingo.processors import BaseProcessor
146
 
147
  assert issubclass(
148
  processor_cls, BaseProcessor
 
3
  import random
4
  import math
5
  import re
6
+ import shutil
7
+ import warnings
8
+ import datetime
9
+ import time
10
+ from collections import defaultdict, deque
11
  from typing import List, Optional, Tuple, Union
12
 
13
  from torch.cuda.amp import autocast as autocast
 
33
 
34
 
35
 
36
+ def download_url(
37
+ url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
38
+ ) -> None:
39
+ """Download a file from a url and place it in root.
40
+
41
+ Args:
42
+ url (str): URL to download file from
43
+ root (str): Directory to place downloaded file in
44
+ filename (str, optional): Name to save the file under. If None, use the basename of the URL
45
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
46
+ max_redirect_hops (int, optional): Maximum number of redirect hops allowed
47
+ """
48
+ root = os.path.expanduser(root)
49
+ if not filename:
50
+ filename = os.path.basename(url)
51
+ fpath = os.path.join(root, filename)
52
+
53
+ os.makedirs(root, exist_ok=True)
54
+
55
+ # check if file is already present locally
56
+ if check_integrity(fpath, md5):
57
+ print("Using downloaded and verified file: " + fpath)
58
+ return
59
+
60
+ if _is_remote_location_available():
61
+ _download_file_from_remote_location(fpath, url)
62
+ else:
63
+ # expand redirect chain if needed
64
+ url = _get_redirect_url(url, max_hops=max_redirect_hops)
65
+
66
+ # check if file is located on Google Drive
67
+ file_id = _get_google_drive_file_id(url)
68
+ if file_id is not None:
69
+ return download_file_from_google_drive(file_id, root, filename, md5)
70
+
71
+ # download the file
72
+ try:
73
+ print("Downloading " + url + " to " + fpath)
74
+ _urlretrieve(url, fpath)
75
+ except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
76
+ if url[:5] == "https":
77
+ url = url.replace("https:", "http:")
78
+ print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
79
+ _urlretrieve(url, fpath)
80
+ else:
81
+ raise e
82
+
83
+ # check integrity of downloaded file
84
+ if not check_integrity(fpath, md5):
85
+ raise RuntimeError("File not found or corrupted.")
86
+
87
+
88
+
89
+ def load_dataset_config(cfg_path):
90
+ cfg = OmegaConf.load(cfg_path).datasets
91
+ cfg = cfg[list(cfg.keys())[0]]
92
+
93
+ return cfg
94
+
95
+ class SmoothedValue(object):
96
+ """Track a series of values and provide access to smoothed values over a
97
+ window or the global series average.
98
+ """
99
+
100
+ def __init__(self, window_size=20, fmt=None):
101
+ if fmt is None:
102
+ fmt = "{median:.4f} ({global_avg:.4f})"
103
+ self.deque = deque(maxlen=window_size)
104
+ self.total = 0.0
105
+ self.count = 0
106
+ self.fmt = fmt
107
+
108
+ def update(self, value, n=1):
109
+ self.deque.append(value)
110
+ self.count += n
111
+ self.total += value * n
112
+
113
+ def synchronize_between_processes(self):
114
+ """
115
+ Warning: does not synchronize the deque!
116
+ """
117
+ if not is_dist_avail_and_initialized():
118
+ return
119
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
120
+ dist.barrier()
121
+ dist.all_reduce(t)
122
+ t = t.tolist()
123
+ self.count = int(t[0])
124
+ self.total = t[1]
125
+
126
+ @property
127
+ def median(self):
128
+ d = torch.tensor(list(self.deque))
129
+ return d.median().item()
130
+
131
+ @property
132
+ def avg(self):
133
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
134
+ return d.mean().item()
135
+
136
+ @property
137
+ def global_avg(self):
138
+ return self.total / self.count
139
+
140
+ @property
141
+ def max(self):
142
+ return max(self.deque)
143
+
144
+ @property
145
+ def value(self):
146
+ return self.deque[-1]
147
+
148
+ def __str__(self):
149
+ return self.fmt.format(
150
+ median=self.median,
151
+ avg=self.avg,
152
+ global_avg=self.global_avg,
153
+ max=self.max,
154
+ value=self.value,
155
+ )
156
+
157
+
158
+ class MetricLogger(object):
159
+ def __init__(self, delimiter="\t"):
160
+ self.meters = defaultdict(SmoothedValue)
161
+ self.delimiter = delimiter
162
+
163
+ def update(self, **kwargs):
164
+ for k, v in kwargs.items():
165
+ if isinstance(v, torch.Tensor):
166
+ v = v.item()
167
+ assert isinstance(v, (float, int))
168
+ self.meters[k].update(v)
169
+
170
+ def __getattr__(self, attr):
171
+ if attr in self.meters:
172
+ return self.meters[attr]
173
+ if attr in self.__dict__:
174
+ return self.__dict__[attr]
175
+ raise AttributeError(
176
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
177
+ )
178
+
179
+ def __str__(self):
180
+ loss_str = []
181
+ for name, meter in self.meters.items():
182
+ loss_str.append("{}: {}".format(name, str(meter)))
183
+ return self.delimiter.join(loss_str)
184
+
185
+ def global_avg(self):
186
+ loss_str = []
187
+ for name, meter in self.meters.items():
188
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
189
+ return self.delimiter.join(loss_str)
190
+
191
+ def synchronize_between_processes(self):
192
+ for meter in self.meters.values():
193
+ meter.synchronize_between_processes()
194
+
195
+ def add_meter(self, name, meter):
196
+ self.meters[name] = meter
197
+
198
+ def log_every(self, iterable, print_freq, header=None):
199
+ i = 0
200
+ if not header:
201
+ header = ""
202
+ start_time = time.time()
203
+ end = time.time()
204
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
205
+ data_time = SmoothedValue(fmt="{avg:.4f}")
206
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
207
+ log_msg = [
208
+ header,
209
+ "[{0" + space_fmt + "}/{1}]",
210
+ "eta: {eta}",
211
+ "{meters}",
212
+ "time: {time}",
213
+ "data: {data}",
214
+ ]
215
+ if torch.cuda.is_available():
216
+ log_msg.append("max mem: {memory:.0f}")
217
+ log_msg = self.delimiter.join(log_msg)
218
+ MB = 1024.0 * 1024.0
219
+ for obj in iterable:
220
+ data_time.update(time.time() - end)
221
+ yield obj
222
+ iter_time.update(time.time() - end)
223
+ if i % print_freq == 0 or i == len(iterable) - 1:
224
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
225
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
226
+ if torch.cuda.is_available():
227
+ print(
228
+ log_msg.format(
229
+ i,
230
+ len(iterable),
231
+ eta=eta_string,
232
+ meters=str(self),
233
+ time=str(iter_time),
234
+ data=str(data_time),
235
+ memory=torch.cuda.max_memory_allocated() / MB,
236
+ )
237
+ )
238
+ else:
239
+ print(
240
+ log_msg.format(
241
+ i,
242
+ len(iterable),
243
+ eta=eta_string,
244
+ meters=str(self),
245
+ time=str(iter_time),
246
+ data=str(data_time),
247
+ )
248
+ )
249
+ i += 1
250
+ end = time.time()
251
+ total_time = time.time() - start_time
252
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
253
+ print(
254
+ "{} Total time: {} ({:.4f} s / it)".format(
255
+ header, total_time_str, total_time / len(iterable)
256
+ )
257
+ )
258
+
259
+
260
+ def move_to_cuda(sample):
261
+ def _move_to_cuda(tensor):
262
+ return tensor.cuda()
263
+
264
+ return apply_to_sample(_move_to_cuda, sample)
265
+
266
+ def apply_to_sample(f, sample):
267
+ if len(sample) == 0:
268
+ return {}
269
+
270
+ def _apply(x):
271
+ if torch.is_tensor(x):
272
+ return f(x)
273
+ elif isinstance(x, dict):
274
+ return {key: _apply(value) for key, value in x.items()}
275
+ elif isinstance(x, list):
276
+ return [_apply(x) for x in x]
277
+ else:
278
+ return x
279
+
280
+ return _apply(sample)
281
+
282
+ def prepare_sample(samples, cuda_enabled=True):
283
+ if cuda_enabled:
284
+ samples = move_to_cuda(samples)
285
+
286
+ # TODO fp16 support
287
+
288
+ return samples
289
+
290
+ def get_world_size():
291
+ if not is_dist_avail_and_initialized():
292
+ return 1
293
+ return dist.get_world_size()
294
+
295
+ class BaseTask:
296
+ def __init__(self, **kwargs):
297
+ super().__init__()
298
+
299
+ self.inst_id_key = "instance_id"
300
+
301
+ @classmethod
302
+ def setup_task(cls, **kwargs):
303
+ return cls()
304
+
305
+ def build_model(self, cfg):
306
+ model_config = cfg.model_cfg
307
+
308
+ model_cls = registry.get_model_class(model_config.arch)
309
+ return model_cls.from_config(model_config)
310
+
311
+ def build_datasets(self, cfg):
312
+ """
313
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
314
+ Download dataset and annotations automatically if not exist.
315
+
316
+ Args:
317
+ cfg (common.config.Config): _description_
318
+
319
+ Returns:
320
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
321
+ """
322
+
323
+ datasets = dict()
324
+
325
+ datasets_config = cfg.datasets_cfg
326
+
327
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
328
+
329
+ for name in datasets_config:
330
+ dataset_config = datasets_config[name]
331
+
332
+ builder = registry.get_builder_class(name)(dataset_config)
333
+ dataset = builder.build_datasets()
334
+
335
+ dataset['train'].name = name
336
+ if 'sample_ratio' in dataset_config:
337
+ dataset['train'].sample_ratio = dataset_config.sample_ratio
338
+
339
+ datasets[name] = dataset
340
+
341
+ return datasets
342
+
343
+ def train_step(self, model, samples):
344
+ loss = model(samples)["loss"]
345
+ return loss
346
+
347
+ def valid_step(self, model, samples):
348
+ raise NotImplementedError
349
+
350
+ def before_evaluation(self, model, dataset, **kwargs):
351
+ model.before_evaluation(dataset=dataset, task_type=type(self))
352
+
353
+ def after_evaluation(self, **kwargs):
354
+ pass
355
+
356
+ def inference_step(self):
357
+ raise NotImplementedError
358
+
359
+ def evaluation(self, model, data_loader, cuda_enabled=True):
360
+ metric_logger = MetricLogger(delimiter=" ")
361
+ header = "Evaluation"
362
+ # TODO make it configurable
363
+ print_freq = 10
364
+
365
+ results = []
366
+
367
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
368
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
369
+
370
+ eval_output = self.valid_step(model=model, samples=samples)
371
+ results.extend(eval_output)
372
+
373
+ if is_dist_avail_and_initialized():
374
+ dist.barrier()
375
+
376
+ return results
377
+
378
+ def train_epoch(
379
+ self,
380
+ epoch,
381
+ model,
382
+ data_loader,
383
+ optimizer,
384
+ lr_scheduler,
385
+ scaler=None,
386
+ cuda_enabled=False,
387
+ log_freq=50,
388
+ accum_grad_iters=1,
389
+ ):
390
+ return self._train_inner_loop(
391
+ epoch=epoch,
392
+ iters_per_epoch=lr_scheduler.iters_per_epoch,
393
+ model=model,
394
+ data_loader=data_loader,
395
+ optimizer=optimizer,
396
+ scaler=scaler,
397
+ lr_scheduler=lr_scheduler,
398
+ log_freq=log_freq,
399
+ cuda_enabled=cuda_enabled,
400
+ accum_grad_iters=accum_grad_iters,
401
+ )
402
+
403
+ def train_iters(
404
+ self,
405
+ epoch,
406
+ start_iters,
407
+ iters_per_inner_epoch,
408
+ model,
409
+ data_loader,
410
+ optimizer,
411
+ lr_scheduler,
412
+ scaler=None,
413
+ cuda_enabled=False,
414
+ log_freq=50,
415
+ accum_grad_iters=1,
416
+ ):
417
+ return self._train_inner_loop(
418
+ epoch=epoch,
419
+ start_iters=start_iters,
420
+ iters_per_epoch=iters_per_inner_epoch,
421
+ model=model,
422
+ data_loader=data_loader,
423
+ optimizer=optimizer,
424
+ scaler=scaler,
425
+ lr_scheduler=lr_scheduler,
426
+ log_freq=log_freq,
427
+ cuda_enabled=cuda_enabled,
428
+ accum_grad_iters=accum_grad_iters,
429
+ )
430
+
431
+ def _train_inner_loop(
432
+ self,
433
+ epoch,
434
+ iters_per_epoch,
435
+ model,
436
+ data_loader,
437
+ optimizer,
438
+ lr_scheduler,
439
+ scaler=None,
440
+ start_iters=None,
441
+ log_freq=50,
442
+ cuda_enabled=False,
443
+ accum_grad_iters=1,
444
+ ):
445
+ """
446
+ An inner training loop compatible with both epoch-based and iter-based training.
447
+
448
+ When using epoch-based, training stops after one epoch; when using iter-based,
449
+ training stops after #iters_per_epoch iterations.
450
+ """
451
+ use_amp = scaler is not None
452
+
453
+ if not hasattr(data_loader, "__next__"):
454
+ # convert to iterator if not already
455
+ data_loader = iter(data_loader)
456
+
457
+ metric_logger = MetricLogger(delimiter=" ")
458
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
459
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
460
+
461
+ # if iter-based runner, schedule lr based on inner epoch.
462
+ logging.info(
463
+ "Start training epoch {}, {} iters per inner epoch.".format(
464
+ epoch, iters_per_epoch
465
+ )
466
+ )
467
+ header = "Train: data epoch: [{}]".format(epoch)
468
+ if start_iters is None:
469
+ # epoch-based runner
470
+ inner_epoch = epoch
471
+ else:
472
+ # In iter-based runner, we schedule the learning rate based on iterations.
473
+ inner_epoch = start_iters // iters_per_epoch
474
+ header = header + "; inner epoch [{}]".format(inner_epoch)
475
+
476
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
477
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
478
+ if i >= iters_per_epoch:
479
+ break
480
+
481
+ samples = next(data_loader)
482
+
483
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
484
+ samples.update(
485
+ {
486
+ "epoch": inner_epoch,
487
+ "num_iters_per_epoch": iters_per_epoch,
488
+ "iters": i,
489
+ }
490
+ )
491
+
492
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
493
+
494
+ with torch.cuda.amp.autocast(enabled=use_amp):
495
+ loss = self.train_step(model=model, samples=samples)
496
+
497
+ # after_train_step()
498
+ if use_amp:
499
+ scaler.scale(loss).backward()
500
+ else:
501
+ loss.backward()
502
+
503
+ # update gradients every accum_grad_iters iterations
504
+ if (i + 1) % accum_grad_iters == 0:
505
+ if use_amp:
506
+ scaler.step(optimizer)
507
+ scaler.update()
508
+ else:
509
+ optimizer.step()
510
+ optimizer.zero_grad()
511
+
512
+ metric_logger.update(loss=loss.item())
513
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
514
+
515
+ # after train_epoch()
516
+ # gather the stats from all processes
517
+ metric_logger.synchronize_between_processes()
518
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
519
+ return {
520
+ k: "{:.3f}".format(meter.global_avg)
521
+ for k, meter in metric_logger.meters.items()
522
+ }
523
+
524
+ @staticmethod
525
+ def save_result(result, result_dir, filename, remove_duplicate=""):
526
+ import json
527
+
528
+ result_file = os.path.join(
529
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
530
+ )
531
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
532
+
533
+ json.dump(result, open(result_file, "w"))
534
+
535
+ if is_dist_avail_and_initialized():
536
+ dist.barrier()
537
+
538
+ if is_main_process():
539
+ logging.warning("rank %d starts merging results." % get_rank())
540
+ # combine results from all processes
541
+ result = []
542
+
543
+ for rank in range(get_world_size()):
544
+ result_file = os.path.join(
545
+ result_dir, "%s_rank%d.json" % (filename, rank)
546
+ )
547
+ res = json.load(open(result_file, "r"))
548
+ result += res
549
+
550
+ if remove_duplicate:
551
+ result_new = []
552
+ id_list = []
553
+ for res in result:
554
+ if res[remove_duplicate] not in id_list:
555
+ id_list.append(res[remove_duplicate])
556
+ result_new.append(res)
557
+ result = result_new
558
+
559
+ json.dump(result, open(final_result_file, "w"))
560
+ print("result file saved to %s" % final_result_file)
561
+
562
+ return final_result_file
563
+
564
+
565
+ class BaseProcessor:
566
+ def __init__(self):
567
+ self.transform = lambda x: x
568
+ return
569
+
570
+ def __call__(self, item):
571
+ return self.transform(item)
572
+
573
+ @classmethod
574
+ def from_config(cls, cfg=None):
575
+ return cls()
576
+
577
+ def build(self, **kwargs):
578
+ cfg = OmegaConf.create(kwargs)
579
+
580
+ return self.from_config(cfg)
581
+
582
+
583
+
584
+ class BaseDatasetBuilder:
585
+ train_dataset_cls, eval_dataset_cls = None, None
586
+
587
+ def __init__(self, cfg=None):
588
+ super().__init__()
589
+
590
+ if cfg is None:
591
+ # help to create datasets from default config.
592
+ self.config = load_dataset_config(self.default_config_path())
593
+ elif isinstance(cfg, str):
594
+ self.config = load_dataset_config(cfg)
595
+ else:
596
+ # when called from task.build_dataset()
597
+ self.config = cfg
598
+
599
+ self.data_type = self.config.data_type
600
+
601
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
602
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
603
+
604
+ def build_datasets(self):
605
+ # download, split, etc...
606
+ # only called on 1 GPU/TPU in distributed
607
+
608
+ if is_main_process():
609
+ self._download_data()
610
+
611
+ if is_dist_avail_and_initialized():
612
+ dist.barrier()
613
+
614
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
615
+ logging.info("Building datasets...")
616
+ datasets = self.build() # dataset['train'/'val'/'test']
617
+
618
+ return datasets
619
+
620
+ def build_processors(self):
621
+ vis_proc_cfg = self.config.get("vis_processor")
622
+ txt_proc_cfg = self.config.get("text_processor")
623
+
624
+ if vis_proc_cfg is not None:
625
+ vis_train_cfg = vis_proc_cfg.get("train")
626
+ vis_eval_cfg = vis_proc_cfg.get("eval")
627
+
628
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
629
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
630
+
631
+ if txt_proc_cfg is not None:
632
+ txt_train_cfg = txt_proc_cfg.get("train")
633
+ txt_eval_cfg = txt_proc_cfg.get("eval")
634
+
635
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
636
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
637
+
638
+ @staticmethod
639
+ def _build_proc_from_cfg(cfg):
640
+ return (
641
+ registry.get_processor_class(cfg.name).from_config(cfg)
642
+ if cfg is not None
643
+ else None
644
+ )
645
+
646
+ @classmethod
647
+ def default_config_path(cls, type="default"):
648
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
649
+
650
+ def _download_data(self):
651
+ self._download_ann()
652
+ self._download_vis()
653
+
654
+ def _download_ann(self):
655
+ """
656
+ Download annotation files if necessary.
657
+ All the vision-language datasets should have annotations of unified format.
658
+
659
+ storage_path can be:
660
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
661
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
662
+
663
+ Local annotation paths should be relative.
664
+ """
665
+ anns = self.config.build_info.annotations
666
+
667
+ splits = anns.keys()
668
+
669
+ cache_root = registry.get_path("cache_root")
670
+
671
+ for split in splits:
672
+ info = anns[split]
673
+
674
+ urls, storage_paths = info.get("url", None), info.storage
675
+
676
+ if isinstance(urls, str):
677
+ urls = [urls]
678
+ if isinstance(storage_paths, str):
679
+ storage_paths = [storage_paths]
680
+
681
+ assert len(urls) == len(storage_paths)
682
+
683
+ for url_or_filename, storage_path in zip(urls, storage_paths):
684
+ # if storage_path is relative, make it full by prefixing with cache_root.
685
+ if not os.path.isabs(storage_path):
686
+ storage_path = os.path.join(cache_root, storage_path)
687
+
688
+ dirname = os.path.dirname(storage_path)
689
+ if not os.path.exists(dirname):
690
+ os.makedirs(dirname)
691
+
692
+ if os.path.isfile(url_or_filename):
693
+ src, dst = url_or_filename, storage_path
694
+ if not os.path.exists(dst):
695
+ shutil.copyfile(src=src, dst=dst)
696
+ else:
697
+ logging.info("Using existing file {}.".format(dst))
698
+ else:
699
+ if os.path.isdir(storage_path):
700
+ # if only dirname is provided, suffix with basename of URL.
701
+ raise ValueError(
702
+ "Expecting storage_path to be a file path, got directory {}".format(
703
+ storage_path
704
+ )
705
+ )
706
+ else:
707
+ filename = os.path.basename(storage_path)
708
+
709
+ download_url(url=url_or_filename, root=dirname, filename=filename)
710
+
711
+ def _download_vis(self):
712
+
713
+ storage_path = self.config.build_info.get(self.data_type).storage
714
+ storage_path = utils.get_cache_path(storage_path)
715
+
716
+ if not os.path.exists(storage_path):
717
+ warnings.warn(
718
+ f"""
719
+ The specified path {storage_path} for visual inputs does not exist.
720
+ Please provide a correct path to the visual inputs or
721
+ refer to datasets/download_scripts/README.md for downloading instructions.
722
+ """
723
+ )
724
+
725
+ def build(self):
726
+ """
727
+ Create by split datasets inheriting torch.utils.data.Datasets.
728
+
729
+ # build() can be dataset-specific. Overwrite to customize.
730
+ """
731
+ self.build_processors()
732
+
733
+ build_info = self.config.build_info
734
+
735
+ ann_info = build_info.annotations
736
+ vis_info = build_info.get(self.data_type)
737
+
738
+ datasets = dict()
739
+ for split in ann_info.keys():
740
+ if split not in ["train", "val", "test"]:
741
+ continue
742
+
743
+ is_train = split == "train"
744
+
745
+ # processors
746
+ vis_processor = (
747
+ self.vis_processors["train"]
748
+ if is_train
749
+ else self.vis_processors["eval"]
750
+ )
751
+ text_processor = (
752
+ self.text_processors["train"]
753
+ if is_train
754
+ else self.text_processors["eval"]
755
+ )
756
+
757
+ # annotation path
758
+ ann_paths = ann_info.get(split).storage
759
+ if isinstance(ann_paths, str):
760
+ ann_paths = [ann_paths]
761
+
762
+ abs_ann_paths = []
763
+ for ann_path in ann_paths:
764
+ if not os.path.isabs(ann_path):
765
+ ann_path = utils.get_cache_path(ann_path)
766
+ abs_ann_paths.append(ann_path)
767
+ ann_paths = abs_ann_paths
768
+
769
+ # visual data storage path
770
+ vis_path = os.path.join(vis_info.storage, split)
771
+
772
+ if not os.path.isabs(vis_path):
773
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
774
+ vis_path = utils.get_cache_path(vis_path)
775
+
776
+ if not os.path.exists(vis_path):
777
+ warnings.warn("storage path {} does not exist.".format(vis_path))
778
+
779
+ # create datasets
780
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
781
+ datasets[split] = dataset_cls(
782
+ vis_processor=vis_processor,
783
+ text_processor=text_processor,
784
+ ann_paths=ann_paths,
785
+ vis_root=vis_path,
786
+ )
787
+
788
+ return datasets
789
+
790
+
791
+
792
+
793
  class Registry:
794
  mapping = {
795
  "builder_name_mapping": {},
 
816
  """
817
 
818
  def wrap(builder_cls):
819
+ # from musilingo.datasets.builders.base_dataset_builder import BaseDatasetBuilder
820
 
821
  assert issubclass(
822
  builder_cls, BaseDatasetBuilder
 
847
  """
848
 
849
  def wrap(task_cls):
850
+ # from musilingo.tasks.base_task import BaseTask
851
 
852
  assert issubclass(
853
  task_cls, BaseTask
 
904
  """
905
 
906
  def wrap(processor_cls):
907
+ # from musilingo.processors import BaseProcessor
908
 
909
  assert issubclass(
910
  processor_cls, BaseProcessor