w11wo commited on
Commit
6e92463
1 Parent(s): 5ac2b74

initial commit

Browse files
Files changed (5) hide show
  1. README.md +3 -3
  2. app.py +74 -0
  3. example1_x2.jpg +0 -0
  4. model.ort +0 -0
  5. requirements.txt +3 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Image Upscaling Playground
3
- emoji:
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 2.9.4
8
  app_file: app.py
 
1
  ---
2
  title: Image Upscaling Playground
3
+ emoji: 🦆
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 2.9.4
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import onnxruntime
4
+ from glob import glob
5
+ import os
6
+ import gradio as gr
7
+
8
+ import gradio as gr
9
+
10
+
11
+ def pre_process(img: np.array) -> np.array:
12
+ # H, W, C -> C, H, W
13
+ img = np.transpose(img[:, :, 0:3], (2, 0, 1))
14
+ # C, H, W -> 1, C, H, W
15
+ img = np.expand_dims(img, axis=0).astype(np.float32)
16
+ return img
17
+
18
+
19
+ def post_process(img: np.array) -> np.array:
20
+ # 1, C, H, W -> C, H, W
21
+ img = np.squeeze(img)
22
+ # C, H, W -> H, W, C
23
+ img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
24
+ return img
25
+
26
+
27
+ def inference(model_path: str, img_array: np.array) -> np.array:
28
+ ort_session = onnxruntime.InferenceSession(model_path)
29
+ ort_inputs = {ort_session.get_inputs()[0].name: img_array}
30
+ ort_outs = ort_session.run(None, ort_inputs)
31
+
32
+ return ort_outs[0]
33
+
34
+
35
+ def convert_pil_to_cv2(image):
36
+ # pil_image = image.convert("RGB")
37
+ open_cv_image = np.array(image)
38
+ # RGB to BGR
39
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
40
+ return open_cv_image
41
+
42
+
43
+ def main(image):
44
+ model_path = "./model.ort"
45
+ img = convert_pil_to_cv2(image)
46
+ if img.ndim == 2:
47
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
48
+
49
+ if img.shape[2] == 4:
50
+ alpha = img[:, :, 3] # GRAY
51
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
52
+ alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
53
+ alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
54
+
55
+ img = img[:, :, 0:3] # BGR
56
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
57
+ image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
58
+ image_output[:, :, 3] = alpha_output
59
+
60
+ elif img.shape[2] == 3:
61
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
62
+
63
+ return image_output
64
+
65
+
66
+ gr.Interface(
67
+ main,
68
+ gr.inputs.Image(type="pil"),
69
+ "image",
70
+ title="Image Upscaling 🦆",
71
+ allow_flagging="never",
72
+ css=".output-image, .input-image, .image-preview {height: 500px !important} ",
73
+ ).launch()
74
+
example1_x2.jpg ADDED
model.ort ADDED
Binary file (261 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy
2
+ onnxruntime
3
+ opencv-python