kangaroo / data_utils.py
WEBing's picture
add data_utils
5fe5ca4
raw
history blame
5.54 kB
import decord
import numpy as np
import torch
from PIL import Image
import random
from eva_clip.transform import image_transform
image_processor = image_transform(image_size=448, is_train=False)
def preprocess_multimodal(sources, num_segments):
for source in sources:
for sentence in source:
X_token = '<video>'
if X_token in sentence['content']:
replace_token = ""
ns = num_segments
ns = ns // 2 - 1
for _ in range(ns):
replace_token += "<image>"
replace_token += "<eof>"
replace_token += "<image>"
replace_token += "<eov>"
replace_token = '<vi_start>' + replace_token + '<vi_end>'
sentence["content"] = sentence["content"].replace(X_token, replace_token)
return sources
def preprocess(
sources,
tokenizer,
s_id=None,
):
en_qa_templates = [
"Review the given video and answer the question associated with its visual elements.",
"Watch the provided video and offer an accurate response to the related question.",
"Scrutinize the video carefully, identifying relevant details in order to address the linked question.",
"Take a close look at the presented visuals and deliver a precise answer to the corresponding question.",
"Observe the video attentively and accurately respond to the associated question.",
"View the video attentively and provide a suitable answer to the posed question.",
"Examine the video and approach the connected question with an informed response.",
"Assess the displayed video and answer the subsequent question with accuracy.",
"Consider the video content and deliver a relevant answer to the corresponding question.",
"Go through the video, taking into account key aspects, and respond to the question."
]
ch_qa_templates = [
"审阅所提供的视频,并回答与其视觉元素相关的问题。",
"观看所提供的视频,对相关问题给出准确的回答。",
"仔细审查视频,识别相关的细节,回答与之相关的问题。",
"仔细观察所展示的视觉内容,并对相应的问题给出精确的回答。",
"认真观察视频并准确回答相关的问题。",
"详细观看视频,并且对提出的问题给出合适的回答。",
"观察视频并用有依据的回答来解答相关的问题。",
"评估展示的视频,并准确地回答随后的问题。",
"根据视频内容,对相应的问题给出合理的答案。",
"浏览视频,根据其中的关键内容回答问题。",
]
if s_id != None:
index = s_id
else:
index = random.choice(range(len(en_qa_templates)))
system_prompt = f"""You are a helpful assistant, {en_qa_templates[index]} 你是一个乐于助人的助手,{ch_qa_templates[index]}"""
chat_template = """{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>'
+ message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}
{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"""
messages = []
for source in sources:
message = [{'role': 'system', 'content': system_prompt}]
for sentence in source:
message.append(sentence)
messages.append(message)
#input_ids = tokenizer.apply_chat_template(messages, chat_template, add_generation_prompt=True, return_tensors='pt')
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt')
return input_ids
def get_index(fps, max_frame, num_segments):
num_frames = max_frame
if num_frames <= num_segments:
out_indices = start_idx + np.array([(idx % num_frames) for idx in range(num_segments)])
out_indices = np.sort(out_indices)
else:
out_indices = np.linspace(0, num_frames-1, num_segments)
durations = [idx.item() / fps for idx in out_indices]
return out_indices.astype(np.int64), durations
def read_video(video_path, num_segments):
vr = decord.VideoReader(video_path)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
total_duration = len(vr) / fps
frame_indices, durations = get_index(fps, max_frame, num_segments)
video = []
for frame_index in frame_indices:
image = Image.fromarray(vr[frame_index].asnumpy())
video.append(image_processor(image).unsqueeze(0))
video = torch.concat(video)
return video, torch.Tensor(durations), total_duration
def get_input(video_path, num_segments, question, history, tokenizer, s_id):
video, durations, total_duration = read_video(video_path, num_segments)
if history == None:
conversations = []
conversations.append({'role': 'user', 'content': f'<video>\n{question}'})
else:
conversations = history
conversations.append({'role': 'user', 'content': question})
sources = [conversations]
sources = preprocess_multimodal(sources, video.shape[0])
input_ids = preprocess(sources, tokenizer, s_id=s_id)
return video, durations, input_ids, conversations
def add_pred_to_history(history, pred):
history.append({'role': 'assistant', 'content': pred})
return history