lucid-hf's picture
CI: deploy Docker/PDM Space
98a3af2 verified
#!/usr/bin/env python3
"""
DEIM Model Freezing Script - Convert DEIM models to TorchScript
This script loads a trained DEIM model from config and checkpoint files,
then exports it as a TorchScript model for deployment.
Usage:
python freeze_model.py -c config.yml -ckpt checkpoint.pth -o frozen_model.pt
Features:
- Supports both torch.jit.trace and torch.jit.script methods
- Validates frozen model output against original
- Includes postprocessor in frozen model (optional)
- Performance benchmarking
- Multiple input size support
"""
import argparse
import os
import sys
import time
import warnings
from typing import Tuple, Dict, Any
import torch
import torch.nn as nn
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'engine'))
from engine.core import YAMLConfig
# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
class FrozenDEIMModel(nn.Module):
def __init__(self, model: nn.Module, input_size: Tuple[int, int] = (1024, 800)):
super().__init__()
self.model = model
self.input_size = input_size
self.model.eval()
if hasattr(self.model, 'deploy'):
self.model.deploy()
def forward(self, images: torch.Tensor) -> tuple[Any, Any]:
out = self.model(images)
return out["pred_logits"], out["pred_boxes"]
def load_model_from_config(config_path: str, checkpoint_path: str,
device: str = 'cuda') -> Tuple[nn.Module, nn.Module, Dict]:
print(f"Loading model configuration from: {config_path}")
print(f"Loading checkpoint from: {checkpoint_path}")
# Load configuration
config_overrides = {'HGNetv2': {'pretrained': False}}
cfg = YAMLConfig(config_path, resume=checkpoint_path, **config_overrides)
# Load checkpoint
print("Loading model weights...")
state_dict = torch.load(checkpoint_path, map_location='cpu')['model']
cfg.model.load_state_dict(state_dict)
# Move to device and set eval mode
model = cfg.model.eval().to(device)
print(f"Model loaded successfully on {device}")
print(f"Model type: {type(model).__name__}")
return model, cfg
def create_sample_inputs(batch_size: int = 1, input_size: Tuple[int, int] = (1024, 800),
device: str = 'cuda') -> torch.Tensor:
height, width = input_size
sample_images = torch.randn(batch_size, 3, height, width,
dtype=torch.float32, device=device)
return sample_images
def freeze_model(model: nn.Module,
output_path: str, method: str = 'trace',
input_size: Tuple[int, int] = (1024, 800),
batch_size: int = 1, device: str = 'cuda') -> bool:
print(f"\n=== Freezing Model ===")
print(f"Method: {method}")
print(f"Input size: {input_size}")
print(f"Batch size: {batch_size}")
print(f"Device: {device}")
try:
wrapper_model = FrozenDEIMModel(
model=model,
input_size=input_size
)
sample_inputs = create_sample_inputs(batch_size, input_size, device)
print(f"Sample input shape: {sample_inputs.shape}")
print(f"Freezing model using {method}...")
if method == 'trace':
with torch.no_grad():
frozen_model = torch.jit.trace(wrapper_model, (sample_inputs,))
elif method == 'script':
frozen_model = torch.jit.script(wrapper_model)
else:
raise ValueError(f"Unknown freezing method: {method}")
print(f"Saving frozen model to: {output_path}")
frozen_model.save(output_path)
file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
print(f"Frozen model size: {file_size:.2f} MB")
print("βœ… Model freezing completed successfully!")
return True
except Exception as e:
print(f"❌ Error during model freezing: {e}")
import traceback
traceback.print_exc()
return False
def test_frozen_model(original_model, frozen_model_path: str, device: str = 'cuda',
input_size: Tuple[int, int] = (1024, 800)) -> bool:
"""
Test loading and running frozen model.
Args:
frozen_model_path: Path to frozen model
device: Device to load model on
input_size: Input size as (height, width) tuple for testing
Returns:
True if test successful
"""
print(f"\n=== Testing Frozen Model ===")
print(f"Loading from: {frozen_model_path}")
frozen_model = torch.jit.load(frozen_model_path, map_location=device)
frozen_model.eval()
print("βœ“ Frozen model loaded successfully")
height, width = input_size
test_images = torch.randn(1, 3, height, width, device=device)
with torch.no_grad():
start_time = time.time()
freezed_outputs = frozen_model(test_images)
inference_time = time.time() - start_time
with torch.no_grad():
start_time = time.time()
outputs = original_model(test_images)
inference_time = time.time() - start_time
print(f"βœ“ Inference successful ({inference_time * 1000:.2f} ms)")
def main():
parser = argparse.ArgumentParser(description="DEIM Model Freezing Script")
parser.add_argument('-c', '--config', type=str, required=True,
help='Path to config file')
parser.add_argument('-ckpt', '--checkpoint', type=str, required=True,
help='Path to model checkpoint')
parser.add_argument('-o', '--output', type=str, required=True,
help='Output path for frozen model')
parser.add_argument('-d', '--device', type=str, default='cpu',
help='Device to use (cuda/cpu)')
parser.add_argument('--method', type=str, default='trace', choices=['trace', 'script'],
help='Freezing method (trace or script)')
parser.add_argument('--input-size', type=str, default='800,1024',
help='Input image size for tracing as "height,width" (default: 800,1024)')
parser.add_argument('--batch-size', type=int, default=1,
help='Batch size for tracing')
parser.add_argument('--no-validate', action='store_true',
help='Skip validation of frozen model')
parser.add_argument('--no-benchmark', action='store_true',
help='Skip performance benchmarking')
parser.add_argument('--test-only', action='store_true',
help='Only test existing frozen model')
args = parser.parse_args()
# Parse input size
try:
if ',' in args.input_size:
height, width = map(int, args.input_size.split(','))
input_size = (height, width)
else:
# If single number provided, assume square
size = int(args.input_size)
input_size = (size, size)
except ValueError:
print(f"Error: Invalid input size format: {args.input_size}")
print("Use format 'height,width' (e.g., '1024,800') or single number for square")
return
# Check if files exist
if not args.test_only:
if not os.path.exists(args.config):
print(f"Error: Config file not found: {args.config}")
return
if not os.path.exists(args.checkpoint):
print(f"Error: Checkpoint file not found: {args.checkpoint}")
return
# Check device availability
if args.device == 'cuda' and not torch.cuda.is_available():
print("Warning: CUDA not available, using CPU")
args.device = 'cpu'
print("=== DEIM Model Freezing Script ===")
if not args.test_only:
print(f"Config: {args.config}")
print(f"Checkpoint: {args.checkpoint}")
print(f"Output: {args.output}")
print(f"Device: {args.device}")
print(f"Method: {args.method}")
print(f"Input size: {input_size[0]}x{input_size[1]} (HxW)")
print(f"Batch size: {args.batch_size}")
print("=" * 35)
try:
model, cfg = load_model_from_config(args.config, args.checkpoint, args.device)
success = freeze_model(
model=model,
output_path=args.output,
method=args.method,
input_size=input_size,
batch_size=args.batch_size,
device=args.device,
)
if success and os.path.exists(args.output):
test_frozen_model(model, args.output, args.device, input_size)
except Exception as e:
print(f"Error during execution: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()