hugbump's picture
Input Image has to be greyscale, otherwise broadcast error. Thus add .convert('L')
ab27eac
raw
history blame
1.03 kB
import streamlit as st
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
def process_image(image):
# prepare image
pixel_values = processor(image, return_tensors="pt").pixel_values
# generate (no beam search)
generated_ids = model.generate(pixel_values)
# decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
########################## Streamlit Code ##########################
st.title('Streamlit Replication of nielsr/TrOCR-handwritten')
uploaded_file = st.file_uploader("Choose an image...")
if uploaded_file:
# .convert('L') to greyscale
input_image = Image.open(uploaded_file).convert('L')
st.image(uploaded_file, caption='Input Image', use_column_width=True)
generated_text = process_image(input_image)
st.write(generated_text)