import torch from huggingface_hub import hf_hub_download from models.isnet import ISNetDIS REPO_ID = "leonelhs/removators" device = 'cuda' if torch.cuda.is_available() else 'cpu' net = ISNetDIS() model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.pth') net.load_state_dict(torch.load(model_path, map_location=device)) net.to(device) net.eval() dummy_input = torch.ones(1, 3, 1024, 1024) # Export the model torch.onnx.export( net, # model dummy_input, # example input "linear_model.onnx", # output file input_names=["input"], # name inputs output_names=["output"], # name outputs dynamic_axes={ # allow variable batch size "input": {0: "batch_size"}, "output": {0: "batch_size"} }, opset_version=17 # ONNX version ) print("Model exported to linear_model.onnx")