abhicake commited on
Commit
d069ac3
1 Parent(s): 9796abc

Upload example_inference.py

Browse files
Files changed (1) hide show
  1. example_inference.py +39 -0
example_inference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage import io
2
+ import torch, os
3
+ from PIL import Image
4
+ from briarmbg import BriaRMBG
5
+ from utilities import preprocess_image, postprocess_image
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ def example_inference():
9
+
10
+ im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
11
+
12
+ net = BriaRMBG()
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15
+ net.to(device)
16
+ net.eval()
17
+
18
+ # prepare input
19
+ model_input_size = [1024,1024]
20
+ orig_im = io.imread(im_path)
21
+ orig_im_size = orig_im.shape[0:2]
22
+ image = preprocess_image(orig_im, model_input_size).to(device)
23
+
24
+ # inference
25
+ result=net(image)
26
+
27
+ # post process
28
+ result_image = postprocess_image(result[0][0], orig_im_size)
29
+
30
+ # save result
31
+ pil_im = Image.fromarray(result_image)
32
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
33
+ orig_image = Image.open(im_path)
34
+ no_bg_image.paste(orig_image, mask=pil_im)
35
+ no_bg_image.save("example_image_no_bg.png")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ example_inference()