Spaces:
Running
on
A10G
Running
on
A10G
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import os | |
from video_llama.datasets.datasets.base_dataset import BaseDataset | |
from video_llama.datasets.datasets.caption_datasets import CaptionDataset | |
import pandas as pd | |
import decord | |
from decord import VideoReader | |
import random | |
import torch | |
from torch.utils.data.dataloader import default_collate | |
class WebvidDataset(BaseDataset): | |
def __init__(self, vis_processor, text_processor, vis_root, ann_root): | |
""" | |
vis_root (string): Root directory of video (e.g. webvid_eval/video/) | |
ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) | |
split (string): val or test | |
""" | |
super().__init__(vis_processor=vis_processor, text_processor=text_processor) | |
# 读取一个路径下所有的 | |
ts_df = [] | |
for file_name in os.listdir(ann_root): | |
if file_name.endswith('.csv'): | |
df = pd.read_csv(os.path.join(ann_root, file_name)) | |
ts_df.append(df) | |
merged_df = pd.concat(ts_df) | |
self.annotation = merged_df | |
self.vis_root = vis_root | |
self.resize_size = 224 | |
self.num_frm = 8 | |
self.frm_sampling_strategy = 'headtail' | |
def _get_video_path(self, sample): | |
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') | |
full_video_fp = os.path.join(self.vis_root, rel_video_fp) | |
return full_video_fp | |
def __getitem__(self, index): | |
num_retries = 10 # skip error videos | |
for _ in range(num_retries): | |
sample = self.annotation.iloc[index] | |
sample_dict = sample.to_dict() | |
video_id = sample_dict['videoid'] | |
if 'name' in sample_dict.keys(): | |
text = sample_dict['name'].strip() | |
else: | |
raise NotImplementedError("Un-supported text annotation format.") | |
# fetch video | |
video_path = self._get_video_path(sample_dict) | |
# if os.path.exists(video_path): | |
try: | |
video = self.vis_processor(video_path) | |
except: | |
print(f"Failed to load examples with video: {video_path}. " | |
f"Will randomly sample an example as a replacement.") | |
index = random.randint(0, len(self) - 1) | |
continue | |
caption = self.text_processor(text) | |
# print(video.size()) | |
if video is None or caption is None \ | |
or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]): | |
print(f"Failed to load examples with video: {video_path}. " | |
f"Will randomly sample an example as a replacement.") | |
index = random.randint(0, len(self) - 1) | |
continue | |
else: | |
break | |
else: | |
raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") | |
# "image_id" is kept to stay compatible with the COCO evaluation format | |
return { | |
"image": video, | |
"text_input": caption, | |
"type":'video', | |
} | |
def __len__(self): | |
return len(self.annotation) | |
# def collater(self, samples): | |
# new_result = {} | |
# new_result['image'] = default_collate( [sample["image"] for sample in samples]) | |
# new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples]) | |
# return new_result | |
class WebvidDatasetEvalDataset(BaseDataset): | |
def __init__(self, vis_processor, text_processor, vis_root, ann_paths): | |
""" | |
vis_root (string): Root directory of images (e.g. coco/images/) | |
ann_root (string): directory to store the annotation file | |
split (string): val or test | |
""" | |
super().__init__(vis_processor, text_processor, vis_root, ann_paths) | |
def __getitem__(self, index): | |
ann = self.annotation[index] | |
vname = ann["video"] | |
video_path = os.path.join(self.vis_root, vname) | |
video = self.vis_processor(video_path) | |
return { | |
"video": video, | |
"image_id": ann["image_id"], | |
"instance_id": ann["instance_id"], | |
} | |