Dove / inference_script.py
Muhammad Taqi Raza
fps=24
c6789b3
from pathlib import Path
import argparse
import logging
import torch
from torchvision import transforms
from torchvision.io import write_video
from tqdm import tqdm
from diffusers import (
CogVideoXDPMScheduler,
CogVideoXPipeline,
)
from transformers import set_seed
from typing import Dict, Tuple
from diffusers.models.embeddings import get_3d_rotary_pos_embed
import json
import os
import cv2
from PIL import Image
from pathlib import Path
import pyiqa
import imageio.v3 as iio
import glob
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
logging.basicConfig(level=logging.INFO)
# 0 ~ 1
to_tensor = transforms.ToTensor()
video_exts = ['.mp4', '.avi', '.mov', '.mkv']
fr_metrics = ['psnr', 'ssim', 'lpips', 'dists']
def no_grad(func):
def wrapper(*args, **kwargs):
with torch.no_grad():
return func(*args, **kwargs)
return wrapper
def is_video_file(filename):
return any(filename.lower().endswith(ext) for ext in video_exts)
def read_video_frames(video_path):
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(to_tensor(Image.fromarray(rgb)))
cap.release()
return torch.stack(frames)
def read_image_folder(folder_path):
image_files = sorted([
os.path.join(folder_path, f) for f in os.listdir(folder_path)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
frames = [to_tensor(Image.open(p).convert("RGB")) for p in image_files]
return torch.stack(frames)
def load_sequence(path):
# return a tensor of shape [F, C, H, W] // 0, 1
if os.path.isdir(path):
return read_image_folder(path)
elif os.path.isfile(path):
if is_video_file(path):
return read_video_frames(path)
elif path.lower().endswith(('.png', '.jpg', '.jpeg')):
# Treat image as a single-frame video
img = to_tensor(Image.open(path).convert("RGB"))
return img.unsqueeze(0) # [1, C, H, W]
raise ValueError(f"Unsupported input: {path}")
@no_grad
def compute_metrics(pred_frames, gt_frames, metrics_model, metric_accumulator, file_name):
print(f"\n\n[{file_name}] Metrics:", end=" ")
for name, model in metrics_model.items():
scores = []
for i in range(pred_frames.shape[0]):
pred = pred_frames[i].unsqueeze(0)
if gt_frames != None:
gt = gt_frames[i].unsqueeze(0)
if name in fr_metrics:
score = model(pred, gt).item()
else:
score = model(pred).item()
scores.append(score)
val = sum(scores) / len(scores)
metric_accumulator[name].append(val)
print(f"{name.upper()}={val:.4f}", end=" ")
print()
def save_frames_as_png(video, output_dir, fps=8):
"""
Save video frames as PNG sequence.
Args:
video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
output_dir (str): directory to save PNG files
fps (int): kept for API compatibility
"""
video = video[0] # Remove batch dimension
video = video.permute(1, 2, 3, 0) # [F, H, W, C]
os.makedirs(output_dir, exist_ok=True)
frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
for i, frame in enumerate(frames):
filename = os.path.join(output_dir, f"{i:03d}.png")
Image.fromarray(frame).save(filename)
def save_video_with_imageio_lossless(video, output_path, fps=8):
"""
Save a video tensor to .mkv using imageio.v3.imwrite with ffmpeg backend.
Args:
video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
output_path (str): where to save the .mkv file
fps (int): frames per second
"""
video = video[0]
video = video.permute(1, 2, 3, 0)
frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
iio.imwrite(
output_path,
frames,
fps=fps,
codec='libx264rgb',
pixelformat='rgb24',
macro_block_size=None,
ffmpeg_params=['-crf', '0'],
)
def save_video_with_imageio(video, output_path, fps=8, format='yuv444p'):
"""
Save a video tensor to .mp4 using imageio.v3.imwrite with ffmpeg backend.
Args:
video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
output_path (str): where to save the .mp4 file
fps (int): frames per second
"""
video = video[0]
video = video.permute(1, 2, 3, 0)
frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
if format == 'yuv444p':
iio.imwrite(
output_path,
frames,
fps=fps,
codec='libx264',
pixelformat='yuv444p',
macro_block_size=None,
ffmpeg_params=['-crf', '0'],
)
else:
iio.imwrite(
output_path,
frames,
fps=fps,
codec='libx264',
pixelformat='yuv420p',
macro_block_size=None,
ffmpeg_params=['-crf', '10'],
)
def preprocess_video_match(
video_path: Path | str,
is_match: bool = False,
) -> torch.Tensor:
"""
Loads a single video.
Args:
video_path: Path to the video file.
Returns:
A torch.Tensor with shape [F, C, H, W] where:
F = number of frames
C = number of channels (3 for RGB)
H = height
W = width
"""
if isinstance(video_path, str):
video_path = Path(video_path)
video_reader = decord.VideoReader(uri=video_path.as_posix())
video_num_frames = len(video_reader)
frames = video_reader.get_batch(list(range(video_num_frames)))
F, H, W, C = frames.shape
original_shape = (F, H, W, C)
pad_f = 0
pad_h = 0
pad_w = 0
if is_match:
remainder = (F - 1) % 8
if remainder != 0:
last_frame = frames[-1:]
pad_f = 8 - remainder
repeated_frames = last_frame.repeat(pad_f, 1, 1, 1)
frames = torch.cat([frames, repeated_frames], dim=0)
pad_h = (16 - H % 16) % 16
pad_w = (16 - W % 16) % 16
if pad_h > 0 or pad_w > 0:
# pad = (w_left, w_right, h_top, h_bottom)
frames = torch.nn.functional.pad(frames, pad=(0, 0, 0, pad_w, 0, pad_h)) # pad right and bottom
# to F, C, H, W
return frames.float().permute(0, 3, 1, 2).contiguous(), pad_f, pad_h, pad_w, original_shape
def remove_padding_and_extra_frames(video, pad_F, pad_H, pad_W):
if pad_F > 0:
video = video[:, :, :-pad_F, :, :]
if pad_H > 0:
video = video[:, :, :, :-pad_H, :]
if pad_W > 0:
video = video[:, :, :, :, :-pad_W]
return video
def make_temporal_chunks(F, chunk_len, overlap_t=8):
"""
Args:
F: total number of frames
chunk_len: int, chunk length in time (excluding overlap)
overlap: int, number of overlapping frames between chunks
Returns:
time_chunks: List of (start_t, end_t) tuples
"""
if chunk_len == 0:
return [(0, F)]
effective_stride = chunk_len - overlap_t
if effective_stride <= 0:
raise ValueError("chunk_len must be greater than overlap")
chunk_starts = list(range(0, F - overlap_t, effective_stride))
if chunk_starts[-1] + chunk_len < F:
chunk_starts.append(F - chunk_len)
time_chunks = []
for i, t_start in enumerate(chunk_starts):
t_end = min(t_start + chunk_len, F)
time_chunks.append((t_start, t_end))
if len(time_chunks) >= 2 and time_chunks[-1][1] - time_chunks[-1][0] < chunk_len:
last = time_chunks.pop()
prev_start, _ = time_chunks[-1]
time_chunks[-1] = (prev_start, last[1])
return time_chunks
def make_spatial_tiles(H, W, tile_size_hw, overlap_hw=(32, 32)):
"""
Args:
H, W: height and width of the frame
tile_size_hw: Tuple (tile_height, tile_width)
overlap_hw: Tuple (overlap_height, overlap_width)
Returns:
spatial_tiles: List of (start_h, end_h, start_w, end_w) tuples
"""
tile_height, tile_width = tile_size_hw
overlap_h, overlap_w = overlap_hw
if tile_height == 0 or tile_width == 0:
return [(0, H, 0, W)]
tile_stride_h = tile_height - overlap_h
tile_stride_w = tile_width - overlap_w
if tile_stride_h <= 0 or tile_stride_w <= 0:
raise ValueError("Tile size must be greater than overlap")
h_tiles = list(range(0, H - overlap_h, tile_stride_h))
if not h_tiles or h_tiles[-1] + tile_height < H:
h_tiles.append(H - tile_height)
# Merge last row if needed
if len(h_tiles) >= 2 and h_tiles[-1] + tile_height > H:
h_tiles.pop()
w_tiles = list(range(0, W - overlap_w, tile_stride_w))
if not w_tiles or w_tiles[-1] + tile_width < W:
w_tiles.append(W - tile_width)
# Merge last column if needed
if len(w_tiles) >= 2 and w_tiles[-1] + tile_width > W:
w_tiles.pop()
spatial_tiles = []
for h_start in h_tiles:
h_end = min(h_start + tile_height, H)
if h_end + tile_stride_h > H:
h_end = H
for w_start in w_tiles:
w_end = min(w_start + tile_width, W)
if w_end + tile_stride_w > W:
w_end = W
spatial_tiles.append((h_start, h_end, w_start, w_end))
return spatial_tiles
def get_valid_tile_region(t_start, t_end, h_start, h_end, w_start, w_end,
video_shape, overlap_t, overlap_h, overlap_w):
_, _, F, H, W = video_shape
t_len = t_end - t_start
h_len = h_end - h_start
w_len = w_end - w_start
valid_t_start = 0 if t_start == 0 else overlap_t // 2
valid_t_end = t_len if t_end == F else t_len - overlap_t // 2
valid_h_start = 0 if h_start == 0 else overlap_h // 2
valid_h_end = h_len if h_end == H else h_len - overlap_h // 2
valid_w_start = 0 if w_start == 0 else overlap_w // 2
valid_w_end = w_len if w_end == W else w_len - overlap_w // 2
out_t_start = t_start + valid_t_start
out_t_end = t_start + valid_t_end
out_h_start = h_start + valid_h_start
out_h_end = h_start + valid_h_end
out_w_start = w_start + valid_w_start
out_w_end = w_start + valid_w_end
return {
"valid_t_start": valid_t_start, "valid_t_end": valid_t_end,
"valid_h_start": valid_h_start, "valid_h_end": valid_h_end,
"valid_w_start": valid_w_start, "valid_w_end": valid_w_end,
"out_t_start": out_t_start, "out_t_end": out_t_end,
"out_h_start": out_h_start, "out_h_end": out_h_end,
"out_w_start": out_w_start, "out_w_end": out_w_end,
}
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (
num_frames + transformer_config.patch_size_t - 1
) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(grid_height, grid_width),
device=device,
)
return freqs_cos, freqs_sin
@no_grad
def process_video(
pipe: CogVideoXPipeline,
video: torch.Tensor,
prompt: str = '',
noise_step: int = 0,
sr_noise_step: int = 399,
):
# SR the video frames based on the prompt.
# `num_frames` is the Number of frames to generate.
# Decode video
video = video.to(pipe.vae.device, dtype=pipe.vae.dtype)
latent_dist = pipe.vae.encode(video).latent_dist
latent = latent_dist.sample() * pipe.vae.config.scaling_factor
patch_size_t = pipe.transformer.config.patch_size_t
if patch_size_t is not None:
ncopy = latent.shape[2] % patch_size_t
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
assert latent.shape[2] % patch_size_t == 0
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
prompt_token_ids = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.transformer.config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = pipe.text_encoder(
prompt_token_ids.to(latent.device)
)[0]
_, seq_len, _ = prompt_embedding.shape
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
latent = latent.permute(0, 2, 1, 3, 4)
# Add noise to latent (Select)
if noise_step != 0:
noise = torch.randn_like(latent)
add_timesteps = torch.full(
(batch_size,),
fill_value=noise_step,
dtype=torch.long,
device=latent.device,
)
latent = pipe.scheduler.add_noise(latent, noise, add_timesteps)
timesteps = torch.full(
(batch_size,),
fill_value=sr_noise_step,
dtype=torch.long,
device=latent.device,
)
# Prepare rotary embeds
vae_scale_factor_spatial = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
transformer_config = pipe.transformer.config
rotary_emb = (
prepare_rotary_positional_embeddings(
height=height * vae_scale_factor_spatial,
width=width * vae_scale_factor_spatial,
num_frames=num_frames,
transformer_config=transformer_config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
device=latent.device,
)
if pipe.transformer.config.use_rotary_positional_embeddings
else None
)
# Predict noise
predicted_noise = pipe.transformer(
hidden_states=latent,
encoder_hidden_states=prompt_embedding,
timestep=timesteps,
image_rotary_emb=rotary_emb,
return_dict=False,
)[0]
latent_generate = pipe.scheduler.get_velocity(
predicted_noise, latent, timesteps
)
# generate video
if patch_size_t is not None and ncopy > 0:
latent_generate = latent_generate[:, ncopy:, :, :, :]
# [B, C, F, H, W]
video_generate = pipe.decode_latents(latent_generate)
video_generate = (video_generate * 0.5 + 0.5).clamp(0.0, 1.0)
return video_generate
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="VSR using DOVE")
parser.add_argument("--input_dir", type=str)
parser.add_argument("--input_json", type=str, default=None)
parser.add_argument("--gt_dir", type=str, default=None)
parser.add_argument("--eval_metrics", type=str, default='') # 'psnr,ssim,lpips,dists,clipiqa,musiq,maniqa,niqe'
parser.add_argument("--model_path", type=str)
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument("--output_path", type=str, default="./results", help="The path save generated video")
parser.add_argument("--fps", type=int, default=24, help="The frames per second for the generated video")
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
parser.add_argument("--upscale_mode", type=str, default="bilinear")
parser.add_argument("--upscale", type=int, default=4)
parser.add_argument("--noise_step", type=int, default=0)
parser.add_argument("--sr_noise_step", type=int, default=399)
parser.add_argument("--is_cpu_offload", action="store_true", help="Enable CPU offload for the model")
parser.add_argument("--is_vae_st", action="store_true", help="Enable VAE slicing and tiling")
parser.add_argument("--png_save", action="store_true", help="Save output as PNG sequence")
parser.add_argument("--save_format", type=str, default="yuv444p", help="Save output as PNG sequence")
# Crop and Tiling Parameters
parser.add_argument("--tile_size_hw", type=int, nargs=2, default=(0, 0), help="Tile size for spatial tiling (height, width)")
parser.add_argument("--overlap_hw", type=int, nargs=2, default=(32, 32))
parser.add_argument("--chunk_len", type=int, default=0, help="Chunk length for temporal chunking")
parser.add_argument("--overlap_t", type=int, default=8)
args = parser.parse_args()
if args.dtype == "float16":
dtype = torch.float16
elif args.dtype == "bfloat16":
dtype = torch.bfloat16
elif args.dtype == "float32":
dtype = torch.float32
else:
raise ValueError("Invalid dtype. Choose from 'float16', 'bfloat16', or 'float32'.")
if args.chunk_len > 0:
print(f"Chunking video into {args.chunk_len} frames with {args.overlap_t} overlap")
overlap_t = args.overlap_t
else:
overlap_t = 0
if args.tile_size_hw != (0, 0):
print(f"Tiling video into {args.tile_size_hw} frames with {args.overlap_hw} overlap")
overlap_hw = args.overlap_hw
else:
overlap_hw = (0, 0)
# Set seed
set_seed(args.seed)
if args.input_json is not None:
with open(args.input_json, 'r') as f:
video_prompt_dict = json.load(f)
else:
video_prompt_dict = {}
# Get all video files from input directory
video_files = []
for ext in video_exts:
video_files.extend(glob.glob(os.path.join(args.input_dir, f'*{ext}')))
video_files = sorted(video_files) # Sort files for consistent ordering
if not video_files:
raise ValueError(f"No video files found in {args.input_dir}")
os.makedirs(args.output_path, exist_ok=True)
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
# function to use Multi GPUs.
pipe = CogVideoXPipeline.from_pretrained(args.model_path, torch_dtype=dtype, device_map="balanced")
# If you're using with lora, add this code
if args.lora_path:
print(f"Loading LoRA weights from {args.lora_path}")
pipe.load_lora_weights(
args.lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
)
pipe.fuse_lora(components=["transformer"], lora_scale=1.0) # lora_scale = lora_alpha / rank
# 2. Set Scheduler.
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
# We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B.
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
# 3. Enable CPU offload for the model.
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
# and enable to("cuda")
# if args.is_cpu_offload:
# pipe.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()
# else:
# pipe.to("cuda")
if args.is_vae_st:
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# pipe.transformer.eval()
# torch.set_grad_enabled(False)
# 4. Set the metircs
if args.eval_metrics != '':
metrics_list = [m.strip().lower() for m in args.eval_metrics.split(',')]
metrics_models = {}
for name in metrics_list:
try:
metrics_models[name] = pyiqa.create_metric(name).to(pipe.device).eval()
except Exception as e:
print(f"Failed to initialize metric '{name}': {e}")
metric_accumulator = {name: [] for name in metrics_list}
else:
metrics_models = None
metric_accumulator = None
for video_path in tqdm(video_files, desc="Processing videos"):
video_name = os.path.basename(video_path)
prompt = video_prompt_dict.get(video_name, "")
if os.path.exists(video_path):
# Read video
# [F, C, H, W]
video, pad_f, pad_h, pad_w, original_shape = preprocess_video_match(video_path, is_match=True)
H_, W_ = video.shape[2], video.shape[3]
video = torch.nn.functional.interpolate(video, size=(H_*args.upscale, W_*args.upscale), mode=args.upscale_mode, align_corners=False)
__frame_transform = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)] # -1, 1
)
video = torch.stack([__frame_transform(f) for f in video], dim=0)
video = video.unsqueeze(0)
# [B, C, F, H, W]
video = video.permute(0, 2, 1, 3, 4).contiguous()
_B, _C, _F, _H, _W = video.shape
time_chunks = make_temporal_chunks(_F, args.chunk_len, overlap_t)
spatial_tiles = make_spatial_tiles(_H, _W, args.tile_size_hw, overlap_hw)
output_video = torch.zeros_like(video)
write_count = torch.zeros_like(video, dtype=torch.int)
print(f"Process video: {video_name} | Prompt: {prompt} | Frame: {_F} (ori: {original_shape[0]}; pad: {pad_f}) | Target Resolution: {_H}, {_W} (ori: {original_shape[1]*args.upscale}, {original_shape[2]*args.upscale}; pad: {pad_h}, {pad_w}) | Chunk Num: {len(time_chunks)*len(spatial_tiles)}")
for t_start, t_end in time_chunks:
for h_start, h_end, w_start, w_end in spatial_tiles:
video_chunk = video[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# print(f"video_chunk: {video_chunk.shape} | t: {t_start}:{t_end} | h: {h_start}:{h_end} | w: {w_start}:{w_end}")
# [B, C, F, H, W]
_video_generate = process_video(
pipe=pipe,
video=video_chunk,
prompt=prompt,
noise_step=args.noise_step,
sr_noise_step=args.sr_noise_step,
)
region = get_valid_tile_region(
t_start, t_end, h_start, h_end, w_start, w_end,
video_shape=video.shape,
overlap_t=overlap_t,
overlap_h=overlap_hw[0],
overlap_w=overlap_hw[1],
)
output_video[:, :, region["out_t_start"]:region["out_t_end"],
region["out_h_start"]:region["out_h_end"],
region["out_w_start"]:region["out_w_end"]] = \
_video_generate[:, :, region["valid_t_start"]:region["valid_t_end"],
region["valid_h_start"]:region["valid_h_end"],
region["valid_w_start"]:region["valid_w_end"]]
write_count[:, :, region["out_t_start"]:region["out_t_end"],
region["out_h_start"]:region["out_h_end"],
region["out_w_start"]:region["out_w_end"]] += 1
video_generate = output_video
if (write_count == 0).any():
print("Error: Lack of write in region !!!")
exit()
if (write_count > 1).any():
print("Error: Write count > 1 in region !!!")
exit()
video_generate = remove_padding_and_extra_frames(video_generate, pad_f, pad_h*4, pad_w*4)
file_name = os.path.basename(video_path)
output_path = os.path.join(args.output_path, file_name)
if metrics_models is not None:
# [1, C, F, H, W] -> [F, C, H, W]
pred_frames = video_generate[0]
pred_frames = pred_frames.permute(1, 0, 2, 3).contiguous()
if args.gt_dir is not None:
gt_frames = load_sequence(os.path.join(args.gt_dir, file_name))
else:
gt_frames = None
compute_metrics(pred_frames, gt_frames, metrics_models, metric_accumulator, file_name)
if args.png_save:
# Save as PNG sequence
output_dir = output_path.rsplit('.', 1)[0] # Remove extension
save_frames_as_png(video_generate, output_dir, fps=args.fps)
else:
output_path = output_path.replace('.mkv', '.mp4')
save_video_with_imageio(video_generate, output_path, fps=args.fps, format=args.save_format)
else:
print(f"Warning: {video_name} not found in {args.input_dir}")
if metrics_models is not None:
print("\n=== Overall Average Metrics ===")
count = len(next(iter(metric_accumulator.values())))
overall_avg = {metric: 0 for metric in metrics_list}
out_name = 'metrics_'
for metric in metrics_list:
out_name += f"{metric}_"
scores = metric_accumulator[metric]
if scores:
avg = sum(scores) / len(scores)
overall_avg[metric] = avg
print(f"{metric.upper()}: {avg:.4f}")
out_name = out_name.rstrip('_') + '.json'
out_path = os.path.join(args.output_path, out_name)
output = {
"per_sample": metric_accumulator,
"average": overall_avg,
"count": count
}
with open(out_path, 'w') as f:
json.dump(output, f, indent=2)
print("All videos processed.")