""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import os from video_llama.datasets.datasets.base_dataset import BaseDataset from video_llama.datasets.datasets.caption_datasets import CaptionDataset import pandas as pd import decord from decord import VideoReader import random import torch from torch.utils.data.dataloader import default_collate class WebvidDataset(BaseDataset): def __init__(self, vis_processor, text_processor, vis_root, ann_root): """ vis_root (string): Root directory of video (e.g. webvid_eval/video/) ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) split (string): val or test """ super().__init__(vis_processor=vis_processor, text_processor=text_processor) # 读取一个路径下所有的 ts_df = [] for file_name in os.listdir(ann_root): if file_name.endswith('.csv'): df = pd.read_csv(os.path.join(ann_root, file_name)) ts_df.append(df) merged_df = pd.concat(ts_df) self.annotation = merged_df self.vis_root = vis_root self.resize_size = 224 self.num_frm = 8 self.frm_sampling_strategy = 'headtail' def _get_video_path(self, sample): rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') full_video_fp = os.path.join(self.vis_root, rel_video_fp) return full_video_fp def __getitem__(self, index): num_retries = 10 # skip error videos for _ in range(num_retries): sample = self.annotation.iloc[index] sample_dict = sample.to_dict() video_id = sample_dict['videoid'] if 'name' in sample_dict.keys(): text = sample_dict['name'].strip() else: raise NotImplementedError("Un-supported text annotation format.") # fetch video video_path = self._get_video_path(sample_dict) # if os.path.exists(video_path): try: video = self.vis_processor(video_path) except: print(f"Failed to load examples with video: {video_path}. " f"Will randomly sample an example as a replacement.") index = random.randint(0, len(self) - 1) continue caption = self.text_processor(text) # print(video.size()) if video is None or caption is None \ or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]): print(f"Failed to load examples with video: {video_path}. " f"Will randomly sample an example as a replacement.") index = random.randint(0, len(self) - 1) continue else: break else: raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") # "image_id" is kept to stay compatible with the COCO evaluation format return { "image": video, "text_input": caption, "type":'video', } def __len__(self): return len(self.annotation) # def collater(self, samples): # new_result = {} # new_result['image'] = default_collate( [sample["image"] for sample in samples]) # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples]) # return new_result class WebvidDatasetEvalDataset(BaseDataset): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file split (string): val or test """ super().__init__(vis_processor, text_processor, vis_root, ann_paths) def __getitem__(self, index): ann = self.annotation[index] vname = ann["video"] video_path = os.path.join(self.vis_root, vname) video = self.vis_processor(video_path) return { "video": video, "image_id": ann["image_id"], "instance_id": ann["instance_id"], }