Spaces:
Running
Running
File size: 3,158 Bytes
d731e09 0dfb412 f2019a4 b6cc9e1 1fca231 31559f1 0dfb412 208476f b6cc9e1 22b51ff b6cc9e1 0dfb412 21257a3 b6cc9e1 21257a3 bafe915 0dfb412 21257a3 bafe915 b6cc9e1 0dfb412 21257a3 b6cc9e1 21257a3 0dfb412 21257a3 fa86caf bafe915 21257a3 08d035b 21257a3 b6cc9e1 21257a3 bafe915 cd9ce00 21257a3 0dfb412 21257a3 3481362 f2019a4 21257a3 cd9ce00 b100458 21257a3 7468778 0dfb412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import spaces
import transformers
import re
import torch
import gradio as gr
import os
import ctranslate2
from concurrent.futures import ThreadPoolExecutor
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CTranslate2 model and tokenizer
model_path = "PleIAs/OCRonos-Vintage-CT2"
generator = ctranslate2.Generator(model_path, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage")
# CSS for formatting (unchanged)
css = """
<style>
... (your existing CSS)
</style>
"""
# Helper functions
def generate_html_diff(old_text, new_text):
# (unchanged)
...
def preprocess_text(text):
# (unchanged)
...
def split_text(text, max_tokens=400):
encoded = tokenizer.encode(text)
splits = []
for i in range(0, len(encoded), max_tokens):
split = encoded[i:i+max_tokens]
splits.append(tokenizer.decode(split))
return splits
# Function to generate text using CTranslate2
def ocr_correction(prompt, max_new_tokens=600):
splits = split_text(prompt, max_tokens=400)
corrected_splits = []
for split in splits:
full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
encoded = tokenizer.encode(full_prompt)
prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
result = generator.generate_batch(
[prompt_tokens],
max_length=max_new_tokens,
sampling_temperature=0.7,
sampling_topk=20,
include_prompt_in_result=False
)[0]
corrected_text = tokenizer.decode(result.sequences_ids[0])
corrected_splits.append(corrected_text)
return " ".join(corrected_splits)
# OCR Correction Class
class OCRCorrector:
def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
self.system_prompt = system_prompt
def correct(self, user_message):
generated_text = ocr_correction(user_message)
html_diff = generate_html_diff(user_message, generated_text)
return generated_text, html_diff
# Combined Processing Class
class TextProcessor:
def __init__(self):
self.ocr_corrector = OCRCorrector()
@spaces.GPU(duration=120)
def process(self, user_message):
# OCR Correction
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
# Combine results
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Create the TextProcessor instance
text_processor = TextProcessor()
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch() |