Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| """ | |
| Utility script to stress-test LLaVA-Video frame decoding in isolation. | |
| This runs the `VideoCaptionDataset` loader on a single node so that we can | |
| watch for files that consistently time out or wedged dataloader workers. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Optional | |
| import torch | |
| from torch.utils.data import DataLoader | |
| ROOT_DIR = Path(__file__).resolve().parents[2] | |
| if str(ROOT_DIR) not in sys.path: | |
| sys.path.insert(0, str(ROOT_DIR)) | |
| from training import data as video_data_module # noqa: E402 | |
| from training.data import VideoCaptionDataset # noqa: E402 | |
| from training.utils import image_transform as default_image_transform # noqa: E402 | |
| def _resolve_llavavid_root(root_arg: Optional[str]) -> Path: | |
| if root_arg: | |
| root = Path(root_arg).expanduser().resolve() | |
| else: | |
| root = ROOT_DIR / "data" / "video" / "LLaVA-Video-178K" | |
| if not root.exists(): | |
| raise FileNotFoundError(f"LLaVA-Video root directory not found: {root}") | |
| return root | |
| def _identity_collate(batch: List[Optional[Dict[str, Any]]]) -> List[Dict[str, Any]]: | |
| """Drop `None` samples that VideoCaptionDataset returns after repeated failures.""" | |
| filtered = [sample for sample in batch if sample is not None] | |
| return filtered | |
| def _parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Decode-check LLaVA-Video samples with the existing dataset logic." | |
| ) | |
| parser.add_argument( | |
| "--llavavid-root", | |
| type=str, | |
| default=None, | |
| help="Path to the LLaVA-Video-178K cache directory. Defaults to data/video/LLaVA-Video-178K relative to repo root.", | |
| ) | |
| parser.add_argument( | |
| "--num-samples", | |
| type=int, | |
| default=256, | |
| help=( | |
| "Number of samples to attempt decoding (per DataLoader worker collectively). " | |
| "Set to -1 to sweep the entire dataset once." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=1, | |
| help="Batch size for the diagnostic DataLoader.", | |
| ) | |
| parser.add_argument( | |
| "--num-workers", | |
| type=int, | |
| default=4, | |
| help="Number of DataLoader workers to spawn. Set to match your training run.", | |
| ) | |
| parser.add_argument( | |
| "--num-frames", | |
| type=int, | |
| default=8, | |
| help="Number of frames to request from load_video_mp4.", | |
| ) | |
| parser.add_argument( | |
| "--resolution", | |
| type=int, | |
| default=256, | |
| help="Resolution passed to the dataset transform.", | |
| ) | |
| parser.add_argument( | |
| "--sample-method", | |
| type=str, | |
| default="uniform", | |
| choices=("uniform", "random"), | |
| help="Frame sampling strategy.", | |
| ) | |
| parser.add_argument( | |
| "--report-every", | |
| type=int, | |
| default=10, | |
| help="Print a progress line every N successfully decoded samples.", | |
| ) | |
| parser.add_argument( | |
| "--timeout", | |
| type=float, | |
| default=30.0, | |
| help="Maximum seconds to allow a batch to hang before treating it as a stall.", | |
| ) | |
| return parser.parse_args() | |
| def _maybe_set_thread_limits() -> None: | |
| # Avoid oversubscribing CPU threads when the loader uses multiple workers. | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("MKL_NUM_THREADS", "1") | |
| os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") | |
| os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") | |
| def main() -> None: | |
| args = _parse_args() | |
| _maybe_set_thread_limits() | |
| llavavid_root = _resolve_llavavid_root(args.llavavid_root) | |
| print(f"[INFO] Using LLaVA-Video root: {llavavid_root}") | |
| original_loader = video_data_module.load_video_mp4 | |
| def traced_loader(*loader_args, **loader_kwargs): | |
| video_path = loader_kwargs.get("video_path") | |
| if video_path is None and loader_args: | |
| video_path = loader_args[0] | |
| start = time.time() | |
| try: | |
| frames = original_loader(*loader_args, **loader_kwargs) | |
| except Exception as exc: # pylint: disable=broad-except | |
| duration = time.time() - start | |
| print(f"[ERROR] {video_path} raised {exc.__class__.__name__} after {duration:.2f}s: {exc}") | |
| raise | |
| duration = time.time() - start | |
| status = "OK" if frames else "NONE" | |
| print(f"[TRACE] {status:>4} | {duration:6.2f}s | {video_path}") | |
| return frames | |
| video_data_module.load_video_mp4 = traced_loader | |
| try: | |
| dataset = VideoCaptionDataset( | |
| transform=default_image_transform, | |
| tokenizer=None, | |
| max_seq_length=256, | |
| resolution=args.resolution, | |
| dataset_name="llavavid", | |
| llavavid_path=str(llavavid_root), | |
| llavavid_local_files_only=True, | |
| sample_method=args.sample_method, | |
| num_frames=args.num_frames, | |
| ) | |
| if len(dataset) == 0: | |
| print("[ERROR] Dataset returned zero length. Check the root directory/config.") | |
| sys.exit(1) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| collate_fn=_identity_collate, | |
| pin_memory=False, | |
| drop_last=False, | |
| ) | |
| print( | |
| f"[INFO] Starting decode sweep: " | |
| f"{args.num_samples} samples, batch_size={args.batch_size}, num_workers={args.num_workers}" | |
| ) | |
| decoded = 0 | |
| attempted = 0 | |
| failed = 0 | |
| start_time = time.time() | |
| last_report = start_time | |
| for batch_idx, batch in enumerate(dataloader, start=1): | |
| expected = args.batch_size | |
| actual = len(batch) | |
| attempted += expected | |
| failed += max(expected - actual, 0) | |
| decoded += sum(1 for sample in batch if sample.get("video")) | |
| if args.num_samples > 0 and decoded >= args.num_samples: | |
| break | |
| now = time.time() | |
| if args.report_every > 0 and decoded and decoded % args.report_every == 0: | |
| elapsed = now - last_report | |
| total_elapsed = now - start_time | |
| print( | |
| f"[INFO] {decoded} successful samples " | |
| f"(attempted={attempted}, failed={failed}) " | |
| f"in {total_elapsed:.1f}s (+{elapsed:.1f}s since last report)." | |
| ) | |
| last_report = now | |
| if now - start_time > args.timeout: | |
| print( | |
| f"[WARN] Exceeded timeout of {args.timeout}s without reaching target samples." | |
| ) | |
| break | |
| total_elapsed = time.time() - start_time | |
| print( | |
| f"[RESULT] Completed sweep: decoded={decoded}, attempted={attempted}, " | |
| f"failed={failed}, elapsed={total_elapsed:.1f}s." | |
| ) | |
| finally: | |
| video_data_module.load_video_mp4 = original_loader | |
| if __name__ == "__main__": | |
| main() | |