Pclanglais's picture
Update app.py
b6cc9e1 verified
raw
history blame
3.16 kB
import spaces
import transformers
import re
import torch
import gradio as gr
import os
import ctranslate2
from concurrent.futures import ThreadPoolExecutor
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CTranslate2 model and tokenizer
model_path = "PleIAs/OCRonos-Vintage-CT2"
generator = ctranslate2.Generator(model_path, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage")
# CSS for formatting (unchanged)
css = """
<style>
... (your existing CSS)
</style>
"""
# Helper functions
def generate_html_diff(old_text, new_text):
# (unchanged)
...
def preprocess_text(text):
# (unchanged)
...
def split_text(text, max_tokens=400):
encoded = tokenizer.encode(text)
splits = []
for i in range(0, len(encoded), max_tokens):
split = encoded[i:i+max_tokens]
splits.append(tokenizer.decode(split))
return splits
# Function to generate text using CTranslate2
def ocr_correction(prompt, max_new_tokens=600):
splits = split_text(prompt, max_tokens=400)
corrected_splits = []
for split in splits:
full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
encoded = tokenizer.encode(full_prompt)
prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
result = generator.generate_batch(
[prompt_tokens],
max_length=max_new_tokens,
sampling_temperature=0.7,
sampling_topk=20,
include_prompt_in_result=False
)[0]
corrected_text = tokenizer.decode(result.sequences_ids[0])
corrected_splits.append(corrected_text)
return " ".join(corrected_splits)
# OCR Correction Class
class OCRCorrector:
def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
self.system_prompt = system_prompt
def correct(self, user_message):
generated_text = ocr_correction(user_message)
html_diff = generate_html_diff(user_message, generated_text)
return generated_text, html_diff
# Combined Processing Class
class TextProcessor:
def __init__(self):
self.ocr_corrector = OCRCorrector()
@spaces.GPU(duration=120)
def process(self, user_message):
# OCR Correction
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
# Combine results
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Create the TextProcessor instance
text_processor = TextProcessor()
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()