import torch from huggingface_hub import hf_hub_download from spaces.zero.torch.aoti import ZeroGPUCompiledModel from spaces.zero.torch.aoti import ZeroGPUWeights from spaces.zero.torch.aoti import drain_module_parameters def aoti_load_(module: torch.nn.Module, repo_id: str, filename: str): compiled_graph_file = hf_hub_download(repo_id, filename) state_dict = module.state_dict() zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()}) compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights) setattr(module, "forward", compiled) drain_module_parameters(module)