raycosine commited on
Commit
a6d1f86
·
1 Parent(s): e821964

add image uploading option

Browse files
Files changed (1) hide show
  1. app.py +50 -9
app.py CHANGED
@@ -5,6 +5,8 @@ import os, requests
5
  from features import binarize, feat_vec, cosine_sim, stroke_normalize, _ensure_ink_true
6
  from features_preproc import crop_and_center as crop_ref, LO
7
  from huggingface_hub import hf_hub_download
 
 
8
  ASSET_REPO = "raycosine/detangutify-data"
9
  FONT_PATH = "data/NotoSerifTangut-Regular.ttf"
10
  URL = "https://notofonts.github.io/tangut/fonts/NotoSerifTangut/full/ttf/NotoSerifTangut-Regular.ttf"
@@ -63,6 +65,34 @@ def crop_and_center_deprecated(bw, size=64, pad=2):
63
  arr = np.roll(arr, shift_y, axis=0)
64
  arr = np.roll(arr, shift_x, axis=1)
65
  return arr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def infer(img):
67
  if img is None:
68
  return [], None
@@ -82,8 +112,7 @@ def infer(img):
82
  arr = np.array(img)
83
  else:
84
  arr = np.asarray(img)
85
- if arr.ndim == 3:
86
- arr = arr[..., 0]
87
 
88
  if arr.dtype != np.uint8:
89
  arr = np.clip(arr, 0, 255).astype(np.uint8)
@@ -135,16 +164,28 @@ with gr.Blocks() as demo:
135
  type="numpy",
136
  )
137
 
138
- gallery = gr.Gallery(
139
- label="Top-10 Results",
140
- columns=10,
141
- preview=False,
142
- height=320
 
143
  )
 
 
 
 
 
 
144
  preview = gr.Image(label="stroke_normalize result", type="pil")
145
  jsonout = gr.JSON(label="Top-10 (JSON)", visible=False)
146
- btn = gr.Button("Search")
147
- btn.click(fn=infer, inputs=canvas, outputs=[gallery, preview, jsonout], api_name="predict")
 
 
 
 
 
148
  #canvas.change(fn=infer, inputs=canvas, outputs=gallery)
149
  api_img = gr.Image(type="pil", visible=False)
150
  api_btn = gr.Button(visible=False)
 
5
  from features import binarize, feat_vec, cosine_sim, stroke_normalize, _ensure_ink_true
6
  from features_preproc import crop_and_center as crop_ref, LO
7
  from huggingface_hub import hf_hub_download
8
+ from skimage.color import rgb2gray
9
+ import numpy as np
10
  ASSET_REPO = "raycosine/detangutify-data"
11
  FONT_PATH = "data/NotoSerifTangut-Regular.ttf"
12
  URL = "https://notofonts.github.io/tangut/fonts/NotoSerifTangut/full/ttf/NotoSerifTangut-Regular.ttf"
 
65
  arr = np.roll(arr, shift_y, axis=0)
66
  arr = np.roll(arr, shift_x, axis=1)
67
  return arr
68
+ def _to_gray_uint8(arr: np.ndarray) -> np.uint8:
69
+
70
+ if arr.ndim == 2:
71
+ if arr.dtype != np.uint8:
72
+ a = arr.astype(np.float32)
73
+ if a.max() <= 1.0: a *= 255.0
74
+ arr = np.clip(a, 0, 255).astype(np.uint8)
75
+ return arr
76
+
77
+ if arr.ndim == 3:
78
+ a = arr.astype(np.float32)
79
+ if a.max() > 1.0:
80
+ a /= 255.0
81
+
82
+
83
+ if a.shape[2] == 4:
84
+ rgb = a[..., :3]
85
+ alpha = a[..., 3:4]
86
+ a = rgb * alpha + (1.0 - alpha) * 1.0
87
+
88
+ elif a.shape[2] >= 3:
89
+ a = a[..., :3]
90
+
91
+ g = rgb2gray(a)
92
+ return (g * 255.0).astype(np.uint8)
93
+
94
+ # 其它奇怪形状:兜底到 uint8
95
+ return np.clip(arr, 0, 255).astype(np.uint8)
96
  def infer(img):
97
  if img is None:
98
  return [], None
 
112
  arr = np.array(img)
113
  else:
114
  arr = np.asarray(img)
115
+ arr = _to_gray_uint8(arr)
 
116
 
117
  if arr.dtype != np.uint8:
118
  arr = np.clip(arr, 0, 255).astype(np.uint8)
 
164
  type="numpy",
165
  )
166
 
167
+ upload = gr.Image(
168
+ label="Upload a character image",
169
+ type="numpy",
170
+ #image_mode="L",
171
+ #sources=["upload"],
172
+ height=192
173
  )
174
+ gallery = gr.Gallery(
175
+ label="Top-10 Results",
176
+ columns=10,
177
+ preview=False,
178
+ height=320
179
+ )
180
  preview = gr.Image(label="stroke_normalize result", type="pil")
181
  jsonout = gr.JSON(label="Top-10 (JSON)", visible=False)
182
+ btn_draw = gr.Button("Search (draw)")
183
+ btn_draw.click(fn=infer, inputs=canvas, outputs=[gallery, preview, jsonout], api_name="predict")
184
+
185
+ btn_upload = gr.Button("Search (upload)")
186
+ btn_upload.click(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout], api_name="predict_upload")
187
+ upload.change(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout])
188
+
189
  #canvas.change(fn=infer, inputs=canvas, outputs=gallery)
190
  api_img = gr.Image(type="pil", visible=False)
191
  api_btn = gr.Button(visible=False)