Spaces:
Runtime error
Runtime error
import gc | |
import os | |
from typing import * | |
import torch | |
from .models import BaseModel | |
from .utilities import ( | |
build_engine, | |
export_onnx, | |
handle_onnx_batch_norm, | |
optimize_onnx, | |
) | |
class EngineBuilder: | |
def __init__( | |
self, | |
model: BaseModel, | |
network: Any, | |
device=torch.device("cuda"), | |
): | |
self.device = device | |
self.model = model | |
self.network = network | |
def build( | |
self, | |
onnx_path: str, | |
onnx_opt_path: str, | |
engine_path: str, | |
opt_image_height: int = 512, | |
opt_image_width: int = 512, | |
opt_batch_size: int = 1, | |
min_image_resolution: int = 256, | |
max_image_resolution: int = 1024, | |
build_enable_refit: bool = False, | |
build_static_batch: bool = False, | |
build_dynamic_shape: bool = False, | |
build_all_tactics: bool = False, | |
onnx_opset: int = 17, | |
force_engine_build: bool = False, | |
force_onnx_export: bool = False, | |
force_onnx_optimize: bool = False, | |
ignore_onnx_optimize: bool = False, | |
auto_cast: bool = True, | |
handle_batch_norm: bool = False, | |
): | |
if not force_onnx_export and os.path.exists(onnx_path): | |
print(f"Found cached model: {onnx_path}") | |
else: | |
print(f"Exporting model: {onnx_path}") | |
export_onnx( | |
self.network, | |
onnx_path=onnx_path, | |
model_data=self.model, | |
opt_image_height=opt_image_height, | |
opt_image_width=opt_image_width, | |
opt_batch_size=opt_batch_size, | |
onnx_opset=onnx_opset, | |
auto_cast=auto_cast, | |
) | |
del self.network | |
gc.collect() | |
torch.cuda.empty_cache() | |
if handle_batch_norm: | |
print(f"Handle Batch Norm for {onnx_path}") | |
handle_onnx_batch_norm(onnx_path) | |
if ignore_onnx_optimize: | |
print(f"Ignore onnx optimize for {onnx_path}.") | |
onnx_opt_path = onnx_path | |
elif not force_onnx_optimize and os.path.exists(onnx_opt_path): | |
print(f"Found cached model: {onnx_opt_path}") | |
else: | |
print(f"Generating optimizing model: {onnx_opt_path}") | |
optimize_onnx( | |
onnx_path=onnx_path, | |
onnx_opt_path=onnx_opt_path, | |
model_data=self.model, | |
) | |
self.model.min_latent_shape = min_image_resolution // 8 | |
self.model.max_latent_shape = max_image_resolution // 8 | |
if not force_engine_build and os.path.exists(engine_path): | |
print(f"Found cached engine: {engine_path}") | |
else: | |
build_engine( | |
engine_path=engine_path, | |
onnx_opt_path=onnx_opt_path, | |
model_data=self.model, | |
opt_image_height=opt_image_height, | |
opt_image_width=opt_image_width, | |
opt_batch_size=opt_batch_size, | |
build_static_batch=build_static_batch, | |
build_dynamic_shape=build_dynamic_shape, | |
build_all_tactics=build_all_tactics, | |
build_enable_refit=build_enable_refit, | |
) | |
gc.collect() | |
torch.cuda.empty_cache() | |