Spaces:
Sleeping
Sleeping
import sys | |
sys.path.append('.') | |
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw | |
from huggingface_hub import hf_hub_download | |
from model.model import ImageToDigitTransformer | |
from utils.tokenizer import START, FINISH, decode | |
from gradio_ui.preprocess import preprocess_canvases | |
# Load model | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
model_path = hf_hub_download(repo_id="nico-x/vit-transformer-mnist", filename="transformer_mnist.pt") | |
model = ImageToDigitTransformer(vocab_size=13).to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
def split_into_quadrants(image): | |
"""Split a PIL Image or numpy array into 4 quadrants (TL, TR, BL, BR).""" | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
w, h = image.size | |
return [ | |
np.array(image.crop((0, 0, w // 2, h // 2))), | |
np.array(image.crop((w // 2, 0, w, h // 2))), | |
np.array(image.crop((0, h // 2, w // 2, h))), | |
np.array(image.crop((w // 2, h // 2, w, h))), | |
] | |
def predict_digit_sequence(editor_data): | |
"""Predicts 4-digit sequence from 2×2 canvas image.""" | |
if editor_data is None or "composite" not in editor_data: | |
return "No image provided." | |
img = editor_data["composite"] | |
quadrants = split_into_quadrants(img) | |
image_tensor = preprocess_canvases(quadrants).to(device) | |
decoded = [START] | |
for _ in range(4): | |
input_ids = torch.tensor(decoded, dtype=torch.long).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
logits = model(image_tensor, input_ids) | |
next_token = torch.argmax(logits[:, -1, :], dim=-1).item() | |
decoded.append(next_token) | |
if next_token == FINISH: | |
break | |
pred = decoded[1:] | |
return "".join(decode(pred[:4])) | |
def create_black_canvas(size=(800, 800)): | |
"""Create a black canvas with a 2×2 light gray grid overlay.""" | |
img = Image.new("L", size, color=0) | |
draw = ImageDraw.Draw(img) | |
w, h = size | |
draw.line([(w // 2, 0), (w // 2, h)], fill=128, width=2) | |
draw.line([(0, h // 2), (w, h // 2)], fill=128, width=2) | |
return img | |
# === UI === | |
canvas_size = 800 | |
with gr.Blocks() as demo: | |
gr.Markdown("## Draw 4 digits in a 2×2 grid using a white brush") | |
canvas = gr.ImageEditor( | |
label="White brush only on black canvas (no uploads)", | |
value=create_black_canvas(), | |
image_mode="L", | |
height=canvas_size, | |
width=canvas_size, | |
sources=[], # disables uploads | |
type="pil", | |
brush=gr.Brush(colors=["#FFFFFF"], default_color="#FFFFFF", default_size=15, color_mode="fixed") | |
) | |
predict_btn = gr.Button("Predict") | |
clear_btn = gr.Button("Erase") | |
output = gr.Textbox(label="Predicted 4-digit sequence", interactive=True) | |
predict_btn.click(fn=predict_digit_sequence, inputs=[canvas], outputs=[output]) | |
clear_btn.click(fn=lambda: create_black_canvas(), inputs=[], outputs=[canvas]) | |
demo.launch() |