OriLib commited on
Commit
4cd0c8b
1 Parent(s): a924c15

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +67 -0
README.md CHANGED
@@ -1,3 +1,70 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - background-removal
5
+ - Pytorch
6
+ - vision
7
  ---
8
+
9
+ # BRIA Background Removal v1.3
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import os
15
+ import numpy as np
16
+ from skimage import io
17
+ from glob import glob
18
+ from tqdm import tqdm
19
+ import cv2
20
+ import torch.nn.functional as F
21
+ from torchvision.transforms.functional import normalize
22
+ from models import BriaRMBG
23
+
24
+ input_size=[1024,1024]
25
+ net=BriaRMBG()
26
+
27
+ model_path = "./model.pth"
28
+ im_path = "./example_image.jpg"
29
+ result_path = "."
30
+
31
+ if torch.cuda.is_available():
32
+ net.load_state_dict(torch.load(model_path))
33
+ net=net.cuda()
34
+ else:
35
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
36
+ net.eval()
37
+
38
+ # prepare input
39
+ im = io.imread(im_path)
40
+ if len(im.shape) < 3:
41
+ im = im[:, :, np.newaxis]
42
+ im_size=im.shape[0:2]
43
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
44
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=input_size, mode='bilinear').type(torch.uint8)
45
+ image = torch.divide(im_tensor,255.0)
46
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
47
+
48
+ if torch.cuda.is_available():
49
+ image=image.cuda()
50
+
51
+ # inference
52
+ result=net(image)
53
+
54
+ # post process
55
+ result = torch.squeeze(F.interpolate(result[0][0], size=im_size, mode='bilinear') ,0)
56
+ ma = torch.max(result)
57
+ mi = torch.min(result)
58
+ result = (result-mi)/(ma-mi)
59
+
60
+ # save result
61
+ im_name=im_path.split('/')[-1].split('.')[0]
62
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
63
+ cv2.imwrite(os.path.join(result_path, im_name+".png"), im_array)
64
+ ```
65
+
66
+ ## Training data
67
+ Bria-RMBG model was trained over 12000 high quality, high resolution images.
68
+ All images were manualy labeled pixel-wise accuratly. The images belong to veriety of categories, the majority of them inclues people.
69
+
70
+ ## Qualitative Evaluation