舟勤
v1
45d16e9
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
from video_llama.datasets.datasets.base_dataset import BaseDataset
from video_llama.datasets.datasets.caption_datasets import CaptionDataset
import pandas as pd
import decord
from decord import VideoReader
import random
import torch
from torch.utils.data.dataloader import default_collate
class WebvidDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_root):
"""
vis_root (string): Root directory of video (e.g. webvid_eval/video/)
ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
split (string): val or test
"""
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
# 读取一个路径下所有的
ts_df = []
for file_name in os.listdir(ann_root):
if file_name.endswith('.csv'):
df = pd.read_csv(os.path.join(ann_root, file_name))
ts_df.append(df)
merged_df = pd.concat(ts_df)
self.annotation = merged_df
self.vis_root = vis_root
self.resize_size = 224
self.num_frm = 8
self.frm_sampling_strategy = 'headtail'
def _get_video_path(self, sample):
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
full_video_fp = os.path.join(self.vis_root, rel_video_fp)
return full_video_fp
def __getitem__(self, index):
num_retries = 10 # skip error videos
for _ in range(num_retries):
sample = self.annotation.iloc[index]
sample_dict = sample.to_dict()
video_id = sample_dict['videoid']
if 'name' in sample_dict.keys():
text = sample_dict['name'].strip()
else:
raise NotImplementedError("Un-supported text annotation format.")
# fetch video
video_path = self._get_video_path(sample_dict)
# if os.path.exists(video_path):
try:
video = self.vis_processor(video_path)
except:
print(f"Failed to load examples with video: {video_path}. "
f"Will randomly sample an example as a replacement.")
index = random.randint(0, len(self) - 1)
continue
caption = self.text_processor(text)
# print(video.size())
if video is None or caption is None \
or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]):
print(f"Failed to load examples with video: {video_path}. "
f"Will randomly sample an example as a replacement.")
index = random.randint(0, len(self) - 1)
continue
else:
break
else:
raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
# "image_id" is kept to stay compatible with the COCO evaluation format
return {
"image": video,
"text_input": caption,
"type":'video',
}
def __len__(self):
return len(self.annotation)
# def collater(self, samples):
# new_result = {}
# new_result['image'] = default_collate( [sample["image"] for sample in samples])
# new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples])
# return new_result
class WebvidDatasetEvalDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
def __getitem__(self, index):
ann = self.annotation[index]
vname = ann["video"]
video_path = os.path.join(self.vis_root, vname)
video = self.vis_processor(video_path)
return {
"video": video,
"image_id": ann["image_id"],
"instance_id": ann["instance_id"],
}