import argparse import os import re from tqdm import tqdm import pandas as pd from vllm import LLM, SamplingParams from utils.logger import logger def parse_args(): parser = argparse.ArgumentParser(description="Recaption the video frame.") parser.add_argument( "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)." ) parser.add_argument( "--video_path_column", type=str, default="video_path", help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", ) parser.add_argument( "--caption_column", type=str, default="sampled_frame_caption", help="The column contains the sampled_frame_caption.", ) parser.add_argument( "--remove_quotes", action="store_true", help="Whether to remove quotes from caption.", ) parser.add_argument( "--batch_size", type=int, default=10, required=False, help="The batch size for the video caption.", ) parser.add_argument( "--summary_model_name", type=str, default="mistralai/Mistral-7B-Instruct-v0.2", ) parser.add_argument( "--summary_prompt", type=str, default=( "You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, " "which you need to summarize it into a description of the video clip." "Please provide your video description following these requirements: " "1. Describe the basic and necessary information of the video in the third person, be as concise as possible. " "2. Output the video description directly. Begin with 'In this video'. " "3. Limit the video description within 100 words. " "Here is the mid-frame description: " ), ) parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.") args = parser.parse_args() return args def main(): args = parse_args() if args.video_metadata_path.endswith(".csv"): video_metadata_df = pd.read_csv(args.video_metadata_path) elif args.video_metadata_path.endswith(".jsonl"): video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) else: raise ValueError("The video_metadata_path must end with .csv or .jsonl.") video_path_list = video_metadata_df[args.video_path_column].tolist() sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist() if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): raise ValueError("The saved_path must end with .csv or .jsonl.") if os.path.exists(args.saved_path): if args.saved_path.endswith(".csv"): saved_metadata_df = pd.read_csv(args.saved_path) elif args.saved_path.endswith(".jsonl"): saved_metadata_df = pd.read_json(args.saved_path, lines=True) saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() video_path_list = list(set(video_path_list) - set(saved_video_path_list)) video_metadata_df.set_index(args.video_path_column, inplace=True) video_metadata_df = video_metadata_df.loc[video_path_list] sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist() logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256) summary_model = LLM(model=args.summary_model_name, trust_remote_code=True) result_dict = {"video_path": [], "summary_model": [], "summary_caption": []} for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)): batch_video_path = video_path_list[i: i + args.batch_size] batch_caption = sampled_frame_caption_list[i : i + args.batch_size] batch_prompt = [] for caption in batch_caption: if args.remove_quotes: caption = re.sub(r'(["\']).*?\1', "", caption) batch_prompt.append("user:" + args.summary_prompt + str(caption) + "\n assistant:") batch_output = summary_model.generate(batch_prompt, sampling_params) result_dict["video_path"].extend(batch_video_path) result_dict["summary_model"].extend([args.summary_model_name] * len(batch_caption)) result_dict["summary_caption"].extend([output.outputs[0].text.rstrip() for output in batch_output]) # Save the metadata every args.saved_freq. if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0: result_df = pd.DataFrame(result_dict) if args.saved_path.endswith(".csv"): header = True if not os.path.exists(args.saved_path) else False result_df.to_csv(args.saved_path, header=header, index=False, mode="a") elif args.saved_path.endswith(".jsonl"): result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") logger.info(f"Save result to {args.saved_path}.") result_dict = {"video_path": [], "summary_model": [], "summary_caption": []} result_df = pd.DataFrame(result_dict) if args.saved_path.endswith(".csv"): header = True if not os.path.exists(args.saved_path) else False result_df.to_csv(args.saved_path, header=header, index=False, mode="a") elif args.saved_path.endswith(".jsonl"): result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") logger.info(f"Save the final result to {args.saved_path}.") if __name__ == "__main__": main()