w11wo commited on
Commit
a625565
1 Parent(s): b9043e0

initial commit

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -20,7 +20,11 @@ def post_process(img: np.array) -> np.array:
20
  return img
21
 
22
 
23
- def inference(img_array: np.array) -> np.array:
 
 
 
 
24
  ort_inputs = {ort_session.get_inputs()[0].name: img_array}
25
  ort_outs = ort_session.run(None, ort_inputs)
26
 
@@ -28,6 +32,7 @@ def inference(img_array: np.array) -> np.array:
28
 
29
 
30
  def convert_pil_to_cv2(image):
 
31
  open_cv_image = np.array(image)
32
  # RGB to BGR
33
  open_cv_image = open_cv_image[:, :, ::-1].copy()
@@ -35,6 +40,7 @@ def convert_pil_to_cv2(image):
35
 
36
 
37
  def upscale(image):
 
38
  img = convert_pil_to_cv2(image)
39
  if img.ndim == 2:
40
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
@@ -42,26 +48,20 @@ def upscale(image):
42
  if img.shape[2] == 4:
43
  alpha = img[:, :, 3] # GRAY
44
  alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
45
- alpha_output = post_process(inference(pre_process(alpha))) # BGR
46
  alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
47
 
48
  img = img[:, :, 0:3] # BGR
49
- image_output = post_process(inference(pre_process(img))) # BGR
50
  image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
51
  image_output[:, :, 3] = alpha_output
52
 
53
  elif img.shape[2] == 3:
54
- image_output = post_process(inference(pre_process(img))) # BGR
55
 
56
  return image_output
57
 
58
 
59
- model_path = "models/model.ort"
60
- options = onnxruntime.SessionOptions()
61
- options.intra_op_num_threads = 1
62
- options.inter_op_num_threads = 1
63
- ort_session = onnxruntime.InferenceSession(model_path, options)
64
-
65
  examples = [f"examples/example_{i+1}.png" for i in range(5)]
66
  css = ".output-image, .input-image, .image-preview {height: 480px !important} "
67
 
 
20
  return img
21
 
22
 
23
+ def inference(model_path: str, img_array: np.array) -> np.array:
24
+ options = onnxruntime.SessionOptions()
25
+ options.intra_op_num_threads = 1
26
+ options.inter_op_num_threads = 1
27
+ ort_session = onnxruntime.InferenceSession(model_path, options)
28
  ort_inputs = {ort_session.get_inputs()[0].name: img_array}
29
  ort_outs = ort_session.run(None, ort_inputs)
30
 
 
32
 
33
 
34
  def convert_pil_to_cv2(image):
35
+ # pil_image = image.convert("RGB")
36
  open_cv_image = np.array(image)
37
  # RGB to BGR
38
  open_cv_image = open_cv_image[:, :, ::-1].copy()
 
40
 
41
 
42
  def upscale(image):
43
+ model_path = "models/model.ort"
44
  img = convert_pil_to_cv2(image)
45
  if img.ndim == 2:
46
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 
48
  if img.shape[2] == 4:
49
  alpha = img[:, :, 3] # GRAY
50
  alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
51
+ alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
52
  alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
53
 
54
  img = img[:, :, 0:3] # BGR
55
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
56
  image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
57
  image_output[:, :, 3] = alpha_output
58
 
59
  elif img.shape[2] == 3:
60
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
61
 
62
  return image_output
63
 
64
 
 
 
 
 
 
 
65
  examples = [f"examples/example_{i+1}.png" for i in range(5)]
66
  css = ".output-image, .input-image, .image-preview {height: 480px !important} "
67