Spaces:
Runtime error
Runtime error
File size: 7,957 Bytes
f239efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import logging
import os
import json
import sqlite3
import random
from os.path import basename
import numpy as np
import datetime
from dataset.base_dataset import ImageVideoBaseDataset
from dataset.video_utils import VIDEO_READER_FUNCS
logger = logging.getLogger(__name__)
IMAGE_TOKEN="<image>"
class ITImgTrainDataset(ImageVideoBaseDataset):
media_type = "image"
def __init__(
self, ann_file, transform,
system="", role=("Human", "Assistant"),
mm_alone=True,
add_second_msg=True,
start_token="<Image>", end_token="</Image>",
random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle
begin_signal=None,
end_signal=None,
clip_transform=False,
skip_short_sample=False,
):
super().__init__()
self.mm_alone = mm_alone
self.clip_transform = clip_transform
if len(ann_file) == 3 and ann_file[2] == "video":
self.media_type = "video"
else:
self.media_type = "image"
self.label_file, self.data_root = ann_file[:2]
logger.info('Load json file')
with open(self.label_file, 'r') as f:
self.anno = json.load(f)
self.num_examples = len(self.anno)
self.transform = transform
annos = []
for ann in self.anno:
filename = ann['video'] if 'video' in ann else ann['image']
if self.media_type =='video' and "webvid" in self.data_root:
video_id, extension = os.path.splitext(os.path.basename(filename))
if video_id not in self.keys_indexfile:
pass
else:
annos.append(ann)
else:
if filename is None or filename=="None":
pass
else:
if os.path.exists(os.path.join(self.data_root, filename)):
annos.append(ann)
else:
...
self.anno = annos
self.num_examples = len(self.anno)
# prompt parameters
if system:
assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token."
# currently not support add start_token and end_token in the system, since the msg should be added properly
self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal
self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal
self.start_token = start_token
self.end_token = end_token
self.system = system
self.role = role
self.random_shuffle = random_shuffle
# instruction location and number
logger.info(f"Random shuffle: {self.random_shuffle}")
def get_anno(self, index):
filename = self.anno[index][self.media_type]
qa = self.anno[index]["QA"]
if "start" in self.anno[index] and "end" in self.anno[index]:
anno = {
"image": os.path.join(self.data_root, filename), "qa": qa,
"start": self.anno[index]["start"], "end": self.anno[index]["end"],
}
else:
anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
return anno
def __len__(self):
return self.num_examples
def process_qa(self, qa, msg=""):
cur_instruction = ""
# randomly shuffle qa for conversation
if self.random_shuffle and len(qa) > 1:
random.shuffle(qa)
if "i" in qa[0].keys() and qa[0]["i"] != "":
cur_instruction = qa[0]["i"] + self.end_signal[0]
conversation = self.system
# add instruction as system message
if cur_instruction:
conversation += cur_instruction
# rstrip() for the extra " " in msg
if self.mm_alone:
conversation += (
self.begin_signal[0] + self.role[0] +
self.start_token + self.end_token + msg.rstrip() + self.end_signal[0]
)
for i, sentence in enumerate(qa):
q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"]
a = sentence["a"]
if q != "":
conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1])
else:
# no question, often in caption dataset
pass
conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1])
if cur_instruction:
cur_instruction += qa[0]["q"]
return conversation, cur_instruction.strip()
def __getitem__(self, index):
try:
ann = self.get_anno(index)
image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform)
conversation, instruction = self.process_qa(ann["qa"])
return image, conversation, instruction, index
except Exception as e:
logger.warning(f"Caught exception {e} when loading image {ann['image']}")
index = np.random.randint(0, len(self))
return self.__getitem__(index)
class ITVidTrainDataset(ITImgTrainDataset):
media_type = "video"
def __init__(
self, ann_file, transform,
num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
mm_alone=True,
system="", role=("Human", "Assistant"),
start_token="<Video>", end_token="</Video>",
add_second_msg=True,
random_shuffle=True,
begin_signal=None,
end_signal=None,
clip_transform=False,
skip_short_sample=False,
):
# "id index file for webvid"
if "webvid" in ann_file[1]:
with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f:
self.keys_indexfile = json.load(f) # the correponding index file for each webvid id
super().__init__(
ann_file, transform,
system=system, role=role,
mm_alone=mm_alone,
start_token=start_token, end_token=end_token,
random_shuffle=random_shuffle,
begin_signal=begin_signal,
end_signal=end_signal,
clip_transform=clip_transform,
skip_short_sample=skip_short_sample,
)
self.num_frames = num_frames
self.video_reader_type = video_reader_type
self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
self.sample_type = sample_type
self.num_tries = num_tries
self.add_second_msg = add_second_msg
logger.info(f"Use {video_reader_type} for data in {ann_file}")
if add_second_msg:
logger.info(f"Add second message: The video contains X frames sampled at T seconds.")
def __getitem__(self, index):
try:
ann = self.get_anno(index)
msg = ""
clip = None
if "start" in ann and "end" in ann:
clip = [ann["start"], ann["end"]]
video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform)
if self.add_second_msg:
# " " should be added in the start and end
msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
conversation, instruction = self.process_qa(ann["qa"], msg)
return video, conversation, instruction, index
except Exception as e:
logger.warning(f"Caught exception {e} when loading video {ann['image']}")
index = np.random.randint(0, len(self))
return self.__getitem__(index) |