TrOCR-digit / app.py
aico's picture
Update app.py
91678c8
raw
history blame
No virus
1.22 kB
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import numpy as np
import cv2
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("aico/TrOCR-MNIST")
def process_image(image):
#print(np.shape(image))
#print(image)
#rint(image.astype('uint8'))
#cv2.imwrite("image.png",image.astype('uint8'),(28, 28))
img = Image.fromarray(image.astype('uint8')).convert("RGB")
#img = Image.open("image.png").convert("RGB")
print(img)
# prepare image
pixel_values = processor(img, return_tensors="pt").pixel_values
# generate (no beam search)
generated_ids = model.generate(pixel_values)
# decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
title = "Interactive demo: Single Digits MNIST"
description = "Aico - University Utrecht"
iface = gr.Interface(fn=process_image,
inputs="sketchpad",
outputs="label",
title = title,
description = description)
iface.launch(debug=True)