File size: 635 Bytes
bb10560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
from spaces.zero.torch.aoti import drain_module_parameters


def aoti_load_(module: torch.nn.Module, repo_id: str, filename: str):
    compiled_graph_file = hf_hub_download(repo_id, filename)
    state_dict = module.state_dict()
    zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
    compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)

    setattr(module, "forward", compiled)
    drain_module_parameters(module)