Spaces:
Running
on
A10G
Running
on
A10G
import os | |
from PIL import Image | |
import webdataset as wds | |
from video_llama.datasets.datasets.base_dataset import BaseDataset | |
from video_llama.datasets.datasets.caption_datasets import CaptionDataset | |
class CCSBUDataset(BaseDataset): | |
def __init__(self, vis_processor, text_processor, location): | |
super().__init__(vis_processor=vis_processor, text_processor=text_processor) | |
self.inner_dataset = wds.DataPipeline( | |
wds.ResampledShards(location), | |
wds.tarfile_to_samples(handler=wds.warn_and_continue), | |
wds.shuffle(1000, handler=wds.warn_and_continue), | |
wds.decode("pilrgb", handler=wds.warn_and_continue), | |
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), | |
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), | |
wds.map(self.to_dict, handler=wds.warn_and_continue), | |
) | |
def to_dict(self, sample): | |
return { | |
"image": sample[0], | |
"text_input": self.text_processor(sample[1]["caption"]), | |
"type":'image', | |
} | |
class CCSBUAlignDataset(CaptionDataset): | |
def __getitem__(self, index): | |
# TODO this assumes image input, not general enough | |
ann = self.annotation[index] | |
img_file = '{}.jpg'.format(ann["image_id"]) | |
image_path = os.path.join(self.vis_root, img_file) | |
image = Image.open(image_path).convert("RGB") | |
image = self.vis_processor(image) | |
caption = ann["caption"] | |
return { | |
"image": image, | |
"text_input": caption, | |
"image_id": self.img_ids[ann["image_id"]], | |
"type":'image', | |
} |