Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| DEIM Model Freezing Script - Convert DEIM models to TorchScript | |
| This script loads a trained DEIM model from config and checkpoint files, | |
| then exports it as a TorchScript model for deployment. | |
| Usage: | |
| python freeze_model.py -c config.yml -ckpt checkpoint.pth -o frozen_model.pt | |
| Features: | |
| - Supports both torch.jit.trace and torch.jit.script methods | |
| - Validates frozen model output against original | |
| - Includes postprocessor in frozen model (optional) | |
| - Performance benchmarking | |
| - Multiple input size support | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| import warnings | |
| from typing import Tuple, Dict, Any | |
| import torch | |
| import torch.nn as nn | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'engine')) | |
| from engine.core import YAMLConfig | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings('ignore', category=UserWarning) | |
| warnings.filterwarnings('ignore', category=FutureWarning) | |
| class FrozenDEIMModel(nn.Module): | |
| def __init__(self, model: nn.Module, input_size: Tuple[int, int] = (1024, 800)): | |
| super().__init__() | |
| self.model = model | |
| self.input_size = input_size | |
| self.model.eval() | |
| if hasattr(self.model, 'deploy'): | |
| self.model.deploy() | |
| def forward(self, images: torch.Tensor) -> tuple[Any, Any]: | |
| out = self.model(images) | |
| return out["pred_logits"], out["pred_boxes"] | |
| def load_model_from_config(config_path: str, checkpoint_path: str, | |
| device: str = 'cuda') -> Tuple[nn.Module, nn.Module, Dict]: | |
| print(f"Loading model configuration from: {config_path}") | |
| print(f"Loading checkpoint from: {checkpoint_path}") | |
| # Load configuration | |
| config_overrides = {'HGNetv2': {'pretrained': False}} | |
| cfg = YAMLConfig(config_path, resume=checkpoint_path, **config_overrides) | |
| # Load checkpoint | |
| print("Loading model weights...") | |
| state_dict = torch.load(checkpoint_path, map_location='cpu')['model'] | |
| cfg.model.load_state_dict(state_dict) | |
| # Move to device and set eval mode | |
| model = cfg.model.eval().to(device) | |
| print(f"Model loaded successfully on {device}") | |
| print(f"Model type: {type(model).__name__}") | |
| return model, cfg | |
| def create_sample_inputs(batch_size: int = 1, input_size: Tuple[int, int] = (1024, 800), | |
| device: str = 'cuda') -> torch.Tensor: | |
| height, width = input_size | |
| sample_images = torch.randn(batch_size, 3, height, width, | |
| dtype=torch.float32, device=device) | |
| return sample_images | |
| def freeze_model(model: nn.Module, | |
| output_path: str, method: str = 'trace', | |
| input_size: Tuple[int, int] = (1024, 800), | |
| batch_size: int = 1, device: str = 'cuda') -> bool: | |
| print(f"\n=== Freezing Model ===") | |
| print(f"Method: {method}") | |
| print(f"Input size: {input_size}") | |
| print(f"Batch size: {batch_size}") | |
| print(f"Device: {device}") | |
| try: | |
| wrapper_model = FrozenDEIMModel( | |
| model=model, | |
| input_size=input_size | |
| ) | |
| sample_inputs = create_sample_inputs(batch_size, input_size, device) | |
| print(f"Sample input shape: {sample_inputs.shape}") | |
| print(f"Freezing model using {method}...") | |
| if method == 'trace': | |
| with torch.no_grad(): | |
| frozen_model = torch.jit.trace(wrapper_model, (sample_inputs,)) | |
| elif method == 'script': | |
| frozen_model = torch.jit.script(wrapper_model) | |
| else: | |
| raise ValueError(f"Unknown freezing method: {method}") | |
| print(f"Saving frozen model to: {output_path}") | |
| frozen_model.save(output_path) | |
| file_size = os.path.getsize(output_path) / (1024 * 1024) # MB | |
| print(f"Frozen model size: {file_size:.2f} MB") | |
| print("β Model freezing completed successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"β Error during model freezing: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_frozen_model(original_model, frozen_model_path: str, device: str = 'cuda', | |
| input_size: Tuple[int, int] = (1024, 800)) -> bool: | |
| """ | |
| Test loading and running frozen model. | |
| Args: | |
| frozen_model_path: Path to frozen model | |
| device: Device to load model on | |
| input_size: Input size as (height, width) tuple for testing | |
| Returns: | |
| True if test successful | |
| """ | |
| print(f"\n=== Testing Frozen Model ===") | |
| print(f"Loading from: {frozen_model_path}") | |
| frozen_model = torch.jit.load(frozen_model_path, map_location=device) | |
| frozen_model.eval() | |
| print("β Frozen model loaded successfully") | |
| height, width = input_size | |
| test_images = torch.randn(1, 3, height, width, device=device) | |
| with torch.no_grad(): | |
| start_time = time.time() | |
| freezed_outputs = frozen_model(test_images) | |
| inference_time = time.time() - start_time | |
| with torch.no_grad(): | |
| start_time = time.time() | |
| outputs = original_model(test_images) | |
| inference_time = time.time() - start_time | |
| print(f"β Inference successful ({inference_time * 1000:.2f} ms)") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="DEIM Model Freezing Script") | |
| parser.add_argument('-c', '--config', type=str, required=True, | |
| help='Path to config file') | |
| parser.add_argument('-ckpt', '--checkpoint', type=str, required=True, | |
| help='Path to model checkpoint') | |
| parser.add_argument('-o', '--output', type=str, required=True, | |
| help='Output path for frozen model') | |
| parser.add_argument('-d', '--device', type=str, default='cpu', | |
| help='Device to use (cuda/cpu)') | |
| parser.add_argument('--method', type=str, default='trace', choices=['trace', 'script'], | |
| help='Freezing method (trace or script)') | |
| parser.add_argument('--input-size', type=str, default='800,1024', | |
| help='Input image size for tracing as "height,width" (default: 800,1024)') | |
| parser.add_argument('--batch-size', type=int, default=1, | |
| help='Batch size for tracing') | |
| parser.add_argument('--no-validate', action='store_true', | |
| help='Skip validation of frozen model') | |
| parser.add_argument('--no-benchmark', action='store_true', | |
| help='Skip performance benchmarking') | |
| parser.add_argument('--test-only', action='store_true', | |
| help='Only test existing frozen model') | |
| args = parser.parse_args() | |
| # Parse input size | |
| try: | |
| if ',' in args.input_size: | |
| height, width = map(int, args.input_size.split(',')) | |
| input_size = (height, width) | |
| else: | |
| # If single number provided, assume square | |
| size = int(args.input_size) | |
| input_size = (size, size) | |
| except ValueError: | |
| print(f"Error: Invalid input size format: {args.input_size}") | |
| print("Use format 'height,width' (e.g., '1024,800') or single number for square") | |
| return | |
| # Check if files exist | |
| if not args.test_only: | |
| if not os.path.exists(args.config): | |
| print(f"Error: Config file not found: {args.config}") | |
| return | |
| if not os.path.exists(args.checkpoint): | |
| print(f"Error: Checkpoint file not found: {args.checkpoint}") | |
| return | |
| # Check device availability | |
| if args.device == 'cuda' and not torch.cuda.is_available(): | |
| print("Warning: CUDA not available, using CPU") | |
| args.device = 'cpu' | |
| print("=== DEIM Model Freezing Script ===") | |
| if not args.test_only: | |
| print(f"Config: {args.config}") | |
| print(f"Checkpoint: {args.checkpoint}") | |
| print(f"Output: {args.output}") | |
| print(f"Device: {args.device}") | |
| print(f"Method: {args.method}") | |
| print(f"Input size: {input_size[0]}x{input_size[1]} (HxW)") | |
| print(f"Batch size: {args.batch_size}") | |
| print("=" * 35) | |
| try: | |
| model, cfg = load_model_from_config(args.config, args.checkpoint, args.device) | |
| success = freeze_model( | |
| model=model, | |
| output_path=args.output, | |
| method=args.method, | |
| input_size=input_size, | |
| batch_size=args.batch_size, | |
| device=args.device, | |
| ) | |
| if success and os.path.exists(args.output): | |
| test_frozen_model(model, args.output, args.device, input_size) | |
| except Exception as e: | |
| print(f"Error during execution: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == '__main__': | |
| main() | |