Use from_pretrained

#1
by nielsr HF staff - opened
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -2,7 +2,6 @@ import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
- from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
  from gradio_imageslider import ImageSlider
8
  from briarmbg import BriaRMBG
@@ -10,15 +9,10 @@ import PIL
10
  from PIL import Image
11
  from typing import Tuple
12
 
13
- net=BriaRMBG()
14
- # model_path = "./model1.pth"
15
- model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
- if torch.cuda.is_available():
17
- net.load_state_dict(torch.load(model_path))
18
- net=net.cuda()
19
- else:
20
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
- net.eval()
22
 
23
 
24
  def resize_image(image):
 
2
  import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
 
5
  import gradio as gr
6
  from gradio_imageslider import ImageSlider
7
  from briarmbg import BriaRMBG
 
9
  from PIL import Image
10
  from typing import Tuple
11
 
12
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ net.to(device)
 
 
 
 
 
16
 
17
 
18
  def resize_image(image):