Detangutify / app.py
raycosine
add image uploading option
a6d1f86
import gradio as gr, numpy as np
from PIL import Image, ImageOps, ImageDraw, ImageFont
from pathlib import Path
import os, requests
from features import binarize, feat_vec, cosine_sim, stroke_normalize, _ensure_ink_true
from features_preproc import crop_and_center as crop_ref, LO
from huggingface_hub import hf_hub_download
from skimage.color import rgb2gray
import numpy as np
ASSET_REPO = "raycosine/detangutify-data"
FONT_PATH = "data/NotoSerifTangut-Regular.ttf"
URL = "https://notofonts.github.io/tangut/fonts/NotoSerifTangut/full/ttf/NotoSerifTangut-Regular.ttf"
if not os.path.exists(FONT_PATH):
os.makedirs(os.path.dirname(FONT_PATH), exist_ok=True)
r = requests.get(URL)
with open(FONT_PATH, "wb") as f:
f.write(r.content)
DATA_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset",
filename="templates_aug.npz")
DATA = np.load(DATA_PATH)
X = DATA["X"]
Y = DATA["y"]
MEAN = DATA.get("mean", None)
STD = DATA.get("std", None)
SIZE = 64
THUMB = 128
_font = ImageFont.truetype(FONT_PATH, THUMB)
def render_glyph(cp:int) -> Image.Image:
ch = chr(cp)
canvas = Image.new("L", (THUMB, THUMB), 0)
d = ImageDraw.Draw(canvas)
bbox = d.textbbox((0,0), ch, font=_font)
w, h = bbox[2]-bbox[0], bbox[3]-bbox[1]
x = (THUMB - w)//2 - bbox[0]
y = (THUMB - h)//2 - bbox[1]
d.text((x, y), ch, fill=255, font=_font)
return ImageOps.invert(canvas)
import numpy as np
from scipy.ndimage import center_of_mass
from PIL import Image
def crop_and_center_deprecated(bw, size=64, pad=2):
coords = np.argwhere(bw > 0)
if coords.size == 0:
return np.zeros((size, size), dtype=np.float32)
y0, x0 = coords.min(axis=0)
y1, x1 = coords.max(axis=0) + 1
cropped = bw[y0:y1, x0:x1]
pil = Image.fromarray((cropped*255).astype(np.uint8))
pil = pil.resize((size - 2*pad, size - 2*pad), Image.BILINEAR)
canvas = Image.new("L", (size, size), 0)
canvas.paste(pil, (pad, pad))
arr = np.array(canvas, dtype=np.float32) / 255.0
cy, cx = center_of_mass(arr)
shift_y, shift_x = int(size/2 - cy), int(size/2 - cx)
arr = np.roll(arr, shift_y, axis=0)
arr = np.roll(arr, shift_x, axis=1)
return arr
def _to_gray_uint8(arr: np.ndarray) -> np.uint8:
if arr.ndim == 2:
if arr.dtype != np.uint8:
a = arr.astype(np.float32)
if a.max() <= 1.0: a *= 255.0
arr = np.clip(a, 0, 255).astype(np.uint8)
return arr
if arr.ndim == 3:
a = arr.astype(np.float32)
if a.max() > 1.0:
a /= 255.0
if a.shape[2] == 4:
rgb = a[..., :3]
alpha = a[..., 3:4]
a = rgb * alpha + (1.0 - alpha) * 1.0
elif a.shape[2] >= 3:
a = a[..., :3]
g = rgb2gray(a)
return (g * 255.0).astype(np.uint8)
# 其它奇怪形状:兜底到 uint8
return np.clip(arr, 0, 255).astype(np.uint8)
def infer(img):
if img is None:
return [], None
if isinstance(img, dict):
picked = None
for k in ("image", "composite", "background"):
v = img.get(k, None)
if v is not None:
picked = v
break
if picked is None:
return [], None
img = picked
if isinstance(img, Image.Image):
arr = np.array(img)
else:
arr = np.asarray(img)
arr = _to_gray_uint8(arr)
if arr.dtype != np.uint8:
arr = np.clip(arr, 0, 255).astype(np.uint8)
#pil = Image.fromarray(arr, mode="L").resize((SIZE, SIZE), Image.BILINEAR)
#bw = binarize(np.array(pil, dtype=np.uint8))
#bw = crop_and_center(bw, SIZE)
#bw = stroke_normalize(bw, target_px=3)
#bw = crop_ref(bw, out_size=LO, margin_frac=0.08) # 用训练同款
#bw = stroke_normalize(bw, target_px=2)
bw0 = binarize(arr)
bw0 = _ensure_ink_true(bw0)
bw = crop_ref(bw0, out_size=LO, margin_frac=0.08)
bw = stroke_normalize(bw, target_px=3)
viz_img = Image.fromarray((bw*255).astype(np.uint8))
q = feat_vec(bw)
if MEAN is not None and STD is not None:
q = (q - MEAN.ravel()) / STD.ravel()
s = cosine_sim(q, X)
idxs = np.argsort(-s)[:10]
top, sec = float(s[idxs[0]]), float(s[idxs[1]]) if len(idxs)>1 else (float(s[idxs[0]]), -1)
low_conf = (top < 0.58) or (top - sec < 0.05)
seen = set()
gallery_items = []
results_json = []
for idx in idxs:
cp = int(Y[idx]); sc = float(s[idx])
if cp in seen:
continue
seen.add(cp)
glyph_img = render_glyph(cp)
caption = f"U+{cp:05X} {chr(cp)}\nScore: {sc:.6f}"
gallery_items.append((glyph_img, caption))
results_json.append({"cp": cp, "char": chr(cp), "score": sc})
if len(gallery_items) >= 10:
break
return gallery_items, viz_img, results_json
with gr.Blocks() as demo:
gr.Markdown("### Detangutify (Tangut Character classifier)")
with gr.Row():
canvas = gr.Sketchpad(
label="Draw here",
image_mode="L",
brush=gr.Brush(colors="black", default_size=2),
canvas_size=(192,192),
type="numpy",
)
upload = gr.Image(
label="Upload a character image",
type="numpy",
#image_mode="L",
#sources=["upload"],
height=192
)
gallery = gr.Gallery(
label="Top-10 Results",
columns=10,
preview=False,
height=320
)
preview = gr.Image(label="stroke_normalize result", type="pil")
jsonout = gr.JSON(label="Top-10 (JSON)", visible=False)
btn_draw = gr.Button("Search (draw)")
btn_draw.click(fn=infer, inputs=canvas, outputs=[gallery, preview, jsonout], api_name="predict")
btn_upload = gr.Button("Search (upload)")
btn_upload.click(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout], api_name="predict_upload")
upload.change(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout])
#canvas.change(fn=infer, inputs=canvas, outputs=gallery)
api_img = gr.Image(type="pil", visible=False)
api_btn = gr.Button(visible=False)
api_btn.click(fn=infer, inputs=api_img, outputs=[gallery, preview, jsonout], api_name="predict_img")
if __name__ == "__main__":
demo.launch()