pllava-7b-demo / dataset /it_dataset.py
cathyxl
added
f239efc
import logging
import os
import json
import sqlite3
import random
from os.path import basename
import numpy as np
import datetime
from dataset.base_dataset import ImageVideoBaseDataset
from dataset.video_utils import VIDEO_READER_FUNCS
logger = logging.getLogger(__name__)
IMAGE_TOKEN="<image>"
class ITImgTrainDataset(ImageVideoBaseDataset):
media_type = "image"
def __init__(
self, ann_file, transform,
system="", role=("Human", "Assistant"),
mm_alone=True,
add_second_msg=True,
start_token="<Image>", end_token="</Image>",
random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle
begin_signal=None,
end_signal=None,
clip_transform=False,
skip_short_sample=False,
):
super().__init__()
self.mm_alone = mm_alone
self.clip_transform = clip_transform
if len(ann_file) == 3 and ann_file[2] == "video":
self.media_type = "video"
else:
self.media_type = "image"
self.label_file, self.data_root = ann_file[:2]
logger.info('Load json file')
with open(self.label_file, 'r') as f:
self.anno = json.load(f)
self.num_examples = len(self.anno)
self.transform = transform
annos = []
for ann in self.anno:
filename = ann['video'] if 'video' in ann else ann['image']
if self.media_type =='video' and "webvid" in self.data_root:
video_id, extension = os.path.splitext(os.path.basename(filename))
if video_id not in self.keys_indexfile:
pass
else:
annos.append(ann)
else:
if filename is None or filename=="None":
pass
else:
if os.path.exists(os.path.join(self.data_root, filename)):
annos.append(ann)
else:
...
self.anno = annos
self.num_examples = len(self.anno)
# prompt parameters
if system:
assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token."
# currently not support add start_token and end_token in the system, since the msg should be added properly
self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal
self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal
self.start_token = start_token
self.end_token = end_token
self.system = system
self.role = role
self.random_shuffle = random_shuffle
# instruction location and number
logger.info(f"Random shuffle: {self.random_shuffle}")
def get_anno(self, index):
filename = self.anno[index][self.media_type]
qa = self.anno[index]["QA"]
if "start" in self.anno[index] and "end" in self.anno[index]:
anno = {
"image": os.path.join(self.data_root, filename), "qa": qa,
"start": self.anno[index]["start"], "end": self.anno[index]["end"],
}
else:
anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
return anno
def __len__(self):
return self.num_examples
def process_qa(self, qa, msg=""):
cur_instruction = ""
# randomly shuffle qa for conversation
if self.random_shuffle and len(qa) > 1:
random.shuffle(qa)
if "i" in qa[0].keys() and qa[0]["i"] != "":
cur_instruction = qa[0]["i"] + self.end_signal[0]
conversation = self.system
# add instruction as system message
if cur_instruction:
conversation += cur_instruction
# rstrip() for the extra " " in msg
if self.mm_alone:
conversation += (
self.begin_signal[0] + self.role[0] +
self.start_token + self.end_token + msg.rstrip() + self.end_signal[0]
)
for i, sentence in enumerate(qa):
q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"]
a = sentence["a"]
if q != "":
conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1])
else:
# no question, often in caption dataset
pass
conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1])
if cur_instruction:
cur_instruction += qa[0]["q"]
return conversation, cur_instruction.strip()
def __getitem__(self, index):
try:
ann = self.get_anno(index)
image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform)
conversation, instruction = self.process_qa(ann["qa"])
return image, conversation, instruction, index
except Exception as e:
logger.warning(f"Caught exception {e} when loading image {ann['image']}")
index = np.random.randint(0, len(self))
return self.__getitem__(index)
class ITVidTrainDataset(ITImgTrainDataset):
media_type = "video"
def __init__(
self, ann_file, transform,
num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
mm_alone=True,
system="", role=("Human", "Assistant"),
start_token="<Video>", end_token="</Video>",
add_second_msg=True,
random_shuffle=True,
begin_signal=None,
end_signal=None,
clip_transform=False,
skip_short_sample=False,
):
# "id index file for webvid"
if "webvid" in ann_file[1]:
with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f:
self.keys_indexfile = json.load(f) # the correponding index file for each webvid id
super().__init__(
ann_file, transform,
system=system, role=role,
mm_alone=mm_alone,
start_token=start_token, end_token=end_token,
random_shuffle=random_shuffle,
begin_signal=begin_signal,
end_signal=end_signal,
clip_transform=clip_transform,
skip_short_sample=skip_short_sample,
)
self.num_frames = num_frames
self.video_reader_type = video_reader_type
self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
self.sample_type = sample_type
self.num_tries = num_tries
self.add_second_msg = add_second_msg
logger.info(f"Use {video_reader_type} for data in {ann_file}")
if add_second_msg:
logger.info(f"Add second message: The video contains X frames sampled at T seconds.")
def __getitem__(self, index):
try:
ann = self.get_anno(index)
msg = ""
clip = None
if "start" in ann and "end" in ann:
clip = [ann["start"], ann["end"]]
video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform)
if self.add_second_msg:
# " " should be added in the start and end
msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
conversation, instruction = self.process_qa(ann["qa"], msg)
return video, conversation, instruction, index
except Exception as e:
logger.warning(f"Caught exception {e} when loading video {ann['image']}")
index = np.random.randint(0, len(self))
return self.__getitem__(index)