from os.path import basename import numpy as np import logging import torch from dataset.base_dataset import BaseDataset from dataset.utils import load_anno, pre_text from dataset.video_utils import VIDEO_READER_FUNCS from dataset.text_prompt import kinetics_templates_action_clip as kinetics_templates logger = logging.getLogger(__name__) class AudioTxtRetTrainDataset(BaseDataset): media_type = "audio" def __init__( self, ann_file, transform, audio_sample_rate, audio_reader_type='librosa', max_audio_length=0, num_tries=3): super(AudioTxtRetTrainDataset, self).__init__() self.anno_list = load_anno(ann_file) self.transform = transform self.audio_reader_type = audio_reader_type self.num_tries = num_tries self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False) self.trimmed30 = ann_file.get("trimmed30", False) self.max_audio_length = max_audio_length self.audio_sample_rate = audio_sample_rate self.match_ids = {} n = 0 for ann in self.anno_list: key = ann["caption"] if self.has_multi_audio_gt else basename(ann["image"]) if key not in self.match_ids: self.match_ids[key] = n n += 1 def __len__(self): return len(self.anno_list) def __getitem__(self, index): try: ann = self.anno_list[index] audio, index = self.load_and_transform_media_data(index, ann['image']) caption = pre_text(ann["caption"]) key = ann["caption"] if self.has_multi_audio_gt else basename(ann["image"]) return audio, caption, self.match_ids[key] except Exception as e: logger.error(e) print(e, flush=True) index = np.random.randint(0, len(self)) return self.__getitem__(index) class AudioTxtRetEvalDataset(BaseDataset): media_type = "audio" def __init__( self, ann_file, transform, audio_sample_rate, audio_reader_type='librosa', max_audio_length=0, num_tries=3): super(AudioTxtRetEvalDataset, self).__init__() self.anno_list = load_anno(ann_file) self.transform = transform self.audio_sample_rate = audio_sample_rate self.max_audio_length = max_audio_length self.audio_reader_type = audio_reader_type self.num_tries = num_tries self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False) self.trimmed30 = ann_file.get("trimmed30", False) self.max_txt_l = ann_file.get("max_txt_l", 32) self.text = None self.audio = None self.txt2img = None self.img2txt = None self.build_data() def build_data(self): self.text = [] self.audio = [] self.txt2img = {} self.img2txt = {} if self.has_multi_audio_gt: self.build_data_multi_audio_gt() else: self.build_data_multi_txt_gt() def build_data_multi_audio_gt(self): """each text may have multiple ground_truth audio, e.g., ssv2""" audio_id = 0 for txt_id, ann in enumerate(self.anno_list): self.text.append(pre_text(ann["caption"])) self.txt2img[txt_id] = [] _audios = ann["image"] \ if isinstance(ann["image"], list) else [ann["image"], ] for i, audio in enumerate(_audios): self.audio.append(audio) self.txt2img[txt_id].append(audio_id) self.img2txt[audio_id] = txt_id audio_id += 1 def build_data_multi_txt_gt(self): """each audio may have multiple ground_truth text, e.g., COCO and Flickr30K""" txt_id = 0 for audio_id, ann in enumerate(self.anno_list): self.audio.append(ann["image"]) self.img2txt[audio_id] = [] _captions = ann["caption"] \ if isinstance(ann["caption"], list) else [ann["caption"], ] for i, caption in enumerate(_captions): self.text.append(pre_text(caption)) self.img2txt[audio_id].append(txt_id) self.txt2img[txt_id] = audio_id txt_id += 1 def __len__(self): return len(self.anno_list) def __getitem__(self, index): ann = self.anno_list[index] audio, index = self.load_and_transform_media_data(index, ann["image"]) return audio, index class ImgTxtRetTrainDataset(BaseDataset): media_type = "image" def __init__(self, ann_file, transform): super(ImgTxtRetTrainDataset, self).__init__() self.anno_list = load_anno(ann_file) self.transform = transform # each caption has multiple image as ground_truth, e.g., ssv2 self.has_multi_txt_gt = ann_file.get("has_multi_txt_gt", False) self.has_multi_vision_gt = ann_file.get("has_multi_vision_gt", False) if self.has_multi_txt_gt: logger.info("The dataset has multiple ground truth for a image/video!") tmp_anno_list = [] for ann in self.anno_list: img_path = ann["image"] for caption in ann["caption"]: tmp_anno_list.append({ "image": img_path, "caption": caption }) self.anno_list = tmp_anno_list self.match_ids = {} n = 0 for ann in self.anno_list: key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"]) if key not in self.match_ids: self.match_ids[key] = n n += 1 def __len__(self): return len(self.anno_list) def __getitem__(self, index): try: ann = self.anno_list[index] image, index = self.load_and_transform_media_data(index, ann["image"]) caption = pre_text(ann["caption"]) key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"]) return image, caption, self.match_ids[key] except Exception as e: logger.error(e) print(e, flush=True) index = np.random.randint(0, len(self)) return self.__getitem__(index) class VidTxtRetTrainDataset(ImgTxtRetTrainDataset): media_type = "video" def __init__( self, ann_file, transform, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3): super(VidTxtRetTrainDataset, self).__init__(ann_file, transform) self.num_frames = num_frames self.video_reader_type = video_reader_type self.video_reader = VIDEO_READER_FUNCS[video_reader_type] self.sample_type = sample_type self.num_tries = num_tries self.read_clip_from_video = ann_file.get("read_clip_from_video", False) if self.read_clip_from_video: raise NotImplementedError("key for match_ids is not implemented!") self.is_paragraph_retrieval = ann_file.get("is_paragraph_retrieval", False) if self.is_paragraph_retrieval: self.anno_list = preprocess_para_retrieval_data(self.anno_list) self.trimmed30 = ann_file.get("trimmed30", False) if self.trimmed30: logger.info("Trimming the video, only use the first 30s!") class AudioVidTxtRetTrainDataset(VidTxtRetTrainDataset): media_type = "audio_video" def __init__( self, ann_file, transform, audio_sample_rate=16000, audio_reader_type='torchaudio', max_audio_length=10, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3): super(AudioVidTxtRetTrainDataset, self).__init__(ann_file, transform, num_frames=num_frames, video_reader_type=video_reader_type, sample_type=sample_type, num_tries=num_tries) assert self.media_type == 'audio_video', self.media_type self.audio_sample_rate = audio_sample_rate self.audio_reader_type = audio_reader_type self.max_audio_length = max_audio_length self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False) self.read_audio_from_video = ann_file.get("read_audio_from_video", False) self.zero_audio_padding_for_video = ann_file.get("zero_audio_padding_for_video", False) def __getitem__(self, index): try: ann = self.anno_list[index] caption = pre_text(ann["caption"]) data_path = {'video': ann["image"]} data_path["read_clip_from_video"] = self.read_clip_from_video if "audio" in ann.keys(): data_path["read_audio_from_video"] = False data_path["audio"] = ann["audio"] else: data_path["read_audio_from_video"] = self.read_audio_from_video media, index = self.load_and_transform_media_data(index, data_path) audio = media[0] if audio is None and self.zero_audio_padding_for_video: logger.warning(f"No audio in {data_path}") media[0] = torch.zeros((998, 64), dtype=torch.float32) key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"]) return media, caption, self.match_ids[key] except Exception as e: logger.error(e) print(e, flush=True) index = np.random.randint(0, len(self)) return self.__getitem__(index) class ImgTxtRetEvalDataset(BaseDataset): media_type = "image" def __init__(self, ann_file, transform): super(ImgTxtRetEvalDataset, self).__init__() self.raw_anno_list = load_anno(ann_file) self.transform = transform self.has_multi_vision_gt = ann_file.get("has_multi_vision_gt", False) # each caption has multiple image as ground_truth self.is_act_rec = ann_file.get("is_act_rec", False) self.max_txt_l = ann_file.get("max_txt_l", 32) # NOTE self.text = None self.image = None self.txt2img = None self.img2txt = None self.build_data() def build_data(self): self.text = [] self.image = [] self.txt2img = {} self.img2txt = {} if self.is_act_rec: self.build_data_act_rec() elif self.has_multi_vision_gt: self.build_data_multi_img_gt() else: self.build_data_multi_txt_gt() self.anno_list = [dict(image=e) for e in self.image] def build_data_act_rec(self): """action recognition task, e.g., kinetics400""" text = list(set([e["caption"] for e in self.raw_anno_list])) text2label = {e: i for i, e in enumerate(text)} text = [[t.format(e) for t in kinetics_templates] for e in text] text = [e for l in text for e in l] self.text = [pre_text(e) for e in text] self.num_prompts = len(kinetics_templates) self.img2txt = {i: text2label[e["caption"]] for i, e in enumerate(self.raw_anno_list)} self.txt2img = [[] for _ in range(len(text) // len(kinetics_templates))] for i, e in enumerate(self.raw_anno_list): self.image.append(e["image"]) self.txt2img[text2label[e["caption"]]].append(i) logger.info(f"Action recognition, number of prompts: {self.num_prompts}") logger.info(f"Action recognition, number of classes: {len(self.text)}") def build_data_multi_img_gt(self): """each text may have multiple ground_truth image, e.g., ssv2""" img_id = 0 for txt_id, ann in enumerate(self.raw_anno_list): self.text.append(pre_text(ann["caption"])) self.txt2img[txt_id] = [] _images = ann["image"] \ if isinstance(ann["image"], list) else [ann["image"], ] for i, image in enumerate(_images): self.image.append(image) self.txt2img[txt_id].append(img_id) self.img2txt[img_id] = txt_id img_id += 1 def build_data_multi_txt_gt(self): """each image may have multiple ground_truth text, e.g., COCO and Flickr30K""" txt_id = 0 for img_id, ann in enumerate(self.raw_anno_list): self.image.append(ann["image"]) self.img2txt[img_id] = [] _captions = ann["caption"] \ if isinstance(ann["caption"], list) else [ann["caption"], ] for i, caption in enumerate(_captions): self.text.append(pre_text(caption)) self.img2txt[img_id].append(txt_id) self.txt2img[txt_id] = img_id txt_id += 1 def __len__(self): return len(self.anno_list) def __getitem__(self, index): ann = self.anno_list[index] image, index = self.load_and_transform_media_data(index, ann["image"]) return image, index class VidTxtRetEvalDataset(ImgTxtRetEvalDataset): media_type = "video" def __init__( self, ann_file, transform, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1): super(VidTxtRetEvalDataset, self).__init__(ann_file, transform) self.num_frames = num_frames self.video_reader_type = video_reader_type self.video_reader = VIDEO_READER_FUNCS[video_reader_type] self.sample_type = sample_type self.num_tries = num_tries self.is_paragraph_retrieval = ann_file.get("is_paragraph_retrieval", False) if self.is_paragraph_retrieval: logger.info("Preprocess paragraph retrieval data!!!") self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list) self.trimmed30 = ann_file.get("trimmed30", False) if self.trimmed30: logger.info("Trimming the video, only use the first 30s!!!") self.read_clip_from_video = ann_file.get("read_clip_from_video", False) self.use_subtitle = ann_file.get("use_subtitle", False) if self.use_subtitle: if self.is_act_rec: raise NotImplementedError self.build_subtitle_data() self.build_data() def __getitem__(self, index): ann = self.anno_list[index] if self.read_clip_from_video: raise NotImplementedError("key for match_ids is not implemented!") else: data_path = ann["image"] image, index = self.load_and_transform_media_data(index, data_path) return image, index def build_subtitle_data(self): self.subtitle = [] for _, ann in enumerate(self.raw_anno_list): if self.trimmed30: if "asr_trimmed_30" in ann.keys(): self.subtitle.append(pre_text(ann["asr_trimmed_30"])) else: self.subtitle.append("") else: if "asr" in ann.keys(): self.subtitle.append(pre_text(ann["asr"])) else: self.subtitle.append("") def preprocess_para_retrieval_data(anno_list): processed_anno_list = [] for d in anno_list: d["caption"] = " ".join(d.pop("caption")) processed_anno_list.append(d) return processed_anno_list class VidTxtRetMCEvalDataset(BaseDataset): """For MSRVTT-MC test task""" media_type = "video" def __init__(self, ann_file, transform, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1): super(VidTxtRetMCEvalDataset, self).__init__() self.anno_list = load_anno(ann_file) self.transform = transform # video args self.num_frames = num_frames self.video_reader_type = video_reader_type self.video_reader = VIDEO_READER_FUNCS[video_reader_type] self.sample_type = sample_type self.num_tries = num_tries def __len__(self): return len(self.anno_list) def __getitem__(self, index): ann = self.anno_list[index] image, index = self.load_and_transform_media_data(index, ann["image"]) caption = [pre_text(e) for e in ann["caption"]] # len=5 answer = ann["answer"] return image, caption, answer, ann class VidTxtRetMCNewEvalDataset(BaseDataset): """For SSV2-MC and Charades-MC test task""" media_type = "video" def __init__(self, ann_file, transform, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1): super(VidTxtRetMCNewEvalDataset, self).__init__() self.anno_list = load_anno(ann_file) self.transform = transform # video args self.num_frames = num_frames self.video_reader_type = video_reader_type self.video_reader = VIDEO_READER_FUNCS[video_reader_type] self.sample_type = sample_type self.num_tries = num_tries def __len__(self): return len(self.anno_list) def __getitem__(self, index): ann = self.anno_list[index] image, index = self.load_and_transform_media_data(index, ann["image"]) option = [pre_text(e) for e in ann["option"]] # len=174 answer = ann["answer"] if isinstance(answer, list): answer = torch.Tensor(answer) return image, option, answer, ann class AudioVidTxtRetEvalDataset(VidTxtRetEvalDataset): media_type = "audio_video" def __init__( self, ann_file, transform, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1, audio_sample_rate=16000, audio_reader_type='torchaudio', max_audio_length=10): super(AudioVidTxtRetEvalDataset, self).__init__(ann_file, transform, num_frames=num_frames, video_reader_type=video_reader_type, sample_type=sample_type, num_tries=num_tries) self.audio_sample_rate = audio_sample_rate self.audio_reader_type = audio_reader_type self.max_audio_length = max_audio_length self.read_clip_from_video = ann_file.get("read_clip_from_video", False) self.read_audio_from_video = ann_file.get("read_audio_from_video", False) self.zero_audio_padding_for_video = ann_file.get("zero_audio_padding_for_video", False) def __getitem__(self, index): ann = self.anno_list[index] data_path = {'video': ann["image"]} if self.read_clip_from_video: raise NotImplementedError("Need to modify load_anno!") if not self.read_audio_from_video: raise NotImplementedError("Need to modify load_anno!") data_path["read_clip_from_video"] = self.read_clip_from_video data_path["read_audio_from_video"] = self.read_audio_from_video media, index = self.load_and_transform_media_data(index, data_path) audio = media[0] if audio is None and self.zero_audio_padding_for_video: media[0] = torch.zeros((998, 64), dtype=torch.float32) return media, index