sayakpaul's picture
sayakpaul HF Staff
up
bb10560
raw
history blame contribute delete
635 Bytes
import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
from spaces.zero.torch.aoti import drain_module_parameters
def aoti_load_(module: torch.nn.Module, repo_id: str, filename: str):
compiled_graph_file = hf_hub_download(repo_id, filename)
state_dict = module.state_dict()
zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
setattr(module, "forward", compiled)
drain_module_parameters(module)