Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import bisect | |
import copy | |
import logging | |
import os | |
import torch | |
import torch.utils.data | |
import torch.distributed | |
from torch.utils.data.dataset import ConcatDataset | |
from .catalog import DatasetCatalog | |
from .clip_datasets.clip_img_txt_pair_tsv import CLIPImgTxtPairTSVDataset | |
from .transforms.build import build_clip_transforms | |
def config_tsv_dataset_args(cfg, dataset_file, factory_name=None, is_train=True): | |
############### code removecd as tsv_dataset_name = factory_name = "CLIPImgTxtPairTSVDataset" ############## | |
if factory_name is not None: | |
tsv_dataset_name = factory_name | |
if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"]: | |
# no need for extra arguments | |
args = {} | |
args['args'] = cfg | |
args['seq_len'] = cfg.DATASETS.MAX_SEQ_LENGTH # cfg.max_seq_length | |
return args, tsv_dataset_name | |
def build_dataset(cfg, transforms, dataset_catalog, is_train=True, is_aux=False): | |
""" | |
Arguments: | |
cfg: config file. | |
transforms (callable): transforms to apply to each (image, target) sample | |
dataset_catalog (DatasetCatalog): contains the information on how to construct a dataset. | |
is_train (bool): whether to setup the dataset for training or testing | |
""" | |
dataset_list = (cfg.DATASETS.TRAIN if not is_aux else cfg.DATASETS.AUX) if is_train else cfg.DATASETS.TEST | |
factory_list = (cfg.DATASETS.FACTORY_TRAIN if not is_aux else cfg.DATASETS.FACTORY_AUX) if is_train else cfg.DATASETS.FACTORY_TEST | |
path_list = (cfg.DATASETS.PATH_TRAIN if not is_aux else cfg.DATASETS.PATH_AUX) if is_train else cfg.DATASETS.PATH_TEST | |
if not isinstance(dataset_list, (list, tuple)): | |
raise RuntimeError( | |
"dataset_list should be a list of strings, got {}".format(dataset_list)) | |
if not isinstance(factory_list, (list, tuple)): | |
raise RuntimeError( | |
"factory_list should be a list of strings, got {}".format(factory_list)) | |
datasets = [] | |
target_offset = 0 | |
for i, dataset_name in enumerate(dataset_list): | |
factory_name = factory_list[i] if i < len(factory_list) else None | |
if factory_name == "CLIPImgTxtPairTSVDataset": | |
dataset_names_merged = dataset_name.split('+') | |
path_lists_merged = path_list[i].split('+') | |
assert len(dataset_names_merged) == len(path_lists_merged), "number of datasets must match that of dataset paths" | |
image_tsv_list = [] | |
text_tsv_list = [] | |
dataset_name_list = [] | |
map_files = [] | |
max_num_tsv = 20 # maximum tsv files to load within a given folder | |
for dname, dpath in zip(dataset_names_merged, path_lists_merged): | |
args, tsv_dataset_name = config_tsv_dataset_args( | |
cfg, dataset_name, factory_name, is_train | |
) | |
factory = CLIPImgTxtPairTSVDataset if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"] else None | |
prev_len = len(image_tsv_list) | |
isFile = os.path.isfile(dpath) | |
if isFile: | |
dpath_listed_files = [os.path.basename(dpath)] | |
dpath = os.path.dirname(dpath) | |
else: | |
dpath_listed_files = sorted(os.listdir(dpath)) | |
for filename in dpath_listed_files: | |
if ("images" in filename or "image" in filename or "img" in filename) and filename.endswith(".tsv"): | |
image_tsv_list.append(os.path.join(dpath, filename)) | |
if "images" in filename: # "images" - "text" | |
text_tsv_list.append(os.path.join(dpath, filename.replace("images", "text"))) | |
elif "image" in filename: # "image"-"text" | |
text_tsv_list.append(os.path.join(dpath, filename.replace("image", "text"))) | |
elif "img" in filename: # "img"-"caption" | |
text_tsv_list.append(os.path.join(dpath, filename.replace("img", "caption"))) | |
if len(image_tsv_list) - prev_len == max_num_tsv: | |
break | |
dataset_name_list += [dname] * (len(image_tsv_list) - prev_len) | |
if dname == "imagenet22k": | |
map_files += [os.path.join(dpath, 'darknet_data_imagenet.labels.list')] * (len(image_tsv_list) - prev_len) | |
else: | |
map_files += [None] * (len(image_tsv_list) - prev_len) | |
assert len(image_tsv_list) == len(text_tsv_list), \ | |
"the number image tsv files must be equal to that of text tsv files, otherwise check your data!" | |
args["image_tsv_file"] = image_tsv_list | |
args["text_tsv_file"] = text_tsv_list | |
args["dataset_name"] = dataset_name_list | |
args["map_file"] = map_files | |
args["filtered_datasets"] = cfg.DATASETS.FILTERED_CLASSIFICATION_DATASETS | |
assert len(image_tsv_list) == len(text_tsv_list) == len(dataset_name_list) == len(map_files) | |
print("number of image tsv files: ", len(image_tsv_list)) | |
print("number of text tsv fies: ", len(text_tsv_list)) | |
args["is_train"] = is_train | |
args["transforms"] = transforms | |
args["target_offset"] = target_offset | |
if "bpe" in cfg.INPUT.TEXT_TOKENIZER: | |
from detectron2.data.datasets.clip_prompt_utils import SimpleTokenizer as _Tokenizer | |
tokenizer = _Tokenizer() | |
args["tokenizer_type"] = "bpe" | |
args["tokenizer"] = tokenizer | |
# make dataset from factory | |
dataset = factory(**args) | |
datasets.append(dataset) | |
precomputed_tokens = {} | |
dataset_classes = {} | |
for dataset in datasets: | |
if hasattr(dataset, "input_ids_all_classes"): | |
precomputed_tokens["imagenet"] = \ | |
[dataset.input_ids_all_classes, dataset.input_mask_all_classes, dataset.segment_ids_all_classes] | |
if hasattr(dataset, "classnames"): | |
if isinstance(dataset.classnames, dict): | |
dataset_classes.update(dataset.classnames) | |
else: | |
dataset_classes[dataset.dataset_name] = dataset.classnames | |
# for testing, return a list of datasets | |
if not is_train: | |
return datasets, precomputed_tokens, dataset_classes | |
if len(datasets) == 0: | |
return None, None, None | |
# for training, concatenate all datasets into a single one | |
dataset = datasets[0] | |
if len(datasets) > 1: | |
dataset = ConcatDataset(datasets) | |
return [dataset], precomputed_tokens, dataset_classes | |
def make_clip_dataset(cfg, is_train=True, is_aux=False, transforms=None): | |
if transforms is None: | |
transforms = build_clip_transforms(cfg, is_train) | |
print("data transforms: ") | |
print(transforms) | |
datasets, precomputed_tokens, dataset_classes = build_dataset(cfg, transforms, DatasetCatalog, is_train, is_aux) | |
if not datasets: | |
return None, None, None | |
return datasets, precomputed_tokens, dataset_classes |