NEXTGPT / code /dataset /concat_dataset.py
osamaifti's picture
Upload 83 files
7cdf421 verified
raw
history blame contribute delete
No virus
1.16 kB
from torch.utils.data import ConcatDataset, Dataset
from .catalog import DatasetCatalog
from .utils import instantiate_from_config
class MyConcatDataset(Dataset):
def __init__(self, dataset_name_list):
super(MyConcatDataset, self).__init__()
_datasets = []
catalog = DatasetCatalog()
for dataset_idx, dataset_name in enumerate(dataset_name_list):
dataset_dict = getattr(catalog, dataset_name)
target = dataset_dict['target']
params = dataset_dict['params']
print(target)
print(params)
dataset = instantiate_from_config(dict(target=target, params=params))
_datasets.append(dataset)
self.datasets = ConcatDataset(_datasets)
def __len__(self):
return self.datasets.__len__()
def __getitem__(self, item):
return self.datasets.__getitem__(item)
def collate(self, instances):
data = {key: [] for key in instances[0].keys()} if instances else {}
for instance in instances:
for key, value in instance.items():
data[key].append(value)
return data