File size: 905 Bytes
12b5ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from huggingface_hub import hf_hub_download
from models.isnet import ISNetDIS

REPO_ID = "leonelhs/removators"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

net = ISNetDIS()

model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.pth')
net.load_state_dict(torch.load(model_path, map_location=device))
net.to(device)
net.eval()

dummy_input = torch.ones(1, 3, 1024, 1024)

# Export the model
torch.onnx.export(
    net,                          # model
    dummy_input,                # example input
    "linear_model.onnx",        # output file
    input_names=["input"],      # name inputs
    output_names=["output"],    # name outputs
    dynamic_axes={              # allow variable batch size
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    },
    opset_version=17            # ONNX version
)

print("Model exported to linear_model.onnx")