import gradio as gr import os from PIL import Image from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoImageProcessor # import utils import base64 # from datasets import load_metric import evaluate import logging # Only show log messages that are at the ERROR level or above, effectively filtering out any warnings logging.getLogger('transformers').setLevel(logging.ERROR) processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") image_processor = AutoImageProcessor.from_pretrained("pstroe/bullinger-general-model") model = VisionEncoderDecoderModel.from_pretrained("pstroe/bullinger-general-model") # Create examples # Get images and respective transcriptions from the examples directory def get_example_data(folder_path="./examples/"): example_data = [] # Get list of all files in the folder all_files = os.listdir(folder_path) # Loop through the file list for file_name in all_files: file_path = os.path.join(folder_path, file_name) # Check if the file is an image (.png) if file_name.endswith(".png"): # Construct the corresponding .txt filename (same name) corresponding_text_file_name = file_name.replace(".png", ".txt") corresponding_text_file_path = os.path.join(folder_path, corresponding_text_file_name) # Initialize to a default value transcription = "Transcription not found." # Try to read the content from the .txt file try: with open(corresponding_text_file_path, "r") as f: transcription = f.read().strip() except FileNotFoundError: pass # If the corresponding .txt file is not found, leave the default value example_data.append([file_path, transcription]) return example_data # From pstroe's script # def compute_metrics(pred): # labels_ids = pred.label_ids # pred_ids = pred.predictions # pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) # labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id # label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) # cer = cer_metric.compute(predictions=pred_str, references=label_str) # return {"cer": cer} def process_image(image, ground_truth): cer = None # prepare image pixel_values = image_processor(image, return_tensors="pt").pixel_values # generate (no beam search) generated_ids = model.generate(pixel_values) # decode generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] if ground_truth is not None and ground_truth.strip() != "": # Debug: Print lengths before computing metric print("Number of predictions:", len(generated_text)) print("Number of references:", len(ground_truth)) # Check if lengths match if len(generated_text) != len(ground_truth): print("Mismatch in number of predictions and references.") print("Predictions:", generated_text) print("References:", ground_truth) print("\n") cer = cer_metric.compute(predictions=[generated_text], references=[ground_truth]) # cer = f"{cer:.3f}" else: cer = "Ground truth not provided" return generated_text, cer # One way to use .svg files # logo_url = "https://www.bullinger-digital.ch/bullinger-digital.svg" # logo_url = "https://www.cl.uzh.ch/docroot/logos/uzh_logo_e_pos.svg" # header_html = "".format( # utils.img_to_bytes(".uzh_logo_e_pos.svg") # ) # Encode images with open("assets/uzh_logo_mod.png", "rb") as img_file: logo_html = base64.b64encode(img_file.read()).decode('utf-8') # with open("assets/bullinger-digital.png", "rb") as img_file: with open("assets/bullinger_logo.png", "rb") as img_file: footer_html = base64.b64encode(img_file.read()).decode('utf-8') # App header title = """

TrOCR: Bullinger Dataset

""" description = """ Use of Microsoft's [TrOCR](https://arxiv.org/abs/2109.10282), an encoder-decoder model consisting of an \ image Transformer encoder and a text Transformer decoder for state-of-the-art optical character recognition \ (OCR) and handwritten text recognition (HTR) on text line images. \ This particular model was fine-tuned on [Bullinger Dataset](https://github.com/pstroe/bullinger-htr) \ as part of the project [Bullinger Digital](https://www.bullinger-digital.ch) ([References](https://www.cl.uzh.ch/de/people/team/compling/pstroebel.html#Publications)). * HF `model card`: [pstroe/bullinger-general-model](https://huggingface.co/pstroe/bullinger-general-model) | \ [Flexible Techniques for Automatic Text Recognition of Historical Documents](https://doi.org/10.5167/uzh-234886) """ # articles = """ #

TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
# Flexible Techniques for Automatic Text Recognition of Historical Documents
# Bullingers Briefwechsel zugänglich machen: Stand der Handschriftenerkennung

# """ # Read .png and the respective .txt files examples = get_example_data() # load_metric() is deprecated # cer_metric = load_metric("cer") # pip install jiwer # pip install evaluate cer_metric = evaluate.load("cer") with gr.Blocks( theme=gr.themes.Soft(), title="TrOCR Bullinger", ) as demo: gr.HTML( f"""
""" ) #174x60 title = gr.HTML(title) description = gr.Markdown(description) with gr.Row(): with gr.Column(variant="panel"): input = gr.components.Image(type="pil", label="Input image:") with gr.Row(): btn_clear = gr.Button(value="Clear") button = gr.Button(value="Submit") with gr.Column(variant="panel"): output = gr.components.Textbox(label="Generated text:") ground_truth = gr.components.Textbox(value="", placeholder="Provide the ground truth, if available.", label="Ground truth:") cer_output = gr.components.Textbox(label="CER:") with gr.Row(): with gr.Accordion(label="Choose an example from test set:", open=False): gr.Examples( examples=examples, inputs = [input, ground_truth], label=None, ) with gr.Row(): # gr.HTML( # f""" #
# #

# Bullinger Digital | Institut für Computerlinguistik, Universität Zürich, 2023 #

#
# """ # ) gr.HTML( f"""

BullingerDigital | Institut für Computerlinguistik, Universität Zürich, 2023

""" ) #383x85 button.click(process_image, inputs=[input, ground_truth], outputs=[output, cer_output]) btn_clear.click(lambda: [None, "", "", ""], outputs=[input, output, ground_truth, cer_output]) # # Try to force light mode # js = """ # function () { # gradioURL = window.location.href # if (!gradioURL.endsWith('?__theme=light')) { # window.location.replace(gradioURL + '?__theme=light'); # } # }""" # demo.load(_js=js) if __name__ == "__main__": demo.launch(favicon_path="icon.png")