Spaces:
Sleeping
Sleeping
File size: 3,234 Bytes
0dfb412 f2019a4 e98a756 f2019a4 ffbf266 61dc098 e98a756 31559f1 ffbf266 e98a756 dfbcb2e dd838d3 ffbf266 e98a756 ffbf266 eed441d 0dfb412 21257a3 eed441d 21257a3 bafe915 0dfb412 bafe915 eed441d e98a756 eed441d 0dfb412 e98a756 ffbf266 e98a756 ffbf266 e98a756 ffbf266 e98a756 ffbf266 c4873ef ffbf266 61dc098 ffbf266 61dc098 ffbf266 61dc098 ffbf266 e98a756 ffbf266 e98a756 3481362 f2019a4 21257a3 cd9ce00 b100458 21257a3 e98a756 7468778 0dfb412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import transformers
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import difflib
from concurrent.futures import ThreadPoolExecutor
import os
# OCR Correction Model
model_name = "PleIAs/OCRonos-Vintage"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
.inserted {
background-color: #90EE90;
}
</style>
"""
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span class="inserted">{word[2:]}</span>')
return ' '.join(html_diff)
def split_text(text, max_tokens=400):
tokens = tokenizer.tokenize(text)
chunks = []
current_chunk = []
for token in tokens:
current_chunk.append(token)
if len(current_chunk) >= max_tokens:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
current_chunk = []
if current_chunk:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
return chunks
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
torch.set_num_threads(num_threads)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future = executor.submit(
model.generate,
input_ids,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=50,
num_return_sequences=1,
do_sample=False
)
output = future.result()
result = tokenizer.decode(output[0], skip_special_tokens=True)
return result.split("### Correction ###")[1].strip()
def process_text(user_message):
chunks = split_text(user_message)
corrected_chunks = []
for chunk in chunks:
corrected_chunk = ocr_correction(chunk)
corrected_chunks.append(corrected_chunk)
corrected_text = ' '.join(corrected_chunks)
html_diff = generate_html_diff(user_message, corrected_text)
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(process_text, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch() |