OriLib commited on
Commit
e9e3ad5
1 Parent(s): d054980

Update example_inference.py

Browse files
Files changed (1) hide show
  1. example_inference.py +3 -1
example_inference.py CHANGED
@@ -3,15 +3,17 @@ import torch, os
3
  from PIL import Image
4
  from briarmbg import BriaRMBG
5
  from utilities import preprocess_image, postprocess_image
 
6
 
7
  def example_inference():
8
 
9
- model_path = f"{os.path.dirname(os.path.abspath(__file__))}/model.pth"
10
  im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
11
 
12
  net = BriaRMBG()
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  net.load_state_dict(torch.load(model_path, map_location=device))
 
15
  net.eval()
16
 
17
  # prepare input
 
3
  from PIL import Image
4
  from briarmbg import BriaRMBG
5
  from utilities import preprocess_image, postprocess_image
6
+ from huggingface_hub import hf_hub_download
7
 
8
  def example_inference():
9
 
10
+ model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
11
  im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
12
 
13
  net = BriaRMBG()
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.load_state_dict(torch.load(model_path, map_location=device))
16
+ net.to(device)
17
  net.eval()
18
 
19
  # prepare input