Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import decord | |
import argparse | |
import pandas as pd | |
import numpy as np | |
from tqdm import tqdm | |
from vllm import LLM, SamplingParams | |
from transformers import AutoTokenizer, AutoProcessor | |
from torch.utils.data import DataLoader | |
SYSTEM_PROMPT = "I need you to generate a structured and detailed caption for the provided video. The structured output and the requirements for each field are as shown in the following JSON content: {\"subjects\": [{\"appearance\": \"Main subject appearance description\", \"action\": \"Main subject action\", \"expression\": \"Main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Subject position in the video (Can be relative position to other objects or spatial description)\", \"TYPES\": {\"type\": \"Main category (e.g., Human)\", \"sub_type\": \"Sub-category (e.g., Man)\"}, \"is_main_subject\": true}, {\"appearance\": \"Non-main subject appearance description\", \"action\": \"Non-main subject action\", \"expression\": \"Non-main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Position of non-main subject 1\", \"TYPES\": {\"type\": \"Main category (e.g., Vehicles)\", \"sub_type\": \"Sub-category (e.g., Ship)\"}, \"is_main_subject\": false}], \"shot_type\": \"Shot type(Options: long_shot/full_shot/medium_shot/close_up/extreme_close_up/other)\", \"shot_angle\": \"Camera angle(Options: eye_level/high_angle/low_angle/other)\", \"shot_position\": \"Camera position(Options: front_view/back_view/side_view/over_the_shoulder/overhead_view/point_of_view/aerial_view/overlooking_view/other)\", \"camera_motion\": \"Camera movement description\", \"environment\": \"Video background/environment description\", \"lighting\": \"Lighting information in the video\"}" | |
class VideoTextDataset(torch.utils.data.Dataset): | |
def __init__(self, csv_path, model_path): | |
self.meta = pd.read_csv(csv_path) | |
self._path = 'path' | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.processor = AutoProcessor.from_pretrained(model_path) | |
def __getitem__(self, index): | |
row = self.meta.iloc[index] | |
path = row[self._path] | |
real_index = self.meta.index[index] | |
vr = decord.VideoReader(path, ctx=decord.cpu(0), width=360, height=420) | |
start = 0 | |
end = len(vr) | |
# avg_fps = vr.get_avg_fps() | |
index = self.get_index(end-start, 16, st=start) | |
frames = vr.get_batch(index).asnumpy() # n h w c | |
video_inputs = [torch.from_numpy(frames).permute(0, 3, 1, 2)] | |
conversation = { | |
"role": "user", | |
"content": [ | |
{ | |
"type": "video", | |
"video": row['path'], | |
"max_pixels": 360 * 420, # 460800 | |
"fps": 2.0, | |
}, | |
{ | |
"type": "text", | |
"text": SYSTEM_PROMPT | |
}, | |
], | |
} | |
# ηζ user_input | |
user_input = self.processor.apply_chat_template( | |
[conversation], | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
results = dict() | |
inputs = { | |
'prompt': user_input, | |
'multi_modal_data': {'video': video_inputs} | |
} | |
results["index"] = real_index | |
results['input'] = inputs | |
return results | |
def __len__(self): | |
return len(self.meta) | |
def get_index(self, video_size, num_frames, st=0): | |
seg_size = max(0., float(video_size - 1) / num_frames) | |
max_frame = int(video_size) - 1 | |
seq = [] | |
# index from 1, must add 1 | |
for i in range(num_frames): | |
start = int(np.round(seg_size * i)) | |
# end = int(np.round(seg_size * (i + 1))) | |
idx = min(start, max_frame) | |
seq.append(idx+st) | |
return seq | |
def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column): | |
flat_indices = [] | |
for x in zip(indices_list): | |
flat_indices.extend(x) | |
flat_results = [] | |
for x in zip(result_list): | |
flat_results.extend(x) | |
flat_indices = np.array(flat_indices) | |
flat_results = np.array(flat_results) | |
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) | |
meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx] | |
meta = meta.loc[unique_indices] | |
return meta | |
def worker_init_fn(worker_id): | |
# Set different seed for each worker | |
worker_seed = torch.initial_seed() % 2**32 | |
np.random.seed(worker_seed) | |
# Prevent deadlocks by setting timeout | |
torch.set_num_threads(1) | |
def main(): | |
parser = argparse.ArgumentParser(description="SkyCaptioner-V1 vllm batch inference") | |
parser.add_argument("--input_csv", default="./examples/test.csv") | |
parser.add_argument("--out_csv", default="./examples/test_result.csv") | |
parser.add_argument("--bs", type=int, default=4) | |
parser.add_argument("--tp", type=int, default=1) | |
parser.add_argument("--model_path", required=True, type=str, help="skycaptioner-v1 model path") | |
args = parser.parse_args() | |
dataset = VideoTextDataset(csv_path=args.input_csv, model_path=args.model_path) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.bs, | |
num_workers=4, | |
worker_init_fn=worker_init_fn, | |
persistent_workers=True, | |
timeout=180, | |
) | |
sampling_params = SamplingParams(temperature=0.05, max_tokens=2048) | |
llm = LLM(model=args.model_path, | |
gpu_memory_utilization=0.6, | |
max_model_len=31920, | |
tensor_parallel_size=args.tp) | |
indices_list = [] | |
caption_save = [] | |
for video_batch in tqdm(dataloader): | |
indices = video_batch["index"] | |
inputs = video_batch["input"] | |
batch_user_inputs = [] | |
for prompt, video in zip(inputs['prompt'], inputs['multi_modal_data']['video'][0]): | |
usi={'prompt':prompt, 'multi_modal_data':{'video':video}} | |
batch_user_inputs.append(usi) | |
outputs = llm.generate(batch_user_inputs, sampling_params, use_tqdm=False) | |
struct_outputs = [output.outputs[0].text for output in outputs] | |
indices_list.extend(indices.tolist()) | |
caption_save.extend(struct_outputs) | |
meta_new = result_writer(indices_list, caption_save, dataset.meta, column=["structural_caption"]) | |
meta_new.to_csv(args.out_csv, index=False) | |
print(f'Saved structural_caption to {args.out_csv}') | |
if __name__ == '__main__': | |
main() |