Spaces:
Running
Running
| import gradio as gr, numpy as np | |
| from PIL import Image, ImageOps, ImageDraw, ImageFont | |
| from pathlib import Path | |
| import os, requests | |
| from features import binarize, feat_vec, cosine_sim, stroke_normalize, _ensure_ink_true | |
| from features_preproc import crop_and_center as crop_ref, LO | |
| from huggingface_hub import hf_hub_download | |
| from skimage.color import rgb2gray | |
| import numpy as np | |
| ASSET_REPO = "raycosine/detangutify-data" | |
| FONT_PATH = "data/NotoSerifTangut-Regular.ttf" | |
| URL = "https://notofonts.github.io/tangut/fonts/NotoSerifTangut/full/ttf/NotoSerifTangut-Regular.ttf" | |
| if not os.path.exists(FONT_PATH): | |
| os.makedirs(os.path.dirname(FONT_PATH), exist_ok=True) | |
| r = requests.get(URL) | |
| with open(FONT_PATH, "wb") as f: | |
| f.write(r.content) | |
| DATA_PATH = hf_hub_download(repo_id=ASSET_REPO, repo_type="dataset", | |
| filename="templates_aug.npz") | |
| DATA = np.load(DATA_PATH) | |
| X = DATA["X"] | |
| Y = DATA["y"] | |
| MEAN = DATA.get("mean", None) | |
| STD = DATA.get("std", None) | |
| SIZE = 64 | |
| THUMB = 128 | |
| _font = ImageFont.truetype(FONT_PATH, THUMB) | |
| def render_glyph(cp:int) -> Image.Image: | |
| ch = chr(cp) | |
| canvas = Image.new("L", (THUMB, THUMB), 0) | |
| d = ImageDraw.Draw(canvas) | |
| bbox = d.textbbox((0,0), ch, font=_font) | |
| w, h = bbox[2]-bbox[0], bbox[3]-bbox[1] | |
| x = (THUMB - w)//2 - bbox[0] | |
| y = (THUMB - h)//2 - bbox[1] | |
| d.text((x, y), ch, fill=255, font=_font) | |
| return ImageOps.invert(canvas) | |
| import numpy as np | |
| from scipy.ndimage import center_of_mass | |
| from PIL import Image | |
| def crop_and_center_deprecated(bw, size=64, pad=2): | |
| coords = np.argwhere(bw > 0) | |
| if coords.size == 0: | |
| return np.zeros((size, size), dtype=np.float32) | |
| y0, x0 = coords.min(axis=0) | |
| y1, x1 = coords.max(axis=0) + 1 | |
| cropped = bw[y0:y1, x0:x1] | |
| pil = Image.fromarray((cropped*255).astype(np.uint8)) | |
| pil = pil.resize((size - 2*pad, size - 2*pad), Image.BILINEAR) | |
| canvas = Image.new("L", (size, size), 0) | |
| canvas.paste(pil, (pad, pad)) | |
| arr = np.array(canvas, dtype=np.float32) / 255.0 | |
| cy, cx = center_of_mass(arr) | |
| shift_y, shift_x = int(size/2 - cy), int(size/2 - cx) | |
| arr = np.roll(arr, shift_y, axis=0) | |
| arr = np.roll(arr, shift_x, axis=1) | |
| return arr | |
| def _to_gray_uint8(arr: np.ndarray) -> np.uint8: | |
| if arr.ndim == 2: | |
| if arr.dtype != np.uint8: | |
| a = arr.astype(np.float32) | |
| if a.max() <= 1.0: a *= 255.0 | |
| arr = np.clip(a, 0, 255).astype(np.uint8) | |
| return arr | |
| if arr.ndim == 3: | |
| a = arr.astype(np.float32) | |
| if a.max() > 1.0: | |
| a /= 255.0 | |
| if a.shape[2] == 4: | |
| rgb = a[..., :3] | |
| alpha = a[..., 3:4] | |
| a = rgb * alpha + (1.0 - alpha) * 1.0 | |
| elif a.shape[2] >= 3: | |
| a = a[..., :3] | |
| g = rgb2gray(a) | |
| return (g * 255.0).astype(np.uint8) | |
| # 其它奇怪形状:兜底到 uint8 | |
| return np.clip(arr, 0, 255).astype(np.uint8) | |
| def infer(img): | |
| if img is None: | |
| return [], None | |
| if isinstance(img, dict): | |
| picked = None | |
| for k in ("image", "composite", "background"): | |
| v = img.get(k, None) | |
| if v is not None: | |
| picked = v | |
| break | |
| if picked is None: | |
| return [], None | |
| img = picked | |
| if isinstance(img, Image.Image): | |
| arr = np.array(img) | |
| else: | |
| arr = np.asarray(img) | |
| arr = _to_gray_uint8(arr) | |
| if arr.dtype != np.uint8: | |
| arr = np.clip(arr, 0, 255).astype(np.uint8) | |
| #pil = Image.fromarray(arr, mode="L").resize((SIZE, SIZE), Image.BILINEAR) | |
| #bw = binarize(np.array(pil, dtype=np.uint8)) | |
| #bw = crop_and_center(bw, SIZE) | |
| #bw = stroke_normalize(bw, target_px=3) | |
| #bw = crop_ref(bw, out_size=LO, margin_frac=0.08) # 用训练同款 | |
| #bw = stroke_normalize(bw, target_px=2) | |
| bw0 = binarize(arr) | |
| bw0 = _ensure_ink_true(bw0) | |
| bw = crop_ref(bw0, out_size=LO, margin_frac=0.08) | |
| bw = stroke_normalize(bw, target_px=3) | |
| viz_img = Image.fromarray((bw*255).astype(np.uint8)) | |
| q = feat_vec(bw) | |
| if MEAN is not None and STD is not None: | |
| q = (q - MEAN.ravel()) / STD.ravel() | |
| s = cosine_sim(q, X) | |
| idxs = np.argsort(-s)[:10] | |
| top, sec = float(s[idxs[0]]), float(s[idxs[1]]) if len(idxs)>1 else (float(s[idxs[0]]), -1) | |
| low_conf = (top < 0.58) or (top - sec < 0.05) | |
| seen = set() | |
| gallery_items = [] | |
| results_json = [] | |
| for idx in idxs: | |
| cp = int(Y[idx]); sc = float(s[idx]) | |
| if cp in seen: | |
| continue | |
| seen.add(cp) | |
| glyph_img = render_glyph(cp) | |
| caption = f"U+{cp:05X} {chr(cp)}\nScore: {sc:.6f}" | |
| gallery_items.append((glyph_img, caption)) | |
| results_json.append({"cp": cp, "char": chr(cp), "score": sc}) | |
| if len(gallery_items) >= 10: | |
| break | |
| return gallery_items, viz_img, results_json | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### Detangutify (Tangut Character classifier)") | |
| with gr.Row(): | |
| canvas = gr.Sketchpad( | |
| label="Draw here", | |
| image_mode="L", | |
| brush=gr.Brush(colors="black", default_size=2), | |
| canvas_size=(192,192), | |
| type="numpy", | |
| ) | |
| upload = gr.Image( | |
| label="Upload a character image", | |
| type="numpy", | |
| #image_mode="L", | |
| #sources=["upload"], | |
| height=192 | |
| ) | |
| gallery = gr.Gallery( | |
| label="Top-10 Results", | |
| columns=10, | |
| preview=False, | |
| height=320 | |
| ) | |
| preview = gr.Image(label="stroke_normalize result", type="pil") | |
| jsonout = gr.JSON(label="Top-10 (JSON)", visible=False) | |
| btn_draw = gr.Button("Search (draw)") | |
| btn_draw.click(fn=infer, inputs=canvas, outputs=[gallery, preview, jsonout], api_name="predict") | |
| btn_upload = gr.Button("Search (upload)") | |
| btn_upload.click(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout], api_name="predict_upload") | |
| upload.change(fn=infer, inputs=upload, outputs=[gallery, preview, jsonout]) | |
| #canvas.change(fn=infer, inputs=canvas, outputs=gallery) | |
| api_img = gr.Image(type="pil", visible=False) | |
| api_btn = gr.Button(visible=False) | |
| api_btn.click(fn=infer, inputs=api_img, outputs=[gallery, preview, jsonout], api_name="predict_img") | |
| if __name__ == "__main__": | |
| demo.launch() | |