| """ |
| Evaluate checkpoints using run_evaluation_heavy. |
| Thin wrapper that discovers checkpoints and calls run_evaluation_heavy. |
| """ |
|
|
| import argparse |
| import json |
| import os |
| from glob import glob |
|
|
| from .pipeline import run_evaluation_heavy |
|
|
|
|
| def discover_checkpoints(checkpoint_pattern: str) -> list: |
| """Discover checkpoint files matching pattern.""" |
| checkpoints = sorted(glob(checkpoint_pattern)) |
| return checkpoints |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate checkpoints") |
| parser.add_argument( |
| "--checkpoint_pattern", |
| type=str, |
| required=True, |
| help="Glob pattern for checkpoints, e.g., 'output/baseline/checkpoint_step_*/ema_weights.pt'", |
| ) |
| parser.add_argument( |
| "--model_config", |
| type=str, |
| required=True, |
| help='Model config as JSON string, e.g., \'{"in_channels":16,"hidden_size":768,...}\'', |
| ) |
| parser.add_argument( |
| "--vae_path", |
| type=str, |
| default="REPA-E/e2e-qwenimage-vae", |
| help="Path to VAE model", |
| ) |
| parser.add_argument( |
| "--text_encoder_path", |
| type=str, |
| default="Qwen/Qwen3-0.6B", |
| help="Path to text encoder", |
| ) |
| parser.add_argument( |
| "--pooling", |
| action="store_true", |
| default=False, |
| help="Use pooled text embeddings", |
| ) |
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default="./precomputed_dataset/heavy-eval@256p", |
| help="Path to evaluation dataset", |
| ) |
| parser.add_argument( |
| "--num_samples", |
| type=int, |
| default=2000, |
| help="Number of samples for CLIP", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=32, |
| help="Batch size", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda:0", |
| help="Device", |
| ) |
| parser.add_argument( |
| "--skip_existing", |
| action="store_true", |
| default=False, |
| help="Skip checkpoints with existing results", |
| ) |
| args = parser.parse_args() |
|
|
| |
| model_config = json.loads(args.model_config) |
|
|
| |
| checkpoints = discover_checkpoints(args.checkpoint_pattern) |
| print(f"Found {len(checkpoints)} checkpoints") |
|
|
| for checkpoint_path in checkpoints: |
| checkpoint_dir = os.path.dirname(checkpoint_path) |
| results_file = os.path.join(checkpoint_dir, "evaluation_results.json") |
|
|
| |
| if args.skip_existing and os.path.exists(results_file): |
| print(f"Skipping {checkpoint_path} (results exist)") |
| continue |
|
|
| print(f"\nEvaluating {checkpoint_path}") |
| os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
| try: |
| results = run_evaluation_heavy( |
| checkpoint_path=checkpoint_path, |
| model_config=model_config, |
| vae_path=args.vae_path, |
| text_encoder_path=args.text_encoder_path, |
| pooling=args.pooling, |
| save_path=checkpoint_dir, |
| dataset_path=args.dataset_path, |
| num_samples=args.num_samples, |
| batch_size=args.batch_size, |
| device=args.device, |
| ) |
|
|
| |
| with open(results_file, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"Results: {results}") |
|
|
| except Exception as e: |
| print(f"Error: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|