Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| from typing import * | |
| import torch | |
| from .models import BaseModel | |
| from .utilities import ( | |
| build_engine, | |
| export_onnx, | |
| handle_onnx_batch_norm, | |
| optimize_onnx, | |
| ) | |
| class EngineBuilder: | |
| def __init__( | |
| self, | |
| model: BaseModel, | |
| network: Any, | |
| device=torch.device("cuda"), | |
| ): | |
| self.device = device | |
| self.model = model | |
| self.network = network | |
| def build( | |
| self, | |
| onnx_path: str, | |
| onnx_opt_path: str, | |
| engine_path: str, | |
| opt_image_height: int = 512, | |
| opt_image_width: int = 512, | |
| opt_batch_size: int = 1, | |
| min_image_resolution: int = 256, | |
| max_image_resolution: int = 1024, | |
| build_enable_refit: bool = False, | |
| build_static_batch: bool = False, | |
| build_dynamic_shape: bool = False, | |
| build_all_tactics: bool = False, | |
| onnx_opset: int = 17, | |
| force_engine_build: bool = False, | |
| force_onnx_export: bool = False, | |
| force_onnx_optimize: bool = False, | |
| ignore_onnx_optimize: bool = False, | |
| auto_cast: bool = True, | |
| handle_batch_norm: bool = False, | |
| ): | |
| if not force_onnx_export and os.path.exists(onnx_path): | |
| print(f"Found cached model: {onnx_path}") | |
| else: | |
| print(f"Exporting model: {onnx_path}") | |
| export_onnx( | |
| self.network, | |
| onnx_path=onnx_path, | |
| model_data=self.model, | |
| opt_image_height=opt_image_height, | |
| opt_image_width=opt_image_width, | |
| opt_batch_size=opt_batch_size, | |
| onnx_opset=onnx_opset, | |
| auto_cast=auto_cast, | |
| ) | |
| del self.network | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if handle_batch_norm: | |
| print(f"Handle Batch Norm for {onnx_path}") | |
| handle_onnx_batch_norm(onnx_path) | |
| if ignore_onnx_optimize: | |
| print(f"Ignore onnx optimize for {onnx_path}.") | |
| onnx_opt_path = onnx_path | |
| elif not force_onnx_optimize and os.path.exists(onnx_opt_path): | |
| print(f"Found cached model: {onnx_opt_path}") | |
| else: | |
| print(f"Generating optimizing model: {onnx_opt_path}") | |
| optimize_onnx( | |
| onnx_path=onnx_path, | |
| onnx_opt_path=onnx_opt_path, | |
| model_data=self.model, | |
| ) | |
| self.model.min_latent_shape = min_image_resolution // 8 | |
| self.model.max_latent_shape = max_image_resolution // 8 | |
| if not force_engine_build and os.path.exists(engine_path): | |
| print(f"Found cached engine: {engine_path}") | |
| else: | |
| build_engine( | |
| engine_path=engine_path, | |
| onnx_opt_path=onnx_opt_path, | |
| model_data=self.model, | |
| opt_image_height=opt_image_height, | |
| opt_image_width=opt_image_width, | |
| opt_batch_size=opt_batch_size, | |
| build_static_batch=build_static_batch, | |
| build_dynamic_shape=build_dynamic_shape, | |
| build_all_tactics=build_all_tactics, | |
| build_enable_refit=build_enable_refit, | |
| ) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |