Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import bisect | |
import copy | |
import logging | |
import os | |
import torch.utils.data | |
import torch.distributed as dist | |
from maskrcnn_benchmark.utils.comm import get_world_size | |
from maskrcnn_benchmark.utils.imports import import_file | |
from . import datasets as D | |
from . import samplers | |
from .collate_batch import BatchCollator, BBoxAugCollator | |
from .transforms import build_transforms | |
from transformers import AutoTokenizer | |
from .datasets.duplicate_dataset import create_duplicate_dataset | |
def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True, class_concat=False, extra_args={}): | |
""" | |
Arguments: | |
dataset_list (list[str]): Contains the names of the datasets, i.e., | |
coco_2014_trian, coco_2014_val, etc | |
transforms (callable): transforms to apply to each (image, target) sample | |
dataset_catalog (DatasetCatalog): contains the information on how to | |
construct a dataset. | |
is_train (bool): whether to setup the dataset for training or testing | |
""" | |
if not isinstance(dataset_list, (list, tuple)): | |
raise RuntimeError( | |
"dataset_list should be a list of strings, got {}".format(dataset_list) | |
) | |
datasets = [] | |
num_category = 1 | |
for dataset_id, dataset_name in enumerate(dataset_list, 1): | |
if is_train: | |
dataset_name = dataset_name + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX | |
else: | |
dataset_name = dataset_name + cfg.DATASETS.TEST_DATASETNAME_SUFFIX | |
data = dataset_catalog.get(dataset_name) | |
factory = getattr(D, data["factory"]) | |
args = data["args"] | |
# for COCODataset, we want to remove images without annotations | |
# during training | |
if data["factory"] == "COCODataset": | |
args["remove_images_without_annotations"] = is_train | |
if data["factory"] == "PascalVOCDataset": | |
args["use_difficult"] = not is_train | |
if data["factory"] in ["VGTSVDataset", "CocoDetectionTSV", "ODTSVDataset"]: | |
args["extra_fields"] = ["class"] | |
if cfg.MODEL.MASK_ON: | |
args["extra_fields"].append("mask") | |
if data["factory"] in ["CocoGrounding", "CocoDetectionTSV", "CaptionTSV", "MixedDataset", "FlickrDataset", "RefExpDataset", "GQADataset", "PseudoData", "PhrasecutDetection"]: | |
# args["return_masks"] = False | |
args["return_masks"] = cfg.MODEL.MASK_ON | |
args["return_tokens"] = True | |
args["max_num_labels"] = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM | |
args["max_query_len"] = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN | |
args["transforms"] = transforms | |
args.update(extra_args) | |
if dataset_name == "flickr30k_train": | |
copy = cfg.DATASETS.FLICKR_COPY | |
elif dataset_name in ["mixed_train", "mixed_train_no_coco"]: | |
copy = cfg.DATASETS.MIXED_COPY | |
elif dataset_name == "COCO_odinw_train_8copy_dt_train": | |
copy = cfg.DATASETS.COCO_COPY | |
elif dataset_name == "LVIS_odinw_train_8copy_dt_train": | |
copy = cfg.DATASETS.LVIS_COPY | |
elif dataset_name == "object365_odinw_2copy_dt_train": | |
copy = cfg.DATASETS.OBJECT365_COPY | |
elif dataset_name == "vg_odinw_clipped_8copy_dt_train": | |
copy = cfg.DATASETS.VG_COPY | |
elif dataset_name == "vg_vgoi6_clipped_8copy_dt_train": | |
copy = cfg.DATASETS.VG_COPY | |
elif dataset_name == "imagenetod_train_odinw_2copy_dt": | |
copy = cfg.DATASETS.IN_COPY | |
elif dataset_name == "oi_train_odinw_dt": | |
copy = cfg.DATASETS.OI_COPY | |
elif is_train: | |
copy = cfg.DATASETS.GENERAL_COPY | |
elif not is_train: | |
copy = cfg.DATASETS.GENERAL_COPY_TEST | |
else: | |
copy = -1 # do not ever copy test | |
if copy != -1: | |
new_factory = create_duplicate_dataset(factory) | |
dataset = new_factory(copy=copy, **args) | |
else: | |
# make dataset from factory | |
dataset = factory(**args) | |
print(dataset_name, 'has the {} data points'.format(len(dataset)), data["factory"]) | |
if class_concat: | |
category = list(dataset.contiguous_category_id_to_json_id.values()) | |
dataset.contiguous_category_id_to_json_id = {} | |
dataset.json_category_id_to_contiguous_id = {} | |
for id, cat in enumerate(category, start=num_category): | |
dataset.json_category_id_to_contiguous_id[cat] = id | |
dataset.contiguous_category_id_to_json_id[id] = cat | |
num_category += len(category) | |
print("Found {} #category after group {}, concating ...".format(num_category, dataset_id)) | |
datasets.append(dataset) | |
# for testing, return a list of datasets | |
if not is_train: | |
return datasets | |
# for training, concatenate all datasets into a single one | |
dataset = datasets[0] | |
if len(datasets) > 1: | |
dataset = D.ConcatDataset(datasets) | |
return [dataset] | |
def build_dataset_by_group(dataset_list, transforms, dataset_catalog, is_train=True, class_by_group=True, | |
class_concat=False, extra_args={}): | |
""" | |
Arguments: | |
dataset_list (list[str]): Contains the names of the datasets, i.e., | |
coco_2014_trian, coco_2014_val, etc | |
transforms (callable): transforms to apply to each (image, target) sample | |
dataset_catalog (DatasetCatalog): contains the information on how to | |
construct a dataset. | |
is_train (bool): whether to setup the dataset for training or testing | |
""" | |
if not isinstance(dataset_list, (list, tuple)): | |
raise RuntimeError( | |
"dataset_list should be a list of strings, got {}".format(dataset_list) | |
) | |
num_category = 1 | |
grouped_datasets = [] | |
for group_id, group in enumerate(dataset_list, 1): | |
datasets = [] | |
for dataset_name in group: | |
data = dataset_catalog.get(dataset_name) | |
factory = getattr(D, data["factory"]) | |
args = data["args"] | |
# for COCODataset, we want to remove images without annotations | |
# during training | |
if data["factory"] == "COCODataset": | |
args["remove_images_without_annotations"] = is_train | |
if data["factory"] == "PascalVOCDataset": | |
args["use_difficult"] = not is_train | |
args["transforms"] = transforms | |
args.update(extra_args) | |
# make dataset from factory | |
dataset = factory(**args) | |
# check if dataset is grouped by task, assume one class per task | |
if class_by_group and data["factory"] != "Background": | |
category = dataset.contiguous_category_id_to_json_id[1] | |
del dataset.contiguous_category_id_to_json_id[1] | |
dataset.json_category_id_to_contiguous_id[category] = group_id | |
dataset.contiguous_category_id_to_json_id[group_id] = category | |
datasets.append(dataset) | |
if class_concat: | |
for dataset in datasets: | |
category = list(dataset.contiguous_category_id_to_json_id.values()) | |
dataset.contiguous_category_id_to_json_id = {} | |
dataset.json_category_id_to_contiguous_id = {} | |
for id, cat in enumerate(category, start=num_category): | |
dataset.json_category_id_to_contiguous_id[cat] = id | |
dataset.contiguous_category_id_to_json_id[id] = cat | |
num_category += len(category) | |
print("Found {} #category after group {}, concating ...".format(num_category, group_id)) | |
if is_train: | |
datasets = D.ConcatDataset(datasets) | |
grouped_datasets.append(datasets) | |
# for testing, return a list of datasets | |
if not is_train: | |
datasets = [dataset for group in grouped_datasets for dataset in group] | |
return datasets | |
if class_concat: | |
grouped_datasets = D.ConcatDataset(grouped_datasets) | |
return [grouped_datasets] | |
# for training, concatenate all datasets into a single one | |
return grouped_datasets | |
def make_data_sampler(dataset, shuffle, distributed, num_replicas=None, rank=None, use_random_seed=True): | |
if distributed: | |
return samplers.DistributedSampler(dataset, shuffle=shuffle, num_replicas=num_replicas, rank=rank, | |
use_random=use_random_seed) | |
if shuffle: | |
sampler = torch.utils.data.sampler.RandomSampler(dataset) | |
else: | |
sampler = torch.utils.data.sampler.SequentialSampler(dataset) | |
return sampler | |
def _quantize(x, bins): | |
bins = copy.copy(bins) | |
bins = sorted(bins) | |
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) | |
return quantized | |
def _compute_aspect_ratios(dataset): | |
aspect_ratios = [] | |
for i in range(len(dataset)): | |
img_info = dataset.get_img_info(i) | |
aspect_ratio = float(img_info["height"]) / float(img_info["width"]) | |
aspect_ratios.append(aspect_ratio) | |
return aspect_ratios | |
def make_batch_data_sampler( | |
dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0, drop_last=False | |
): | |
if aspect_grouping: | |
if not isinstance(aspect_grouping, (list, tuple)): | |
aspect_grouping = [aspect_grouping] | |
aspect_ratios = _compute_aspect_ratios(dataset) | |
group_ids = _quantize(aspect_ratios, aspect_grouping) | |
batch_sampler = samplers.GroupedBatchSampler( | |
sampler, group_ids, images_per_batch, drop_uneven=drop_last | |
) | |
else: | |
batch_sampler = torch.utils.data.sampler.BatchSampler( | |
sampler, images_per_batch, drop_last=drop_last | |
) | |
if num_iters is not None: | |
batch_sampler = samplers.IterationBasedBatchSampler( | |
batch_sampler, num_iters, start_iter | |
) | |
return batch_sampler | |
def make_data_loader(cfg, is_train=True, is_distributed=False, num_replicas=None, rank=None, start_iter=0): | |
num_gpus = num_replicas or get_world_size() | |
if is_train: | |
images_per_batch = cfg.SOLVER.IMS_PER_BATCH | |
assert ( | |
images_per_batch % num_gpus == 0 | |
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number " | |
"of GPUs ({}) used.".format(images_per_batch, num_gpus) | |
images_per_gpu = images_per_batch // num_gpus | |
shuffle = True | |
num_iters = cfg.SOLVER.MAX_ITER | |
else: | |
images_per_batch = cfg.TEST.IMS_PER_BATCH | |
assert ( | |
images_per_batch % num_gpus == 0 | |
), "TEST.IMS_PER_BATCH ({}) must be divisible by the number " | |
"of GPUs ({}) used.".format(images_per_batch, num_gpus) | |
images_per_gpu = images_per_batch // num_gpus | |
shuffle = False if not is_distributed else True | |
num_iters = None | |
start_iter = 0 | |
if images_per_gpu > 1: | |
logger = logging.getLogger(__name__) | |
logger.warning( | |
"When using more than one image per GPU you may encounter " | |
"an out-of-memory (OOM) error if your GPU does not have " | |
"sufficient memory. If this happens, you can reduce " | |
"SOLVER.IMS_PER_BATCH (for training) or " | |
"TEST.IMS_PER_BATCH (for inference). For training, you must " | |
"also adjust the learning rate and schedule length according " | |
"to the linear scaling rule. See for example: " | |
"https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14" | |
) | |
# group images which have similar aspect ratio. In this case, we only | |
# group in two cases: those with width / height > 1, and the other way around, | |
# but the code supports more general grouping strategy | |
aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else [] | |
paths_catalog = import_file( | |
"maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True | |
) | |
DatasetCatalog = paths_catalog.DatasetCatalog | |
if len(cfg.DATASETS.REGISTER) > 0: | |
for new_dataset in cfg.DATASETS.REGISTER: | |
# img_dir = cfg.DATASETS.REGISTER[new_dataset]["img_dir"] | |
# if "ann_file" in cfg.DATASETS.REGISTER[new_dataset]: | |
# ann_file = cfg.DATASETS.REGISTER[new_dataset]["ann_file"] | |
# else: | |
# ann_file = None | |
attrs = dict(cfg.DATASETS.REGISTER[new_dataset]) | |
if is_train: | |
new_dataset = new_dataset + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX | |
else: | |
new_dataset = new_dataset + cfg.DATASETS.TEST_DATASETNAME_SUFFIX | |
DatasetCatalog.set(new_dataset, attrs) | |
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST | |
# Haotian: expand bing dataset | |
if "bing_caption_train" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0: | |
dataset_list = list(dataset_list) | |
dataset_list.remove("bing_caption_train") | |
for bing_index in cfg.DATASETS.BING_INDEX_LIST: | |
dataset_list.insert(len(dataset_list), "bing_caption_{}_train".format(bing_index)) | |
dataset_list = tuple(dataset_list) | |
if "bing_caption_train_no_coco" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0: | |
dataset_list = list(dataset_list) | |
dataset_list.remove("bing_caption_train_no_coco") | |
for bing_index in cfg.DATASETS.BING_INDEX_LIST: | |
dataset_list.insert(len(dataset_list), "bing_caption_{}_train_no_coco".format(bing_index)) | |
dataset_list = tuple(dataset_list) | |
print("The combined datasets are: {}.".format(dataset_list)) | |
transforms = None if not is_train and cfg.TEST.USE_MULTISCALE else build_transforms(cfg, is_train) | |
extra_args = {} | |
if is_train and cfg.DATASETS.USE_CROWD: | |
extra_args['ignore_crowd'] = False | |
if is_train and cfg.DATASETS.MAX_BOX > 0: | |
extra_args['max_box'] = cfg.DATASETS.MAX_BOX | |
if is_train and cfg.DATASETS.FEW_SHOT>0: | |
extra_args['few_shot'] = cfg.DATASETS.FEW_SHOT | |
if is_train and cfg.DATASETS.SHUFFLE_SEED != 0: | |
extra_args['shuffle_seed'] = cfg.DATASETS.SHUFFLE_SEED | |
# od to grounding | |
if is_train and cfg.DATASETS.RANDOM_SAMPLE_NEG > 0: | |
extra_args['random_sample_negative'] = cfg.DATASETS.RANDOM_SAMPLE_NEG | |
if is_train and cfg.DATASETS.ADD_DET_PROMPT: | |
extra_args["add_detection_prompt"] = True | |
if is_train and cfg.DATASETS.USE_OD_AUG: | |
extra_args["use_od_data_aug"] = True | |
if is_train and cfg.DATASETS.DISABLE_SHUFFLE: | |
extra_args["disable_shuffle"] = True | |
if cfg.DATASETS.ONE_HOT: | |
extra_args["one_hot"] = True | |
if is_train and len(cfg.DATASETS.PROMPT_VERSION) > 0: | |
extra_args["prompt_engineer_version"] = cfg.DATASETS.PROMPT_VERSION | |
if is_train and len(cfg.DATASETS.CONTROL_PROB) == 4: | |
extra_args["control_probabilities"] = cfg.DATASETS.CONTROL_PROB | |
if is_train and cfg.DATASETS.DISABLE_CLIP_TO_IMAGE: | |
extra_args["disable_clip_to_image"] = cfg.DATASETS.DISABLE_CLIP_TO_IMAGE | |
if is_train and cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT: | |
extra_args["no_minus_one_for_one_hot"] = cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT | |
if is_train: | |
extra_args["separation_tokens"] = cfg.DATASETS.SEPARATION_TOKENS | |
# caption | |
if is_train and cfg.DATASETS.CAPTION_MIN_BOX > 0: | |
extra_args["caption_min_box"] = cfg.DATASETS.CAPTION_MIN_BOX | |
if is_train and cfg.DATASETS.REPLACE_CLEAN_LABEL: | |
extra_args["replace_clean_label"] = True | |
if is_train and cfg.DATASETS.FURTHER_SCREEN: | |
extra_args["further_screen"] = True | |
if is_train and cfg.DATASETS.CAPTION_CONF > 0.0: | |
extra_args["caption_conf"] = cfg.DATASETS.CAPTION_CONF | |
if is_train: | |
extra_args["caption_nms"] = cfg.DATASETS.CAPTION_NMS | |
if is_train and cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER > 0: | |
extra_args["pack_random_caption_number"] = cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER | |
if is_train and cfg.DATASETS.INFERENCE_CAPTION: | |
extra_args["inference_caption"] = True | |
if is_train and cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA > 0: | |
extra_args["sample_negative_for_grounding_data"] = cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA | |
if is_train and cfg.DATASETS.RANDOM_PACK_PROB > 0: | |
extra_args["random_pack_prob"] = cfg.DATASETS.RANDOM_PACK_PROB | |
if is_train and cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY > 0: | |
extra_args["no_random_pack_probability"] = cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY | |
if is_train: | |
extra_args["safeguard_positive_caption"] = cfg.DATASETS.SAFEGUARD_POSITIVE_CAPTION | |
if is_train: | |
extra_args["local_debug"] = cfg.DATASETS.LOCAL_DEBUG | |
if is_train: | |
extra_args["no_mask_for_od"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_OD | |
if is_train: | |
extra_args["no_mask_for_gold"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_GOLD | |
if is_train: | |
extra_args["mlm_obj_for_only_positive"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_OBJ_FOR_ONLY_POSITIVE | |
if cfg.DATASETS.OVERRIDE_CATEGORY and cfg.DATASETS.USE_OVERRIDE_CATEGORY: | |
extra_args["override_category"] = cfg.DATASETS.OVERRIDE_CATEGORY | |
if is_train: | |
extra_args["caption_format_version"] = cfg.DATASETS.CAPTION_FORMAT_VERSION | |
if is_train: | |
extra_args["special_safeguard_for_coco_grounding"] = cfg.DATASETS.SPECIAL_SAFEGUARD_FOR_COCO_GROUNDING | |
if is_train: | |
extra_args["diver_box_for_vqa"] = cfg.DATASETS.DIVER_BOX_FOR_VQA | |
extra_args["caption_prompt"] = cfg.DATASETS.CAPTION_PROMPT | |
extra_args["use_caption_prompt"] = cfg.DATASETS.USE_CAPTION_PROMPT | |
# extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) | |
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": | |
# extra_args['tokenizer'] = build_tokenizer("clip") | |
from transformers import CLIPTokenizerFast | |
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: | |
extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True, mask_token='ðŁĴij</w>') | |
else: | |
extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True) | |
else: | |
extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) | |
if isinstance(dataset_list[0], (tuple, list)): | |
datasets = build_dataset_by_group(dataset_list, transforms, DatasetCatalog, is_train, | |
class_by_group=cfg.DATASETS.ALTERNATIVE_TRAINING, | |
class_concat=cfg.DATASETS.CLASS_CONCAT, | |
extra_args=extra_args) | |
else: | |
datasets = build_dataset(cfg, dataset_list, transforms, DatasetCatalog, is_train, | |
class_concat=cfg.DATASETS.CLASS_CONCAT, | |
extra_args=extra_args) | |
data_loaders = [] | |
for di, dataset in enumerate(datasets): | |
if is_train and cfg.SOLVER.MAX_EPOCH > 0: | |
num_iters = cfg.SOLVER.MAX_EPOCH * len(dataset) // cfg.SOLVER.IMS_PER_BATCH | |
print("Number of iterations are {}".format(num_iters)) | |
cfg.defrost() | |
cfg.SOLVER.MAX_ITER = num_iters | |
cfg.SOLVER.DATASET_LENGTH = len(dataset) | |
cfg.freeze() | |
if is_train and cfg.SOLVER.MULTI_MAX_EPOCH: | |
num_iters = None | |
cfg.defrost() | |
cfg.SOLVER.MULTI_MAX_ITER += (cfg.SOLVER.MULTI_MAX_EPOCH[di] * len(dataset) // cfg.SOLVER.IMS_PER_BATCH,) | |
cfg.freeze() | |
if is_train and cfg.DATALOADER.DISTRIBUTE_CHUNK_AMONG_NODE: | |
from .datasets.custom_distributed_sampler import DistributedSamplerChunkByNode | |
chunk_or_not = [] | |
for i in dataset_list: | |
if "bing_caption" in i: | |
chunk_or_not.append(True) | |
else: | |
chunk_or_not.append(False) | |
assert(len(chunk_or_not) == len(dataset.datasets)) | |
''' | |
If we are training on 4 nodes, each with 8 GPUs | |
''' | |
num_nodes = int(os.getenv('NODE_COUNT', os.getenv('OMPI_COMM_WORLD_SIZE', 1))) | |
local_size = cfg.num_gpus//num_nodes | |
node_rank = int(os.getenv('NODE_RANK', os.getenv('OMPI_COMM_WORLD_RANK', 0))) | |
local_rank = cfg.local_rank | |
sampler = DistributedSamplerChunkByNode( | |
dataset = dataset, | |
all_datasets = dataset.datasets, # Assumming dataset is a ConcateDataset instance, | |
chunk_or_not = chunk_or_not, | |
num_replicas = cfg.num_gpus, # total GPU number, e.g., 32 | |
rank = dist.get_rank(), # Global Rank, e.g., 0~31 | |
node_rank = node_rank, # Node Rank, e.g., 0~3 | |
node_number = num_nodes, # how many node e.g., 4 | |
process_num_per_node = local_size, # e.g., 8 | |
rank_within_local_node = local_rank, # e.g., 0~7 | |
) | |
else: | |
sampler = make_data_sampler(dataset, shuffle, is_distributed, num_replicas=num_replicas, rank=rank, | |
use_random_seed=cfg.DATALOADER.USE_RANDOM_SEED) | |
batch_sampler = make_batch_data_sampler( | |
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, drop_last=is_train | |
) | |
collator = BBoxAugCollator() if not is_train and cfg.TEST.USE_MULTISCALE else BatchCollator( | |
cfg.DATALOADER.SIZE_DIVISIBILITY) | |
num_workers = cfg.DATALOADER.NUM_WORKERS | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
num_workers=num_workers, | |
batch_sampler=batch_sampler, | |
collate_fn=collator, | |
) | |
data_loaders.append(data_loader) | |
if is_train and cfg.SOLVER.MULTI_MAX_EPOCH: | |
cfg.defrost() | |
cfg.SOLVER.MULTI_MAX_ITER += ( | |
cfg.SOLVER.MULTI_MAX_EPOCH[-1] * min([len(dataset) // cfg.SOLVER.IMS_PER_BATCH for dataset in datasets]),) | |
cfg.freeze() | |
if is_train and not cfg.DATASETS.ALTERNATIVE_TRAINING and not cfg.DATASETS.MULTISTAGE_TRAINING: | |
# during training, a single (possibly concatenated) data_loader is returned | |
assert len(data_loaders) == 1 | |
return data_loaders[0] | |
return data_loaders | |