YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
vrthinker
A video reward model that compares two videos against a text prompt and outputs per-dimension preferences:
TA (Text Alignment), MQ (Motion Quality), VQ (Visual Quality), OA (Overall).
Each label is one of 1 (Video 1 wins), 2 (Video 2 wins), 0 (tie).
The model reasons step-by-step and may call a select_frames tool to request additional frames from the videos
before committing to an answer.
Install
pip install torch transformers accelerate pillow opencv-python
Inference
Save the snippet below as infer.py, then:
python infer.py --video1 path/to/v1.mp4 --video2 path/to/v2.mp4 \
--prompt "A robot rides a unicorn across a rainbow bridge."
# infer.py
import argparse, json, re
from pathlib import Path
import cv2
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
MODEL_DIR = str(Path(__file__).resolve().parent) # the dir containing this README
FRAMES_PER_VIDEO = 128
INITIAL_PER_VIDEO = 4
MAX_TURNS = 6
MAX_FRAMES_PER_CALL = 12
IMAGE_SIDE = 448
SYSTEM_PROMPT = """Task Description:
Your task is to compare two videos generated based on the same text prompt by analyzing their frames in detail and provide an overall judgment along with a judgment for each evaluation dimension.
The provided frames are downsampled from these videos:
- Video 1: First four input frames.
- Video 2: Next four input frames.
Evaluation Dimensions:
1. Text Alignment (TA): How faithfully each video reflects the text prompt.
2. Visual Quality (VQ): Aesthetics, artifacts, blurriness, distortion, color, resolution, flickering.
3. Motion Quality (MQ): Smoothness, jitter, unnatural motion, temporal consistency.
4. Overall Assessment (OA): Holistic judgment across the above.
Frames and Analysis Rules:
- 8 sampled frames are provided initially (4 per video), evenly downsampled from 128 frames per video. The first 4 are Video 1, the next 4 are Video 2.
- Each video has 128 frames (indices 0-127). To inspect more frames, call select_frames with the indices you need; the tool retrieves the same indices from both videos symmetrically.
- Tool returns are paired: for [i, j, k] you get (v1[i], v2[i], v1[j], v2[j], v1[k], v2[k]). Use this pairing to compare the same moment across both videos.
- Each tool call accepts at most 12 indices.
Format Requirement:
1. <Snapshot></Snapshot> โ summarize useful visual details after receiving frames.
2. <Think></Think> โ reasoning.
3. <Answer></Answer> โ final judgment.
Label semantics: 1 = Video 1 better, 2 = Video 2 better, 0 = tie.
Examples:
<Answer>TA=1, VQ=1, MQ=0, OA=1</Answer>
Tool call format:
When you want to inspect more frames, emit a tool call inside <tool_call></tool_call> tags:
<tool_call>{"name": "select_frames", "arguments": {"frame_indices": [10, 30, 60, 90]}}</tool_call>
"""
def extract_frames(video_path: str, indices: list[int]) -> list[Image.Image]:
"""Return PIL frames at the given indices, evenly mapped over the video's actual length."""
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
out: list[Image.Image] = []
for idx in indices:
# map idx in [0, FRAMES_PER_VIDEO) -> real frame in [0, total)
real = min(int(idx / FRAMES_PER_VIDEO * total), total - 1)
cap.set(cv2.CAP_PROP_POS_FRAMES, real)
ok, frame = cap.read()
if not ok:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame).resize((IMAGE_SIDE, IMAGE_SIDE))
out.append(img)
cap.release()
return out
def initial_frames(v1: str, v2: str) -> list[Image.Image]:
idxs = [int(FRAMES_PER_VIDEO * (i + 0.5) / INITIAL_PER_VIDEO) for i in range(INITIAL_PER_VIDEO)]
return extract_frames(v1, idxs) + extract_frames(v2, idxs)
def tool_frames(v1: str, v2: str, indices: list[int]) -> list[Image.Image]:
indices = indices[:MAX_FRAMES_PER_CALL]
out: list[Image.Image] = []
for i in indices:
out += extract_frames(v1, [i])
out += extract_frames(v2, [i])
return out
def parse_tool_call(text: str) -> dict | None:
m = re.search(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", text, re.DOTALL)
if not m:
return None
try:
obj = json.loads(m.group(1))
return obj.get("arguments", {})
except json.JSONDecodeError:
return None
def parse_answer(text: str) -> dict | None:
m = re.search(r"<Answer>(.*?)</Answer>", text, re.DOTALL | re.IGNORECASE)
if not m:
return None
body = m.group(1)
return {d: int(re.search(rf"{d}\s*=\s*(\d)", body).group(1))
for d in ("TA", "MQ", "VQ", "OA")
if re.search(rf"{d}\s*=\s*(\d)", body)}
@torch.inference_mode()
def run(video1: str, video2: str, prompt: str) -> dict:
processor = AutoProcessor.from_pretrained(MODEL_DIR, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_DIR, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
).eval()
images = initial_frames(video1, video2)
user_text = (
f"Compare the two videos generated from the following prompt and evaluate them "
f"across Text Alignment (TA), Motion Quality (MQ), Visual Quality (VQ), and "
f"Overall Assessment (OA).\n\nPrompt: {prompt}\n\n"
f"The first 4 images are uniformly sampled from Video 1, and the next 4 are from "
f"Video 2. Each video has 128 frames (indices 0-127). "
f"Use the select_frames tool to request additional frames if needed."
)
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image"}] * len(images)
+ [{"type": "text", "text": user_text}]},
]
for turn in range(MAX_TURNS):
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=images, return_tensors="pt", padding=True).to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False, temperature=0.0)
reply = processor.batch_decode(
output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
)[0]
messages.append({"role": "assistant", "content": [{"type": "text", "text": reply}]})
answer = parse_answer(reply)
if answer:
return answer
call = parse_tool_call(reply)
if not call or "frame_indices" not in call:
return parse_answer(reply) or {"TA": None, "MQ": None, "VQ": None, "OA": None}
new_imgs = tool_frames(video1, video2, call["frame_indices"])
images += new_imgs
messages.append({
"role": "user",
"content": [{"type": "image"}] * len(new_imgs)
+ [{"type": "text",
"text": f"<tool_response>Retrieved {len(call['frame_indices'])} "
f"frame pairs ({call['frame_indices']}) symmetrically from both "
f"videos.</tool_response>"}],
})
return {"TA": None, "MQ": None, "VQ": None, "OA": None}
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--video1", required=True)
p.add_argument("--video2", required=True)
p.add_argument("--prompt", required=True)
args = p.parse_args()
print(json.dumps(run(args.video1, args.video2, args.prompt), indent=2))
Output
{
"TA": 1,
"MQ": 0,
"VQ": 2,
"OA": 1
}
1 = Video 1 wins on that dimension, 2 = Video 2 wins, 0 = tie.
Hardware
Requires ~16 GB GPU memory in bf16. Tested on a single A100/H100.
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support