Spaces:
Sleeping
Sleeping
| import sys | |
| sys.path.append('.') | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from huggingface_hub import hf_hub_download | |
| from model.model import ImageToDigitTransformer | |
| from utils.tokenizer import START, FINISH, decode | |
| from gradio_ui.preprocess import preprocess_canvases | |
| # Load model | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| model_path = hf_hub_download(repo_id="nico-x/vit-transformer-mnist", filename="transformer_mnist.pt") | |
| model = ImageToDigitTransformer(vocab_size=13).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| def split_into_quadrants(image): | |
| """Split a PIL Image or numpy array into 4 quadrants (TL, TR, BL, BR).""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| w, h = image.size | |
| return [ | |
| np.array(image.crop((0, 0, w // 2, h // 2))), | |
| np.array(image.crop((w // 2, 0, w, h // 2))), | |
| np.array(image.crop((0, h // 2, w // 2, h))), | |
| np.array(image.crop((w // 2, h // 2, w, h))), | |
| ] | |
| def predict_digit_sequence(editor_data): | |
| """Predicts 4-digit sequence from 2×2 canvas image.""" | |
| if editor_data is None or "composite" not in editor_data: | |
| return "No image provided." | |
| img = editor_data["composite"] | |
| quadrants = split_into_quadrants(img) | |
| image_tensor = preprocess_canvases(quadrants).to(device) | |
| decoded = [START] | |
| for _ in range(4): | |
| input_ids = torch.tensor(decoded, dtype=torch.long).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(image_tensor, input_ids) | |
| next_token = torch.argmax(logits[:, -1, :], dim=-1).item() | |
| decoded.append(next_token) | |
| if next_token == FINISH: | |
| break | |
| pred = decoded[1:] | |
| return "".join(decode(pred[:4])) | |
| def create_black_canvas(size=(800, 800)): | |
| """Create a black canvas with a 2×2 light gray grid overlay.""" | |
| img = Image.new("L", size, color=0) | |
| draw = ImageDraw.Draw(img) | |
| w, h = size | |
| draw.line([(w // 2, 0), (w // 2, h)], fill=128, width=2) | |
| draw.line([(0, h // 2), (w, h // 2)], fill=128, width=2) | |
| return img | |
| # === UI === | |
| canvas_size = 800 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Draw 4 digits in a 2×2 grid using a white brush") | |
| canvas = gr.ImageEditor( | |
| label="White brush only on black canvas (no uploads)", | |
| value=create_black_canvas(), | |
| image_mode="L", | |
| height=canvas_size, | |
| width=canvas_size, | |
| sources=[], # disables uploads | |
| type="pil", | |
| brush=gr.Brush(colors=["#FFFFFF"], default_color="#FFFFFF", default_size=15, color_mode="fixed") | |
| ) | |
| predict_btn = gr.Button("Predict") | |
| clear_btn = gr.Button("Erase") | |
| output = gr.Textbox(label="Predicted 4-digit sequence", interactive=True) | |
| predict_btn.click(fn=predict_digit_sequence, inputs=[canvas], outputs=[output]) | |
| clear_btn.click(fn=lambda: create_black_canvas(), inputs=[], outputs=[canvas]) | |
| demo.launch() |