Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
from PIL import Image | |
from transformers import TrOCRProcessor | |
from transformers import VisionEncoderDecoderModel | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import warnings | |
warnings.filterwarnings("ignore") | |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") | |
def extract_text(image): | |
# calling the processor is equivalent to calling the feature extractor | |
pixel_values = processor(image, return_tensors="pt").pixel_values | |
generated_ids = model.generate(pixel_values) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_text | |
def hand_written(image_raw): | |
image_raw = np.array(image_raw) | |
image = cv2.cvtColor(image_raw,cv2.COLOR_BGR2GRAY) | |
image = cv2.GaussianBlur(image,(5,5),0) | |
image = cv2.threshold(image,200,255,cv2.THRESH_BINARY_INV)[1] | |
kernal = cv2.getStructuringElement(cv2.MORPH_RECT,(10,1)) | |
image = cv2.dilate(image,kernal,iterations=5) | |
contours,hier = cv2.findContours(image,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE) | |
all_box = [] | |
for i in contours: | |
bbox = cv2.boundingRect(i) | |
all_box.append(bbox) | |
# Calculate maximum rectangle height | |
c = np.array(all_box) | |
max_height = np.max(c[::, 3]) | |
# Sort the contours by y-value | |
by_y = sorted(all_box, key=lambda x: x[1]) # y values | |
line_y = by_y[0][1] # first y | |
line = 1 | |
by_line = [] | |
# Assign a line number to each contour | |
for x, y, w, h in by_y: | |
if y > line_y + max_height: | |
line_y = y | |
line += 1 | |
by_line.append((line, x, y, w, h)) | |
# This will now sort automatically by line then by x | |
contours_sorted = [(x, y, w, h) for line, x, y, w, h in sorted(by_line)] | |
text = "" | |
for line in contours_sorted: | |
x,y,w,h = line | |
cropped_image = image_raw[y:y+h,x:x+w] | |
try: | |
extracted = extract_text(cropped_image) | |
if not extracted == "0 0" and not extracted == "0 1": | |
text = "\n".join([text,extracted]) | |
except: | |
print("skiping") | |
pass | |
return text | |
# load image examples from the IAM database | |
title = "TrOCR + EN_ICR demo" | |
description = "TrOCR Handwritten Recognizer" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models</a> | <a href='https://github.com/microsoft/unilm/tree/master/trocr'>Github Repo</a></p>" | |
examples =[["image_0.png"]] | |
iface = gr.Interface(fn=hand_written, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Textbox(), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples) | |
iface.launch(debug=True,share=True) |