import logging import os import json import sqlite3 import random from os.path import basename import numpy as np import datetime from dataset.base_dataset import ImageVideoBaseDataset from dataset.video_utils import VIDEO_READER_FUNCS logger = logging.getLogger(__name__) IMAGE_TOKEN="" class ITImgTrainDataset(ImageVideoBaseDataset): media_type = "image" def __init__( self, ann_file, transform, system="", role=("Human", "Assistant"), mm_alone=True, add_second_msg=True, start_token="", end_token="", random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle begin_signal=None, end_signal=None, clip_transform=False, skip_short_sample=False, ): super().__init__() self.mm_alone = mm_alone self.clip_transform = clip_transform if len(ann_file) == 3 and ann_file[2] == "video": self.media_type = "video" else: self.media_type = "image" self.label_file, self.data_root = ann_file[:2] logger.info('Load json file') with open(self.label_file, 'r') as f: self.anno = json.load(f) self.num_examples = len(self.anno) self.transform = transform annos = [] for ann in self.anno: filename = ann['video'] if 'video' in ann else ann['image'] if self.media_type =='video' and "webvid" in self.data_root: video_id, extension = os.path.splitext(os.path.basename(filename)) if video_id not in self.keys_indexfile: pass else: annos.append(ann) else: if filename is None or filename=="None": pass else: if os.path.exists(os.path.join(self.data_root, filename)): annos.append(ann) else: ... self.anno = annos self.num_examples = len(self.anno) # prompt parameters if system: assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token." # currently not support add start_token and end_token in the system, since the msg should be added properly self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal self.start_token = start_token self.end_token = end_token self.system = system self.role = role self.random_shuffle = random_shuffle # instruction location and number logger.info(f"Random shuffle: {self.random_shuffle}") def get_anno(self, index): filename = self.anno[index][self.media_type] qa = self.anno[index]["QA"] if "start" in self.anno[index] and "end" in self.anno[index]: anno = { "image": os.path.join(self.data_root, filename), "qa": qa, "start": self.anno[index]["start"], "end": self.anno[index]["end"], } else: anno = {"image": os.path.join(self.data_root, filename), "qa": qa} return anno def __len__(self): return self.num_examples def process_qa(self, qa, msg=""): cur_instruction = "" # randomly shuffle qa for conversation if self.random_shuffle and len(qa) > 1: random.shuffle(qa) if "i" in qa[0].keys() and qa[0]["i"] != "": cur_instruction = qa[0]["i"] + self.end_signal[0] conversation = self.system # add instruction as system message if cur_instruction: conversation += cur_instruction # rstrip() for the extra " " in msg if self.mm_alone: conversation += ( self.begin_signal[0] + self.role[0] + self.start_token + self.end_token + msg.rstrip() + self.end_signal[0] ) for i, sentence in enumerate(qa): q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"] a = sentence["a"] if q != "": conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1]) else: # no question, often in caption dataset pass conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1]) if cur_instruction: cur_instruction += qa[0]["q"] return conversation, cur_instruction.strip() def __getitem__(self, index): try: ann = self.get_anno(index) image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform) conversation, instruction = self.process_qa(ann["qa"]) return image, conversation, instruction, index except Exception as e: logger.warning(f"Caught exception {e} when loading image {ann['image']}") index = np.random.randint(0, len(self)) return self.__getitem__(index) class ITVidTrainDataset(ITImgTrainDataset): media_type = "video" def __init__( self, ann_file, transform, num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, mm_alone=True, system="", role=("Human", "Assistant"), start_token="", add_second_msg=True, random_shuffle=True, begin_signal=None, end_signal=None, clip_transform=False, skip_short_sample=False, ): # "id index file for webvid" if "webvid" in ann_file[1]: with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f: self.keys_indexfile = json.load(f) # the correponding index file for each webvid id super().__init__( ann_file, transform, system=system, role=role, mm_alone=mm_alone, start_token=start_token, end_token=end_token, random_shuffle=random_shuffle, begin_signal=begin_signal, end_signal=end_signal, clip_transform=clip_transform, skip_short_sample=skip_short_sample, ) self.num_frames = num_frames self.video_reader_type = video_reader_type self.video_reader = VIDEO_READER_FUNCS[video_reader_type] self.sample_type = sample_type self.num_tries = num_tries self.add_second_msg = add_second_msg logger.info(f"Use {video_reader_type} for data in {ann_file}") if add_second_msg: logger.info(f"Add second message: The video contains X frames sampled at T seconds.") def __getitem__(self, index): try: ann = self.get_anno(index) msg = "" clip = None if "start" in ann and "end" in ann: clip = [ann["start"], ann["end"]] video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform) if self.add_second_msg: # " " should be added in the start and end msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. " conversation, instruction = self.process_qa(ann["qa"], msg) return video, conversation, instruction, index except Exception as e: logger.warning(f"Caught exception {e} when loading video {ann['image']}") index = np.random.randint(0, len(self)) return self.__getitem__(index)