schirrmacher commited on
Commit
04566b4
1 Parent(s): 08aed96

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. utils/inference.py +33 -7
utils/inference.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import numpy as np
3
  from PIL import Image
4
  from skimage import io
@@ -6,6 +7,31 @@ from ormbg import ORMBG
6
  import torch.nn.functional as F
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
10
  if len(im.shape) < 3:
11
  im = im[:, :, np.newaxis]
@@ -27,19 +53,19 @@ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
27
  return im_array
28
 
29
 
30
- def example_inference():
31
-
32
- image_path = "example.png"
33
- result_name = "no-background.png"
34
 
35
  net = ORMBG()
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
  if torch.cuda.is_available():
39
- net.load_state_dict(torch.load("models/ormbg.pth"))
40
  net = net.cuda()
41
  else:
42
- net.load_state_dict(torch.load("models/ormbg.pth", map_location="cpu"))
43
  net.eval()
44
 
45
  model_input_size = [1024, 1024]
@@ -61,4 +87,4 @@ def example_inference():
61
 
62
 
63
  if __name__ == "__main__":
64
- example_inference()
 
1
  import torch
2
+ import argparse
3
  import numpy as np
4
  from PIL import Image
5
  from skimage import io
 
7
  import torch.nn.functional as F
8
 
9
 
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser(
12
+ description="Remove background from images using ORMBG model."
13
+ )
14
+ parser.add_argument(
15
+ "--input",
16
+ type=str,
17
+ default="example.png",
18
+ help="Path to the input image file.",
19
+ )
20
+ parser.add_argument(
21
+ "--output",
22
+ type=str,
23
+ default="no-background.png",
24
+ help="Path to the output image file.",
25
+ )
26
+ parser.add_argument(
27
+ "--model-path",
28
+ type=str,
29
+ default="models/ormbg.pth",
30
+ help="Path to the model file.",
31
+ )
32
+ return parser.parse_args()
33
+
34
+
35
  def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
36
  if len(im.shape) < 3:
37
  im = im[:, :, np.newaxis]
 
53
  return im_array
54
 
55
 
56
+ def inference(args):
57
+ image_path = args.input
58
+ result_name = args.output
59
+ model_path = args.model_path
60
 
61
  net = ORMBG()
62
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
 
64
  if torch.cuda.is_available():
65
+ net.load_state_dict(torch.load(model_path))
66
  net = net.cuda()
67
  else:
68
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
69
  net.eval()
70
 
71
  model_input_size = [1024, 1024]
 
87
 
88
 
89
  if __name__ == "__main__":
90
+ inference(parse_args())