bubbliiiing
Create Code
19fe404
raw
history blame contribute delete
No virus
5.98 kB
import argparse
import os
import re
from tqdm import tqdm
import pandas as pd
from vllm import LLM, SamplingParams
from utils.logger import logger
def parse_args():
parser = argparse.ArgumentParser(description="Recaption the video frame.")
parser.add_argument(
"--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
)
parser.add_argument(
"--video_path_column",
type=str,
default="video_path",
help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
)
parser.add_argument(
"--caption_column",
type=str,
default="sampled_frame_caption",
help="The column contains the sampled_frame_caption.",
)
parser.add_argument(
"--remove_quotes",
action="store_true",
help="Whether to remove quotes from caption.",
)
parser.add_argument(
"--batch_size",
type=int,
default=10,
required=False,
help="The batch size for the video caption.",
)
parser.add_argument(
"--summary_model_name",
type=str,
default="mistralai/Mistral-7B-Instruct-v0.2",
)
parser.add_argument(
"--summary_prompt",
type=str,
default=(
"You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, "
"which you need to summarize it into a description of the video clip."
"Please provide your video description following these requirements: "
"1. Describe the basic and necessary information of the video in the third person, be as concise as possible. "
"2. Output the video description directly. Begin with 'In this video'. "
"3. Limit the video description within 100 words. "
"Here is the mid-frame description: "
),
)
parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.video_metadata_path.endswith(".csv"):
video_metadata_df = pd.read_csv(args.video_metadata_path)
elif args.video_metadata_path.endswith(".jsonl"):
video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
else:
raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
video_path_list = video_metadata_df[args.video_path_column].tolist()
sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
raise ValueError("The saved_path must end with .csv or .jsonl.")
if os.path.exists(args.saved_path):
if args.saved_path.endswith(".csv"):
saved_metadata_df = pd.read_csv(args.saved_path)
elif args.saved_path.endswith(".jsonl"):
saved_metadata_df = pd.read_json(args.saved_path, lines=True)
saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
video_path_list = list(set(video_path_list) - set(saved_video_path_list))
video_metadata_df.set_index(args.video_path_column, inplace=True)
video_metadata_df = video_metadata_df.loc[video_path_list]
sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
summary_model = LLM(model=args.summary_model_name, trust_remote_code=True)
result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}
for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)):
batch_video_path = video_path_list[i: i + args.batch_size]
batch_caption = sampled_frame_caption_list[i : i + args.batch_size]
batch_prompt = []
for caption in batch_caption:
if args.remove_quotes:
caption = re.sub(r'(["\']).*?\1', "", caption)
batch_prompt.append("user:" + args.summary_prompt + str(caption) + "\n assistant:")
batch_output = summary_model.generate(batch_prompt, sampling_params)
result_dict["video_path"].extend(batch_video_path)
result_dict["summary_model"].extend([args.summary_model_name] * len(batch_caption))
result_dict["summary_caption"].extend([output.outputs[0].text.rstrip() for output in batch_output])
# Save the metadata every args.saved_freq.
if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0:
result_df = pd.DataFrame(result_dict)
if args.saved_path.endswith(".csv"):
header = True if not os.path.exists(args.saved_path) else False
result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
logger.info(f"Save result to {args.saved_path}.")
result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}
result_df = pd.DataFrame(result_dict)
if args.saved_path.endswith(".csv"):
header = True if not os.path.exists(args.saved_path) else False
result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
logger.info(f"Save the final result to {args.saved_path}.")
if __name__ == "__main__":
main()