Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import ConcatDataset, DataLoader | |
from torchvision import transforms | |
from torchvision.transforms import InterpolationMode | |
from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset | |
def get_media_type(dataset_config): | |
if len(dataset_config) == 3 and dataset_config[2] == "video": | |
return "video" | |
elif dataset_config[-1] == "only_video": | |
return "only_video" | |
else: | |
return "image" | |
def create_dataset(dataset_type, config): | |
if "clip" in config.model.get("vit_model", 'vit'): | |
mean = (0.485, 0.456, 0.406) | |
std = (0.229, 0.224, 0.225) | |
else: | |
vision_enc_name = config.model.vision_encoder.name | |
if "swin" in vision_enc_name or "vit" in vision_enc_name: | |
mean = (0.485, 0.456, 0.406) | |
std = (0.229, 0.224, 0.225) | |
elif "beit" in vision_enc_name: | |
mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning | |
std = (0.5, 0.5, 0.5) | |
elif "clip" in vision_enc_name: | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
std = (0.26862954, 0.26130258, 0.27577711) | |
else: | |
raise ValueError | |
normalize = transforms.Normalize(mean, std) | |
# loaded images and videos are torch.Tensor of torch.uint8 format, | |
# ordered as (T, 1 or 3, H, W) where T=1 for image | |
type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) | |
if config.inputs.video_input.random_aug: | |
aug_transform = transforms.RandAugment() | |
else: | |
aug_transform = transforms.Lambda(lambda x: x) | |
train_transform = transforms.Compose( | |
[ | |
aug_transform, | |
transforms.RandomResizedCrop( | |
config.inputs.image_res, | |
scale=(0.5, 1.0), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
transforms.RandomHorizontalFlip(), | |
type_transform, | |
normalize, | |
] | |
) | |
test_transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
(config.inputs.image_res, config.inputs.image_res), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
type_transform, | |
normalize, | |
] | |
) | |
video_reader_type = config.inputs.video_input.get("video_reader_type", "decord") | |
video_only_dataset_kwargs_train = dict( | |
video_reader_type=video_reader_type, | |
sample_type=config.inputs.video_input.sample_type, | |
num_frames=config.inputs.video_input.num_frames, | |
num_tries=3, # false tolerance | |
) | |
if dataset_type == "pt_train": | |
raise ValueError("NOT PRETRAINING YET") | |
elif dataset_type in ["it_train"]: | |
# convert to list of lists | |
train_files = ( | |
[config.train_file] if isinstance(config.train_file[0], str) else config.train_file | |
) | |
train_media_types = sorted(list({get_media_type(e) for e in train_files})) | |
train_datasets = [] | |
for m in train_media_types: | |
dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset | |
# dataset of the same media_type will be mixed in a single Dataset object | |
_train_files = [e for e in train_files if get_media_type(e) == m] | |
datasets = [] | |
for train_file in _train_files: | |
dataset_kwargs = dict( | |
ann_file=train_file, | |
transform=train_transform, | |
mm_alone=config.preprocess.get("mm_alone", True), | |
add_second_msg=config.preprocess.get("add_second_msg", True), | |
skip_short_sample=config.preprocess.get("skip_short_sample", False), | |
clip_transform=config.preprocess.get("clip_transform", False), | |
random_shuffle=config.preprocess.get("random_shuffle", True), | |
system=config.preprocess.get("system", ""), | |
role=config.preprocess.get('roles', ("Human", "Assistant")), | |
end_signal=config.preprocess.get('end_signal', "###"), | |
begin_signal=config.preprocess.get('begin_signal', ""), | |
) | |
if m == "video": | |
video_only_dataset_kwargs_train.update({ | |
"start_token": config.model.get("start_token", "<Video>"), | |
"end_token": config.model.get("end_token", "</Video>"), | |
}) | |
dataset_kwargs.update(video_only_dataset_kwargs_train) | |
if "tgif" in train_file[1]: | |
video_only_dataset_kwargs_train.update({ | |
"video_reader_type": "gif" | |
}) | |
dataset_kwargs.update(video_only_dataset_kwargs_train) | |
elif "webvid" in train_file[1]: | |
video_only_dataset_kwargs_train.update({ | |
"video_reader_type": "hdfs" | |
}) | |
else: | |
video_only_dataset_kwargs_train.update({ | |
"video_reader_type": "decord" | |
}) | |
dataset_kwargs.update(video_only_dataset_kwargs_train) | |
datasets.append(dataset_cls(**dataset_kwargs)) | |
dataset = ConcatDataset(datasets) | |
train_datasets.append(dataset) | |
return train_datasets | |
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): | |
loaders = [] | |
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( | |
datasets, samplers, batch_size, num_workers, is_trains, collate_fns | |
): | |
if is_train: | |
shuffle = sampler is None | |
drop_last = True | |
else: | |
shuffle = False | |
drop_last = False | |
loader = DataLoader( | |
dataset, | |
batch_size=bs, | |
num_workers=n_worker, | |
pin_memory=False, | |
sampler=sampler, | |
shuffle=shuffle, | |
collate_fn=collate_fn, | |
drop_last=drop_last, | |
persistent_workers=True if n_worker > 0 else False, | |
) | |
loaders.append(loader) | |
return loaders | |