import gradio as gr, numpy as np from PIL import Image, ImageOps, ImageDraw, ImageFont from pathlib import Path import os, requests from features import binarize, feat_vec, cosine_sim, stroke_normalize, _ensure_ink_true from features_preproc import crop_and_center as crop_ref, LO from huggingface_hub import hf_hub_download from skimage.color import rgb2gray import numpy as np ASSET_REPO = "raycosine/detangutify-data" FONT_PATH = "data/NotoSerifTangut-Regular.ttf" URL = "https://notofonts.github.io/tangut/fonts/NotoSerifTangut/full/ttf/NotoSerifTangut-Regular.ttf" if not os.path.exists(FONT_PATH): os.makedirs(os.path.dirname(FONT_PATH), exist_ok=True) r = requests.get(URL) with open(FONT_PATH, "wb") as f: f.write(r.content) DATA_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset", filename="templates_aug.npz") DATA = np.load(DATA_PATH) X = DATA["X"] Y = DATA["y"] MEAN = DATA.get("mean", None) STD = DATA.get("std", None) SIZE = 64 THUMB = 128 _font = ImageFont.truetype(FONT_PATH, THUMB) def render_glyph(cp:int) -> Image.Image: ch = chr(cp) canvas = Image.new("L", (THUMB, THUMB), 0) d = ImageDraw.Draw(canvas) bbox = d.textbbox((0,0), ch, font=_font) w, h = bbox[2]-bbox[0], bbox[3]-bbox[1] x = (THUMB - w)//2 - bbox[0] y = (THUMB - h)//2 - bbox[1] d.text((x, y), ch, fill=255, font=_font) return ImageOps.invert(canvas) import numpy as np from scipy.ndimage import center_of_mass from PIL import Image def crop_and_center_deprecated(bw, size=64, pad=2): coords = np.argwhere(bw > 0) if coords.size == 0: return np.zeros((size, size), dtype=np.float32) y0, x0 = coords.min(axis=0) y1, x1 = coords.max(axis=0) + 1 cropped = bw[y0:y1, x0:x1] pil = Image.fromarray((cropped*255).astype(np.uint8)) pil = pil.resize((size - 2*pad, size - 2*pad), Image.BILINEAR) canvas = Image.new("L", (size, size), 0) canvas.paste(pil, (pad, pad)) arr = np.array(canvas, dtype=np.float32) / 255.0 cy, cx = center_of_mass(arr) shift_y, shift_x = int(size/2 - cy), int(size/2 - cx) arr = np.roll(arr, shift_y, axis=0) arr = np.roll(arr, shift_x, axis=1) return arr def _to_gray_uint8(arr: np.ndarray) -> np.uint8: if arr.ndim == 2: if arr.dtype != np.uint8: a = arr.astype(np.float32) if a.max() <= 1.0: a *= 255.0 arr = np.clip(a, 0, 255).astype(np.uint8) return arr if arr.ndim == 3: a = arr.astype(np.float32) if a.max() > 1.0: a /= 255.0 if a.shape[2] == 4: rgb = a[..., :3] alpha = a[..., 3:4] a = rgb * alpha + (1.0 - alpha) * 1.0 elif a.shape[2] >= 3: a = a[..., :3] g = rgb2gray(a) return (g * 255.0).astype(np.uint8) # 其它奇怪形状:兜底到 uint8 return np.clip(arr, 0, 255).astype(np.uint8) def infer(img): if img is None: return [], None if isinstance(img, dict): picked = None for k in ("image", "composite", "background"): v = img.get(k, None) if v is not None: picked = v break if picked is None: return [], None img = picked if isinstance(img, Image.Image): arr = np.array(img) else: arr = np.asarray(img) arr = _to_gray_uint8(arr) if arr.dtype != np.uint8: arr = np.clip(arr, 0, 255).astype(np.uint8) #pil = Image.fromarray(arr, mode="L").resize((SIZE, SIZE), Image.BILINEAR) #bw = binarize(np.array(pil, dtype=np.uint8)) #bw = crop_and_center(bw, SIZE) #bw = stroke_normalize(bw, target_px=3) #bw = crop_ref(bw, out_size=LO, margin_frac=0.08) # 用训练同款 #bw = stroke_normalize(bw, target_px=2) bw0 = binarize(arr) bw0 = _ensure_ink_true(bw0) bw = crop_ref(bw0, out_size=LO, margin_frac=0.08) bw = stroke_normalize(bw, target_px=3) viz_img = Image.fromarray((bw*255).astype(np.uint8)) q = feat_vec(bw) if MEAN is not None and STD is not None: q = (q - MEAN.ravel()) / STD.ravel() s = cosine_sim(q, X) idxs = np.argsort(-s)[:10] top, sec = float(s[idxs[0]]), float(s[idxs[1]]) if len(idxs)>1 else (float(s[idxs[0]]), -1) low_conf = (top < 0.58) or (top - sec < 0.05) seen = set() gallery_items = [] results_json = [] for idx in idxs: cp = int(Y[idx]); sc = float(s[idx]) if cp in seen: continue seen.add(cp) glyph_img = render_glyph(cp) caption = f"U+{cp:05X} {chr(cp)}\nScore: {sc:.6f}" gallery_items.append((glyph_img, caption)) results_json.append({"cp": cp, "char": chr(cp), "score": sc}) if len(gallery_items) >= 10: break return gallery_items, viz_img, results_json with gr.Blocks() as demo: gr.Markdown("### Detangutify (Tangut Character classifier)") with gr.Row(): canvas = gr.Sketchpad( label="Draw here", image_mode="L", brush=gr.Brush(colors="black", default_size=2), canvas_size=(192,192), type="numpy", ) upload = gr.Image( label="Upload a character image", type="numpy", #image_mode="L", #sources=["upload"], height=192 ) gallery = gr.Gallery( label="Top-10 Results", columns=10, preview=False, height=320 ) preview = gr.Image(label="stroke_normalize result", type="pil") jsonout = gr.JSON(label="Top-10 (JSON)", visible=False) btn_draw = gr.Button("Search (draw)") btn_draw.click(fn=infer, inputs=canvas, outputs=[gallery, preview, jsonout], api_name="predict") btn_upload = gr.Button("Search (upload)") btn_upload.click(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout], api_name="predict_upload") upload.change(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout]) #canvas.change(fn=infer, inputs=canvas, outputs=gallery) api_img = gr.Image(type="pil", visible=False) api_btn = gr.Button(visible=False) api_btn.click(fn=infer, inputs=api_img, outputs=[gallery, preview, jsonout], api_name="predict_img") if __name__ == "__main__": demo.launch()