|
""" |
|
Copyright (c) 2022, salesforce.com, inc. |
|
All rights reserved. |
|
SPDX-License-Identifier: BSD-3-Clause |
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
""" |
|
|
|
import gzip |
|
import logging |
|
import os |
|
import random as rnd |
|
import tarfile |
|
import zipfile |
|
import random |
|
from typing import List |
|
from tqdm import tqdm |
|
|
|
import decord |
|
from decord import VideoReader |
|
import webdataset as wds |
|
import numpy as np |
|
import torch |
|
from torch.utils.data.dataset import IterableDataset |
|
|
|
from medomni.common.registry import registry |
|
from medomni.datasets.datasets.base_dataset import ConcatDataset |
|
|
|
|
|
decord.bridge.set_bridge("torch") |
|
MAX_INT = registry.get("MAX_INT") |
|
|
|
|
|
class ChainDataset(wds.DataPipeline): |
|
r"""Dataset for chaining multiple :class:`DataPipeline` s. |
|
|
|
This class is useful to assemble different existing dataset streams. The |
|
chaining operation is done on-the-fly, so concatenating large-scale |
|
datasets with this class will be efficient. |
|
|
|
Args: |
|
datasets (iterable of IterableDataset): datasets to be chained together |
|
""" |
|
def __init__(self, datasets: List[wds.DataPipeline]) -> None: |
|
super().__init__() |
|
self.datasets = datasets |
|
self.prob = [] |
|
self.names = [] |
|
for dataset in self.datasets: |
|
if hasattr(dataset, 'name'): |
|
self.names.append(dataset.name) |
|
else: |
|
self.names.append('Unknown') |
|
if hasattr(dataset, 'sample_ratio'): |
|
self.prob.append(dataset.sample_ratio) |
|
else: |
|
self.prob.append(1) |
|
logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") |
|
|
|
def __iter__(self): |
|
datastreams = [iter(dataset) for dataset in self.datasets] |
|
while True: |
|
select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] |
|
yield next(select_datastream) |
|
|
|
|
|
def apply_to_sample(f, sample): |
|
if len(sample) == 0: |
|
return {} |
|
|
|
def _apply(x): |
|
if torch.is_tensor(x): |
|
return f(x) |
|
elif isinstance(x, dict): |
|
return {key: _apply(value) for key, value in x.items()} |
|
elif isinstance(x, list): |
|
return [_apply(x) for x in x] |
|
else: |
|
return x |
|
|
|
return _apply(sample) |
|
|
|
|
|
def move_to_cuda(sample): |
|
def _move_to_cuda(tensor): |
|
return tensor.cuda() |
|
|
|
return apply_to_sample(_move_to_cuda, sample) |
|
|
|
|
|
def prepare_sample(samples, cuda_enabled=True): |
|
if cuda_enabled: |
|
samples = move_to_cuda(samples) |
|
|
|
|
|
|
|
return samples |
|
|
|
|
|
def reorg_datasets_by_split(datasets): |
|
""" |
|
Organizes datasets by split. |
|
|
|
Args: |
|
datasets: dict of torch.utils.data.Dataset objects by name. |
|
|
|
Returns: |
|
Dict of datasets by split {split_name: List[Datasets]}. |
|
""" |
|
|
|
|
|
|
|
reorg_datasets = dict() |
|
|
|
|
|
for _, dataset in datasets.items(): |
|
for split_name, dataset_split in dataset.items(): |
|
if split_name not in reorg_datasets: |
|
reorg_datasets[split_name] = [dataset_split] |
|
else: |
|
reorg_datasets[split_name].append(dataset_split) |
|
|
|
return reorg_datasets |
|
|
|
|
|
def concat_datasets(datasets): |
|
""" |
|
Concatenates multiple datasets into a single dataset. |
|
|
|
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support |
|
generic IterableDataset because it requires creating separate samplers. |
|
|
|
Now only supports conctenating training datasets and assuming validation and testing |
|
have only a single dataset. This is because metrics should not be computed on the concatenated |
|
datasets. |
|
|
|
Args: |
|
datasets: dict of torch.utils.data.Dataset objects by split. |
|
|
|
Returns: |
|
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, |
|
"val" and "test" remain the same. |
|
|
|
If the input training datasets contain both map-style and DataPipeline datasets, returns |
|
a tuple, where the first element is a concatenated map-style dataset and the second |
|
element is a chained DataPipeline dataset. |
|
|
|
""" |
|
|
|
for split_name in datasets: |
|
if split_name != "train": |
|
assert ( |
|
len(datasets[split_name]) == 1 |
|
), "Do not support multiple {} datasets.".format(split_name) |
|
datasets[split_name] = datasets[split_name][0] |
|
else: |
|
iterable_datasets, map_datasets = [], [] |
|
for dataset in datasets[split_name]: |
|
if isinstance(dataset, wds.DataPipeline): |
|
logging.info( |
|
"Dataset {} is IterableDataset, can't be concatenated.".format( |
|
dataset |
|
) |
|
) |
|
iterable_datasets.append(dataset) |
|
elif isinstance(dataset, IterableDataset): |
|
raise NotImplementedError( |
|
"Do not support concatenation of generic IterableDataset." |
|
) |
|
else: |
|
map_datasets.append(dataset) |
|
|
|
|
|
|
|
if len(iterable_datasets) > 1: |
|
chained_datasets = ( |
|
ChainDataset(iterable_datasets) |
|
) |
|
elif len(iterable_datasets) == 1: |
|
chained_datasets = iterable_datasets[0] |
|
else: |
|
chained_datasets = None |
|
|
|
concat_datasets = ( |
|
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None |
|
) |
|
|
|
train_datasets = concat_datasets, chained_datasets |
|
train_datasets = tuple([x for x in train_datasets if x is not None]) |
|
train_datasets = ( |
|
train_datasets[0] if len(train_datasets) == 1 else train_datasets |
|
) |
|
|
|
datasets[split_name] = train_datasets |
|
|
|
return datasets |
|
|
|
|