pllava-7b-demo / dataset /__init__.py
cathyxl
added
f239efc
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset
def get_media_type(dataset_config):
if len(dataset_config) == 3 and dataset_config[2] == "video":
return "video"
elif dataset_config[-1] == "only_video":
return "only_video"
else:
return "image"
def create_dataset(dataset_type, config):
if "clip" in config.model.get("vit_model", 'vit'):
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
else:
vision_enc_name = config.model.vision_encoder.name
if "swin" in vision_enc_name or "vit" in vision_enc_name:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
elif "beit" in vision_enc_name:
mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning
std = (0.5, 0.5, 0.5)
elif "clip" in vision_enc_name:
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
else:
raise ValueError
normalize = transforms.Normalize(mean, std)
# loaded images and videos are torch.Tensor of torch.uint8 format,
# ordered as (T, 1 or 3, H, W) where T=1 for image
type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
if config.inputs.video_input.random_aug:
aug_transform = transforms.RandAugment()
else:
aug_transform = transforms.Lambda(lambda x: x)
train_transform = transforms.Compose(
[
aug_transform,
transforms.RandomResizedCrop(
config.inputs.image_res,
scale=(0.5, 1.0),
interpolation=InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(),
type_transform,
normalize,
]
)
test_transform = transforms.Compose(
[
transforms.Resize(
(config.inputs.image_res, config.inputs.image_res),
interpolation=InterpolationMode.BICUBIC,
),
type_transform,
normalize,
]
)
video_reader_type = config.inputs.video_input.get("video_reader_type", "decord")
video_only_dataset_kwargs_train = dict(
video_reader_type=video_reader_type,
sample_type=config.inputs.video_input.sample_type,
num_frames=config.inputs.video_input.num_frames,
num_tries=3, # false tolerance
)
if dataset_type == "pt_train":
raise ValueError("NOT PRETRAINING YET")
elif dataset_type in ["it_train"]:
# convert to list of lists
train_files = (
[config.train_file] if isinstance(config.train_file[0], str) else config.train_file
)
train_media_types = sorted(list({get_media_type(e) for e in train_files}))
train_datasets = []
for m in train_media_types:
dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset
# dataset of the same media_type will be mixed in a single Dataset object
_train_files = [e for e in train_files if get_media_type(e) == m]
datasets = []
for train_file in _train_files:
dataset_kwargs = dict(
ann_file=train_file,
transform=train_transform,
mm_alone=config.preprocess.get("mm_alone", True),
add_second_msg=config.preprocess.get("add_second_msg", True),
skip_short_sample=config.preprocess.get("skip_short_sample", False),
clip_transform=config.preprocess.get("clip_transform", False),
random_shuffle=config.preprocess.get("random_shuffle", True),
system=config.preprocess.get("system", ""),
role=config.preprocess.get('roles', ("Human", "Assistant")),
end_signal=config.preprocess.get('end_signal', "###"),
begin_signal=config.preprocess.get('begin_signal', ""),
)
if m == "video":
video_only_dataset_kwargs_train.update({
"start_token": config.model.get("start_token", "<Video>"),
"end_token": config.model.get("end_token", "</Video>"),
})
dataset_kwargs.update(video_only_dataset_kwargs_train)
if "tgif" in train_file[1]:
video_only_dataset_kwargs_train.update({
"video_reader_type": "gif"
})
dataset_kwargs.update(video_only_dataset_kwargs_train)
elif "webvid" in train_file[1]:
video_only_dataset_kwargs_train.update({
"video_reader_type": "hdfs"
})
else:
video_only_dataset_kwargs_train.update({
"video_reader_type": "decord"
})
dataset_kwargs.update(video_only_dataset_kwargs_train)
datasets.append(dataset_cls(**dataset_kwargs))
dataset = ConcatDataset(datasets)
train_datasets.append(dataset)
return train_datasets
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=False,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
persistent_workers=True if n_worker > 0 else False,
)
loaders.append(loader)
return loaders