zdou0830's picture
desco
749745d
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import bisect
import copy
import logging
import os
import torch.utils.data
import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.imports import import_file
from . import datasets as D
from . import samplers
from .collate_batch import BatchCollator, BBoxAugCollator
from .transforms import build_transforms
from transformers import AutoTokenizer
from .datasets.duplicate_dataset import create_duplicate_dataset
def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True, class_concat=False, extra_args={}):
"""
Arguments:
dataset_list (list[str]): Contains the names of the datasets, i.e.,
coco_2014_trian, coco_2014_val, etc
transforms (callable): transforms to apply to each (image, target) sample
dataset_catalog (DatasetCatalog): contains the information on how to
construct a dataset.
is_train (bool): whether to setup the dataset for training or testing
"""
if not isinstance(dataset_list, (list, tuple)):
raise RuntimeError("dataset_list should be a list of strings, got {}".format(dataset_list))
datasets = []
num_category = 1
for dataset_id, dataset_name in enumerate(dataset_list, 1):
if is_train:
dataset_name = dataset_name + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX
else:
dataset_name = dataset_name + cfg.DATASETS.TEST_DATASETNAME_SUFFIX
data = dataset_catalog.get(dataset_name)
factory = getattr(D, data["factory"])
args = data["args"]
# for COCODataset, we want to remove images without annotations
# during training
if data["factory"] == "COCODataset":
args["remove_images_without_annotations"] = is_train
if data["factory"] == "PascalVOCDataset":
args["use_difficult"] = not is_train
if data["factory"] in ["VGTSVDataset", "CocoDetectionTSV", "ODTSVDataset"]:
args["extra_fields"] = ["class"]
if cfg.MODEL.MASK_ON:
args["extra_fields"].append("mask")
if data["factory"] in [
"CocoGrounding",
"CocoDetectionTSV",
"CaptionTSV",
"MixedDataset",
"FlickrDataset",
"RefExpDataset",
"GQADataset",
"PseudoData",
"PhrasecutDetection",
]:
# args["return_masks"] = False
args["return_masks"] = cfg.MODEL.MASK_ON
args["return_tokens"] = True
args["max_num_labels"] = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM
args["max_query_len"] = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
args["transforms"] = transforms
args.update(extra_args)
if "flickr30k_train" in dataset_name: #dataset_name == "flickr30k_train":
copy = cfg.DATASETS.FLICKR_COPY
elif "mixed_train" in dataset_name: #dataset_name in ["mixed_train", "mixed_train_no_coco"]:
copy = cfg.DATASETS.MIXED_COPY
elif dataset_name in ["COCO_odinw_train_8copy_dt_train", "coco_dt_train", "coco_grounding_train"]:
copy = cfg.DATASETS.COCO_COPY
elif dataset_name in ["LVIS_odinw_train_8copy_dt_train", "lvisv1_dt_train", "lvis_grounding_train"]:
copy = cfg.DATASETS.LVIS_COPY
elif dataset_name in ["object365_odinw_2copy_dt_train", "object365_dt_train"]:
copy = cfg.DATASETS.OBJECT365_COPY
elif dataset_name == "vg_odinw_clipped_8copy_dt_train":
copy = cfg.DATASETS.VG_COPY
elif dataset_name == "vg_vgoi6_clipped_8copy_dt_train":
copy = cfg.DATASETS.VG_COPY
elif dataset_name == "imagenetod_train_odinw_2copy_dt":
copy = cfg.DATASETS.IN_COPY
elif dataset_name == "oi_train_odinw_dt":
copy = cfg.DATASETS.OI_COPY
elif "refcoco" in dataset_name:
copy = cfg.DATASETS.REFCOCO_COPY
elif is_train:
copy = cfg.DATASETS.GENERAL_COPY
elif not is_train:
copy = cfg.DATASETS.GENERAL_COPY_TEST
else:
copy = -1 # do not ever copy test
if copy != -1 and copy != 1:
new_factory = create_duplicate_dataset(factory)
dataset = new_factory(copy=copy, **args)
else:
# make dataset from factory
dataset = factory(**args)
print(dataset_name, "has the {} data points".format(len(dataset)), data["factory"])
if class_concat:
category = list(dataset.contiguous_category_id_to_json_id.values())
dataset.contiguous_category_id_to_json_id = {}
dataset.json_category_id_to_contiguous_id = {}
for id, cat in enumerate(category, start=num_category):
dataset.json_category_id_to_contiguous_id[cat] = id
dataset.contiguous_category_id_to_json_id[id] = cat
num_category += len(category)
print("Found {} #category after group {}, concating ...".format(num_category, dataset_id))
datasets.append(dataset)
# for testing, return a list of datasets
if not is_train:
return datasets
# for training, concatenate all datasets into a single one
dataset = datasets[0]
if len(datasets) > 1:
dataset = D.ConcatDataset(datasets)
return [dataset]
def build_dataset_by_group(
dataset_list, transforms, dataset_catalog, is_train=True, class_by_group=True, class_concat=False, extra_args={}
):
"""
Arguments:
dataset_list (list[str]): Contains the names of the datasets, i.e.,
coco_2014_trian, coco_2014_val, etc
transforms (callable): transforms to apply to each (image, target) sample
dataset_catalog (DatasetCatalog): contains the information on how to
construct a dataset.
is_train (bool): whether to setup the dataset for training or testing
"""
if not isinstance(dataset_list, (list, tuple)):
raise RuntimeError("dataset_list should be a list of strings, got {}".format(dataset_list))
num_category = 1
grouped_datasets = []
for group_id, group in enumerate(dataset_list, 1):
datasets = []
for dataset_name in group:
data = dataset_catalog.get(dataset_name)
factory = getattr(D, data["factory"])
args = data["args"]
# for COCODataset, we want to remove images without annotations
# during training
if data["factory"] == "COCODataset":
args["remove_images_without_annotations"] = is_train
if data["factory"] == "PascalVOCDataset":
args["use_difficult"] = not is_train
args["transforms"] = transforms
args.update(extra_args)
# make dataset from factory
dataset = factory(**args)
# check if dataset is grouped by task, assume one class per task
if class_by_group and data["factory"] != "Background":
category = dataset.contiguous_category_id_to_json_id[1]
del dataset.contiguous_category_id_to_json_id[1]
dataset.json_category_id_to_contiguous_id[category] = group_id
dataset.contiguous_category_id_to_json_id[group_id] = category
datasets.append(dataset)
if class_concat:
for dataset in datasets:
category = list(dataset.contiguous_category_id_to_json_id.values())
dataset.contiguous_category_id_to_json_id = {}
dataset.json_category_id_to_contiguous_id = {}
for id, cat in enumerate(category, start=num_category):
dataset.json_category_id_to_contiguous_id[cat] = id
dataset.contiguous_category_id_to_json_id[id] = cat
num_category += len(category)
print("Found {} #category after group {}, concating ...".format(num_category, group_id))
if is_train:
datasets = D.ConcatDataset(datasets)
grouped_datasets.append(datasets)
# for testing, return a list of datasets
if not is_train:
datasets = [dataset for group in grouped_datasets for dataset in group]
return datasets
if class_concat:
grouped_datasets = D.ConcatDataset(grouped_datasets)
return [grouped_datasets]
# for training, concatenate all datasets into a single one
return grouped_datasets
def make_data_sampler(dataset, shuffle, distributed, num_replicas=None, rank=None, use_random_seed=True):
if distributed:
return samplers.DistributedSampler(
dataset, shuffle=shuffle, num_replicas=num_replicas, rank=rank, use_random=use_random_seed
)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
def _quantize(x, bins):
bins = copy.copy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized
def _compute_aspect_ratios(dataset):
aspect_ratios = []
for i in range(len(dataset)):
img_info = dataset.get_img_info(i)
aspect_ratio = float(img_info["height"]) / float(img_info["width"])
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0, drop_last=False
):
if aspect_grouping:
if not isinstance(aspect_grouping, (list, tuple)):
aspect_grouping = [aspect_grouping]
aspect_ratios = _compute_aspect_ratios(dataset)
group_ids = _quantize(aspect_ratios, aspect_grouping)
batch_sampler = samplers.GroupedBatchSampler(sampler, group_ids, images_per_batch, drop_uneven=drop_last)
else:
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, images_per_batch, drop_last=drop_last)
if num_iters is not None:
batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
return batch_sampler
def make_data_loader(cfg, is_train=True, is_distributed=False, num_replicas=None, rank=None, start_iter=0):
num_gpus = num_replicas or get_world_size()
if is_train:
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
assert images_per_batch % num_gpus == 0, "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number "
"of GPUs ({}) used.".format(images_per_batch, num_gpus)
images_per_gpu = images_per_batch // num_gpus
shuffle = True
num_iters = cfg.SOLVER.MAX_ITER
else:
images_per_batch = cfg.TEST.IMS_PER_BATCH
assert images_per_batch % num_gpus == 0, "TEST.IMS_PER_BATCH ({}) must be divisible by the number "
"of GPUs ({}) used.".format(images_per_batch, num_gpus)
images_per_gpu = images_per_batch // num_gpus
shuffle = False if not is_distributed else True
num_iters = None
start_iter = 0
if images_per_gpu > 1:
logger = logging.getLogger(__name__)
logger.warning(
"When using more than one image per GPU you may encounter "
"an out-of-memory (OOM) error if your GPU does not have "
"sufficient memory. If this happens, you can reduce "
"SOLVER.IMS_PER_BATCH (for training) or "
"TEST.IMS_PER_BATCH (for inference). For training, you must "
"also adjust the learning rate and schedule length according "
"to the linear scaling rule. See for example: "
"https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14"
)
# group images which have similar aspect ratio. In this case, we only
# group in two cases: those with width / height > 1, and the other way around,
# but the code supports more general grouping strategy
aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else []
paths_catalog = import_file("maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True)
DatasetCatalog = paths_catalog.DatasetCatalog
if len(cfg.DATASETS.REGISTER) > 0:
for new_dataset in cfg.DATASETS.REGISTER:
# img_dir = cfg.DATASETS.REGISTER[new_dataset]["img_dir"]
# if "ann_file" in cfg.DATASETS.REGISTER[new_dataset]:
# ann_file = cfg.DATASETS.REGISTER[new_dataset]["ann_file"]
# else:
# ann_file = None
attrs = dict(cfg.DATASETS.REGISTER[new_dataset])
if is_train:
new_dataset = new_dataset + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX
else:
new_dataset = new_dataset + cfg.DATASETS.TEST_DATASETNAME_SUFFIX
DatasetCatalog.set(new_dataset, attrs)
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST
# Haotian: expand bing dataset
if "bing_caption_train" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0:
dataset_list = list(dataset_list)
dataset_list.remove("bing_caption_train")
for bing_index in cfg.DATASETS.BING_INDEX_LIST:
dataset_list.insert(len(dataset_list), "bing_caption_{}_train".format(bing_index))
dataset_list = tuple(dataset_list)
if "bing_caption_train_no_coco" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0:
dataset_list = list(dataset_list)
dataset_list.remove("bing_caption_train_no_coco")
for bing_index in cfg.DATASETS.BING_INDEX_LIST:
dataset_list.insert(len(dataset_list), "bing_caption_{}_train_no_coco".format(bing_index))
dataset_list = tuple(dataset_list)
print("The combined datasets are: {}.".format(dataset_list))
transforms = None if not is_train and cfg.TEST.USE_MULTISCALE else build_transforms(cfg, is_train)
extra_args = {}
if is_train and cfg.DATASETS.USE_CROWD:
extra_args["ignore_crowd"] = False
if is_train and cfg.DATASETS.MAX_BOX > 0:
extra_args["max_box"] = cfg.DATASETS.MAX_BOX
if is_train and cfg.DATASETS.FEW_SHOT > 0:
extra_args["few_shot"] = cfg.DATASETS.FEW_SHOT
if is_train and cfg.DATASETS.SHUFFLE_SEED != 0:
extra_args["shuffle_seed"] = cfg.DATASETS.SHUFFLE_SEED
if is_train and cfg.AUGMENT.MOSAIC_PROB > 0:
extra_args["mosaic_prob"] = cfg.AUGMENT.MOSAIC_PROB
extra_args["mosaic_shift"] = cfg.AUGMENT.MOSAIC_SHIFT
extra_args["mosaic_size"] = cfg.AUGMENT.MOSAIC_SIZE
if is_train and cfg.AUGMENT.PASTE_PROB > 0:
extra_args["paste_prob"] = cfg.AUGMENT.PASTE_PROB
extra_args["paste_cat"] = cfg.AUGMENT.PASTE_CAT
extra_args["paste_num"] = cfg.AUGMENT.PASTE_NUM
# od to grounding
if is_train and cfg.DATASETS.RANDOM_SAMPLE_NEG > 0:
extra_args["random_sample_negative"] = cfg.DATASETS.RANDOM_SAMPLE_NEG
if is_train and cfg.DATASETS.ADD_DET_PROMPT:
extra_args["add_detection_prompt"] = True
if is_train and cfg.DATASETS.USE_OD_AUG:
extra_args["use_od_data_aug"] = True
if is_train and cfg.DATASETS.USE_COCO_FORMAT:
extra_args["use_coco_format"] = True
if is_train and cfg.DATASETS.DISABLE_SHUFFLE:
extra_args["disable_shuffle"] = True
if cfg.DATASETS.ONE_HOT:
extra_args["one_hot"] = True
if is_train and len(cfg.DATASETS.PROMPT_VERSION) > 0:
extra_args["prompt_engineer_version"] = cfg.DATASETS.PROMPT_VERSION
if is_train and len(cfg.DATASETS.CONTROL_PROB) == 4:
extra_args["control_probabilities"] = cfg.DATASETS.CONTROL_PROB
if is_train and cfg.DATASETS.DISABLE_CLIP_TO_IMAGE:
extra_args["disable_clip_to_image"] = cfg.DATASETS.DISABLE_CLIP_TO_IMAGE
if is_train and cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT:
extra_args["no_minus_one_for_one_hot"] = cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT
if is_train:
extra_args["separation_tokens"] = cfg.DATASETS.SEPARATION_TOKENS
# caption
if is_train and cfg.DATASETS.CAPTION_MIN_BOX > 0:
extra_args["caption_min_box"] = cfg.DATASETS.CAPTION_MIN_BOX
if is_train and cfg.DATASETS.REPLACE_CLEAN_LABEL:
extra_args["replace_clean_label"] = True
if is_train and cfg.DATASETS.FURTHER_SCREEN:
extra_args["further_screen"] = True
if is_train and cfg.DATASETS.CAPTION_CONF > 0.0:
extra_args["caption_conf"] = cfg.DATASETS.CAPTION_CONF
if is_train:
extra_args["caption_nms"] = cfg.DATASETS.CAPTION_NMS
if is_train and cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER > 0:
extra_args["pack_random_caption_number"] = cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER
if is_train and cfg.DATASETS.INFERENCE_CAPTION:
extra_args["inference_caption"] = True
if is_train and cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA > 0:
extra_args["sample_negative_for_grounding_data"] = cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA
if is_train and cfg.DATASETS.RANDOM_PACK_PROB > 0:
extra_args["random_pack_prob"] = cfg.DATASETS.RANDOM_PACK_PROB
if is_train and cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY > 0:
extra_args["no_random_pack_probability"] = cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY
if is_train:
extra_args["safeguard_positive_caption"] = cfg.DATASETS.SAFEGUARD_POSITIVE_CAPTION
if is_train:
extra_args["local_debug"] = cfg.DATASETS.LOCAL_DEBUG
if is_train:
extra_args["no_mask_for_od"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_OD
if is_train:
extra_args["no_mask_for_gold"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_GOLD
if is_train:
extra_args["mlm_obj_for_only_positive"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_OBJ_FOR_ONLY_POSITIVE
if cfg.DATASETS.OVERRIDE_CATEGORY and cfg.DATASETS.USE_OVERRIDE_CATEGORY:
extra_args["override_category"] = cfg.DATASETS.OVERRIDE_CATEGORY
if is_train:
extra_args["caption_format_version"] = cfg.DATASETS.CAPTION_FORMAT_VERSION
if is_train:
extra_args["special_safeguard_for_coco_grounding"] = cfg.DATASETS.SPECIAL_SAFEGUARD_FOR_COCO_GROUNDING
if is_train:
extra_args["diver_box_for_vqa"] = cfg.DATASETS.DIVER_BOX_FOR_VQA
extra_args["od_to_grounding_version"] = cfg.DATASETS.OD_TO_GROUNDING_VERSION
extra_args["caption_prompt"] = cfg.DATASETS.CAPTION_PROMPT
extra_args["use_caption_prompt"] = cfg.DATASETS.USE_CAPTION_PROMPT
extra_args["description_file"] = cfg.DATASETS.DESCRIPTION_FILE
extra_args["similarity_file"] = cfg.DATASETS.SIMILARITY_FILE
extra_args["caption_vocab_file"] = cfg.DATASETS.CAPTION_VOCAB_FILE
extra_args["caption_augmentation_version"] = cfg.DATASETS.CAPTION_AUGMENTATION_VERSION
extra_args["cc_caption_augmentation_version"] = cfg.DATASETS.CC_CAPTION_AUGMENTATION_VERSION
# extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
# extra_args['tokenizer'] = build_tokenizer("clip")
from transformers import CLIPTokenizerFast
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained(
"openai/clip-vit-base-patch32", from_slow=True, mask_token="ðŁĴij</w>"
)
else:
extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True)
else:
extra_args["tokenizer"] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
if isinstance(dataset_list[0], (tuple, list)):
datasets = build_dataset_by_group(
dataset_list,
transforms,
DatasetCatalog,
is_train,
class_by_group=cfg.DATASETS.ALTERNATIVE_TRAINING,
class_concat=cfg.DATASETS.CLASS_CONCAT,
extra_args=extra_args,
)
else:
datasets = build_dataset(
cfg,
dataset_list,
transforms,
DatasetCatalog,
is_train,
class_concat=cfg.DATASETS.CLASS_CONCAT,
extra_args=extra_args,
)
data_loaders = []
for di, dataset in enumerate(datasets):
if is_train and cfg.SOLVER.MAX_EPOCH > 0:
num_iters = cfg.SOLVER.MAX_EPOCH * len(dataset) // cfg.SOLVER.IMS_PER_BATCH
print("Number of iterations are {}".format(num_iters))
cfg.defrost()
cfg.SOLVER.MAX_ITER = num_iters
cfg.SOLVER.DATASET_LENGTH = len(dataset)
cfg.freeze()
if is_train and cfg.SOLVER.MULTI_MAX_EPOCH:
num_iters = None
cfg.defrost()
cfg.SOLVER.MULTI_MAX_ITER += (cfg.SOLVER.MULTI_MAX_EPOCH[di] * len(dataset) // cfg.SOLVER.IMS_PER_BATCH,)
cfg.freeze()
if is_train and cfg.DATALOADER.DISTRIBUTE_CHUNK_AMONG_NODE:
from .datasets.custom_distributed_sampler import DistributedSamplerChunkByNode
chunk_or_not = []
for i in dataset_list:
if "bing_caption" in i:
chunk_or_not.append(True)
else:
chunk_or_not.append(False)
assert len(chunk_or_not) == len(dataset.datasets)
"""
If we are training on 4 nodes, each with 8 GPUs
"""
num_nodes = int(os.getenv("NODE_COUNT", os.getenv("OMPI_COMM_WORLD_SIZE", 1)))
local_size = cfg.num_gpus // num_nodes
node_rank = int(os.getenv("NODE_RANK", os.getenv("OMPI_COMM_WORLD_RANK", 0)))
local_rank = cfg.local_rank
sampler = DistributedSamplerChunkByNode(
dataset=dataset,
all_datasets=dataset.datasets, # Assumming dataset is a ConcateDataset instance,
chunk_or_not=chunk_or_not,
num_replicas=cfg.num_gpus, # total GPU number, e.g., 32
rank=dist.get_rank(), # Global Rank, e.g., 0~31
node_rank=node_rank, # Node Rank, e.g., 0~3
node_number=num_nodes, # how many node e.g., 4
process_num_per_node=local_size, # e.g., 8
rank_within_local_node=local_rank, # e.g., 0~7
)
else:
sampler = make_data_sampler(
dataset,
shuffle,
is_distributed,
num_replicas=num_replicas,
rank=rank,
use_random_seed=cfg.DATALOADER.USE_RANDOM_SEED,
)
batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, drop_last=is_train
)
collator = (
BBoxAugCollator()
if not is_train and cfg.TEST.USE_MULTISCALE
else BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
)
num_workers = cfg.DATALOADER.NUM_WORKERS
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=collator,
)
data_loaders.append(data_loader)
if is_train and cfg.SOLVER.MULTI_MAX_EPOCH:
cfg.defrost()
cfg.SOLVER.MULTI_MAX_ITER += (
cfg.SOLVER.MULTI_MAX_EPOCH[-1] * min([len(dataset) // cfg.SOLVER.IMS_PER_BATCH for dataset in datasets]),
)
cfg.freeze()
if is_train and not cfg.DATASETS.ALTERNATIVE_TRAINING and not cfg.DATASETS.MULTISTAGE_TRAINING:
# during training, a single (possibly concatenated) data_loader is returned
assert len(data_loaders) == 1
return data_loaders[0]
return data_loaders