File size: 905 Bytes
12b5ec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
from huggingface_hub import hf_hub_download
from models.isnet import ISNetDIS
REPO_ID = "leonelhs/removators"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = ISNetDIS()
model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.pth')
net.load_state_dict(torch.load(model_path, map_location=device))
net.to(device)
net.eval()
dummy_input = torch.ones(1, 3, 1024, 1024)
# Export the model
torch.onnx.export(
net, # model
dummy_input, # example input
"linear_model.onnx", # output file
input_names=["input"], # name inputs
output_names=["output"], # name outputs
dynamic_axes={ # allow variable batch size
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
opset_version=17 # ONNX version
)
print("Model exported to linear_model.onnx")
|