| | |
| | """ |
| | Check codebook range by iterating through videos and extracting codes. |
| | |
| | This script loads videos from the dataset, encodes them to get video codes, |
| | and tracks the min/max values to determine the codebook range. |
| | """ |
| |
|
| | import argparse |
| | import os |
| | import sys |
| | import logging |
| | from tqdm import tqdm |
| | import torch |
| | import numpy as np |
| |
|
| | sys.path.append(os.getcwd()) |
| |
|
| | from train.dataset_utils import OpenVid1MDataset, PrecomputedFeatureDataset |
| | from src.pipeline_video import CosmosVideoTokenizer |
| | from transformers import T5Tokenizer |
| | from torch.utils.data import DataLoader |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%m/%d/%Y %H:%M:%S", |
| | level=logging.INFO, |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Check codebook range from video dataset") |
| | |
| | parser.add_argument( |
| | "--csv_path", |
| | type=str, |
| | default=None, |
| | help="Path to OpenVid1M CSV file (if using raw videos)", |
| | ) |
| | parser.add_argument( |
| | "--video_root_dir", |
| | type=str, |
| | default=None, |
| | help="Root directory containing video files", |
| | ) |
| | parser.add_argument( |
| | "--features_dir", |
| | type=str, |
| | default=None, |
| | help="Directory containing pre-extracted features (if using precomputed features)", |
| | ) |
| | parser.add_argument( |
| | "--video_tokenizer_model_id", |
| | type=str, |
| | default="Cosmos-1.0-Tokenizer-DV8x16x16", |
| | help="HuggingFace model ID for Cosmos video tokenizer", |
| | ) |
| | parser.add_argument( |
| | "--num_frames", |
| | type=int, |
| | default=16, |
| | help="Number of frames per video", |
| | ) |
| | parser.add_argument( |
| | "--video_height", |
| | type=int, |
| | default=480, |
| | help="Video height", |
| | ) |
| | parser.add_argument( |
| | "--video_width", |
| | type=int, |
| | default=848, |
| | help="Video width", |
| | ) |
| | parser.add_argument( |
| | "--text_encoder_architecture", |
| | type=str, |
| | default="umt5-base", |
| | choices=["umt5-base", "umt5-xxl", "t5"], |
| | help="Text encoder architecture", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=1, |
| | help="Batch size (use 1 for detailed per-sample tracking)", |
| | ) |
| | parser.add_argument( |
| | "--max_samples", |
| | type=int, |
| | default=None, |
| | help="Maximum number of samples to check. If None, check all.", |
| | ) |
| | parser.add_argument( |
| | "--check_interval", |
| | type=int, |
| | default=10, |
| | help="Print statistics every N samples", |
| | ) |
| | |
| | return parser.parse_args() |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | dtype = torch.float32 |
| | |
| | logger.info(f"Using device: {device}") |
| | |
| | |
| | video_tokenizer = None |
| | use_precomputed = args.features_dir is not None |
| | |
| | if not use_precomputed: |
| | if args.csv_path is None: |
| | raise ValueError("Either --csv_path or --features_dir must be provided") |
| | |
| | logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}") |
| | video_tokenizer = CosmosVideoTokenizer( |
| | model_id=args.video_tokenizer_model_id, |
| | device=device, |
| | dtype=dtype |
| | ) |
| | video_tokenizer.requires_grad_(False) |
| | video_tokenizer.eval() |
| | |
| | |
| | logger.info(f"Video tokenizer codebook_size: {video_tokenizer.codebook_size}") |
| | logger.info(f"Video tokenizer mask_token_id: {video_tokenizer.mask_token_id}") |
| | |
| | |
| | if use_precomputed: |
| | logger.info(f"Using precomputed features from: {args.features_dir}") |
| | dataset = PrecomputedFeatureDataset( |
| | features_dir=args.features_dir, |
| | num_samples=args.max_samples, |
| | ) |
| | else: |
| | |
| | if args.video_root_dir is None: |
| | csv_dir = os.path.dirname(args.csv_path) |
| | if os.path.exists(os.path.join(csv_dir, 'video_reorg')): |
| | video_root_dir = os.path.join(csv_dir, 'video_reorg') |
| | elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')): |
| | video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg') |
| | else: |
| | video_root_dir = csv_dir |
| | logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}") |
| | else: |
| | video_root_dir = args.video_root_dir |
| | |
| | |
| | if args.text_encoder_architecture == "umt5-base": |
| | model_id = "google/umt5-base" |
| | elif args.text_encoder_architecture == "umt5-xxl": |
| | model_id = "google/umt5-xxl" |
| | elif args.text_encoder_architecture == "t5": |
| | model_id = "t5-base" |
| | else: |
| | raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}") |
| | |
| | tokenizer = T5Tokenizer.from_pretrained(model_id) |
| | |
| | dataset = OpenVid1MDataset( |
| | csv_path=args.csv_path, |
| | video_root_dir=video_root_dir, |
| | tokenizer=tokenizer, |
| | num_frames=args.num_frames, |
| | height=args.video_height, |
| | width=args.video_width, |
| | text_encoder_architecture=args.text_encoder_architecture, |
| | use_random_temporal_crop=False, |
| | use_random_crop=False, |
| | ) |
| | |
| | if args.max_samples is not None: |
| | dataset.data = dataset.data[:args.max_samples] |
| | logger.info(f"Limited dataset to {len(dataset)} samples") |
| | |
| | logger.info(f"Dataset size: {len(dataset)}") |
| | |
| | |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=0, |
| | pin_memory=False, |
| | ) |
| | |
| | |
| | global_min = None |
| | global_max = None |
| | total_samples = 0 |
| | failed_samples = 0 |
| | |
| | logger.info("Starting to check codebook range...") |
| | logger.info("=" * 80) |
| | |
| | with torch.no_grad(): |
| | for batch_idx, batch in enumerate(tqdm(dataloader, desc="Checking codes")): |
| | try: |
| | if use_precomputed: |
| | |
| | video_codes = batch["video_codes"] |
| | if isinstance(video_codes, torch.Tensor): |
| | video_codes = video_codes.long() |
| | else: |
| | video_codes = torch.from_numpy(video_codes).long() |
| | else: |
| | |
| | videos = batch["video"].to(device, non_blocking=True) |
| | video_codes = video_tokenizer.encode(videos) |
| | video_codes = video_codes.cpu().long() |
| | |
| | |
| | batch_min = video_codes.min().item() |
| | batch_max = video_codes.max().item() |
| | |
| | if global_min is None: |
| | global_min = batch_min |
| | global_max = batch_max |
| | else: |
| | global_min = min(global_min, batch_min) |
| | global_max = max(global_max, batch_max) |
| | |
| | total_samples += video_codes.shape[0] |
| | |
| | |
| | if (batch_idx + 1) % args.check_interval == 0 or batch_idx == 0: |
| | print(f"\n[Sample {total_samples}]") |
| | print(f" Current batch range: [{batch_min}, {batch_max}]") |
| | print(f" Global range so far: [{global_min}, {global_max}]") |
| | print(f" Codebook size (expected): {video_tokenizer.codebook_size if video_tokenizer else 'N/A'}") |
| | if video_tokenizer: |
| | expected_max = video_tokenizer.codebook_size - 1 |
| | print(f" Expected max (codebook_size - 1): {expected_max}") |
| | if global_max > expected_max: |
| | print(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!") |
| | if global_min < 0: |
| | print(f" ⚠️ WARNING: Found code {global_min} < 0!") |
| | |
| | |
| | unique_values = torch.unique(video_codes).tolist() |
| | print(f" Unique values in batch: {len(unique_values)}") |
| | if len(unique_values) <= 20: |
| | print(f" Values: {sorted(unique_values)}") |
| | else: |
| | print(f" Min unique: {min(unique_values)}, Max unique: {max(unique_values)}") |
| | print("-" * 80) |
| | |
| | except Exception as e: |
| | failed_samples += args.batch_size |
| | logger.error(f"Failed to process batch {batch_idx}: {e}") |
| | continue |
| | |
| | |
| | logger.info("=" * 80) |
| | logger.info("FINAL STATISTICS:") |
| | logger.info(f" Total samples processed: {total_samples}") |
| | logger.info(f" Failed samples: {failed_samples}") |
| | logger.info(f" Global min code: {global_min}") |
| | logger.info(f" Global max code: {global_max}") |
| | logger.info(f" Code range: [{global_min}, {global_max}]") |
| | |
| | if video_tokenizer: |
| | expected_max = video_tokenizer.codebook_size - 1 |
| | logger.info(f" Expected max (codebook_size - 1): {expected_max}") |
| | logger.info(f" Codebook size: {video_tokenizer.codebook_size}") |
| | logger.info(f" Mask token ID: {video_tokenizer.mask_token_id}") |
| | |
| | if global_max > expected_max: |
| | logger.warning(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!") |
| | elif global_max == expected_max: |
| | logger.info(f" ✓ Max code matches expected max") |
| | else: |
| | logger.info(f" Note: Max code {global_max} < expected max {expected_max} (some codes may not be used)") |
| | |
| | if global_min < 0: |
| | logger.warning(f" ⚠️ WARNING: Found code {global_min} < 0!") |
| | elif global_min == 0: |
| | logger.info(f" ✓ Min code is 0 (as expected)") |
| | else: |
| | logger.info(f" Note: Min code {global_min} > 0 (some codes may not be used)") |
| | |
| | logger.info("=" * 80) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|