Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,246 Bytes
f8ba7b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import spaces
import subprocess
import json
from PIL import Image, ImageDraw
import os
import tempfile
# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
"Microsoft Handwritten": "microsoft/trocr-base-handwritten",
"Medieval Base": "medieval-data/trocr-medieval-base",
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
"Medieval Textualis": "medieval-data/trocr-medieval-textualis",
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
"Medieval Print": "medieval-data/trocr-medieval-print"
}
# Global variables to store the current model and processor
current_model = None
current_processor = None
current_model_name = None
def load_model(model_name):
global current_model, current_processor, current_model_name
if model_name != current_model_name:
model_id = MODEL_OPTIONS[model_name]
current_processor = TrOCRProcessor.from_pretrained(model_id)
current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
current_model_name = model_name
# Move model to GPU
current_model = current_model.to('cuda')
return current_processor, current_model
@spaces.GPU
def process_image(image, model_name):
# Save the uploaded image to a temporary file
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
image.save(temp_img, format="JPEG")
temp_img_path = temp_img.name
# Run Kraken for line detection
lines_json_path = "lines.json"
kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl"
subprocess.run(kraken_command, shell=True, check=True)
# Load the lines from the JSON file
with open(lines_json_path, 'r') as f:
lines_data = json.load(f)
processor, model = load_model(model_name)
# Process each line
transcriptions = []
for line in lines_data['lines']:
# Extract line coordinates
x1, y1 = line['baseline'][0]
x2, y2 = line['baseline'][-1]
# Crop the line from the original image
line_image = image.crop((x1, y1, x2, y2))
# Prepare image for TrOCR
pixel_values = processor(line_image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to('cuda')
# Generate (no beam search)
with torch.no_grad():
generated_ids = model.generate(pixel_values)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcriptions.append(generated_text)
# Clean up temporary files
os.unlink(temp_img_path)
os.unlink(lines_json_path)
# Create an image with bounding boxes
draw = ImageDraw.Draw(image)
for line in lines_data['lines']:
coords = line['baseline']
draw.line(coords, fill="red", width=2)
return image, "\n".join(transcriptions)
# Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Medieval Document Transcription")
gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
with gr.Row():
output_image = gr.Image(type="pil", label="Detected Lines")
transcription_output = gr.Textbox(label="Transcription", lines=10)
submit_button = gr.Button("Transcribe")
submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
iface.launch() |