File size: 3,234 Bytes
0dfb412
f2019a4
e98a756
f2019a4
 
ffbf266
61dc098
e98a756
31559f1
ffbf266
e98a756
dfbcb2e
dd838d3
ffbf266
e98a756
ffbf266
 
eed441d
0dfb412
21257a3
eed441d
 
 
 
 
 
 
 
21257a3
bafe915
0dfb412
bafe915
eed441d
 
 
 
 
 
 
e98a756
eed441d
0dfb412
e98a756
 
ffbf266
e98a756
ffbf266
e98a756
 
 
 
 
ffbf266
 
e98a756
ffbf266
 
 
c4873ef
ffbf266
 
 
61dc098
 
 
 
 
 
ffbf266
 
 
 
61dc098
ffbf266
61dc098
 
ffbf266
e98a756
 
 
 
 
 
 
 
 
 
 
 
ffbf266
e98a756
 
 
3481362
f2019a4
21257a3
cd9ce00
b100458
21257a3
 
e98a756
7468778
0dfb412
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import transformers
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import difflib
from concurrent.futures import ThreadPoolExecutor
import os

# OCR Correction Model
model_name = "PleIAs/OCRonos-Vintage"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# CSS for formatting
css = """
<style>
.generation {
    margin-left: 2em;
    margin-right: 2em;
    font-size: 1.2em;
}
.inserted {
    background-color: #90EE90;
}
</style>
"""

def generate_html_diff(old_text, new_text):
    d = difflib.Differ()
    diff = list(d.compare(old_text.split(), new_text.split()))
    html_diff = []
    for word in diff:
        if word.startswith(' '):
            html_diff.append(word[2:])
        elif word.startswith('+ '):
            html_diff.append(f'<span class="inserted">{word[2:]}</span>')
    return ' '.join(html_diff)

def split_text(text, max_tokens=400):
    tokens = tokenizer.tokenize(text)
    chunks = []
    current_chunk = []

    for token in tokens:
        current_chunk.append(token)
        if len(current_chunk) >= max_tokens:
            chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
            current_chunk = []

    if current_chunk:
        chunks.append(tokenizer.convert_tokens_to_string(current_chunk))

    return chunks

def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
    prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    torch.set_num_threads(num_threads)

    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        future = executor.submit(
            model.generate,
            input_ids,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            top_k=50,
            num_return_sequences=1,
            do_sample=False
        )
        output = future.result()

    result = tokenizer.decode(output[0], skip_special_tokens=True)
    return result.split("### Correction ###")[1].strip()

def process_text(user_message):
    chunks = split_text(user_message)
    corrected_chunks = []

    for chunk in chunks:
        corrected_chunk = ocr_correction(chunk)
        corrected_chunks.append(corrected_chunk)

    corrected_text = ' '.join(corrected_chunks)
    html_diff = generate_html_diff(user_message, corrected_text)
    
    ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
    final_output = f"{css}{ocr_result}"
    return final_output

# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
    gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
    text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
    process_button = gr.Button("Process Text")
    text_output = gr.HTML(label="Processed text")
    process_button.click(process_text, inputs=text_input, outputs=[text_output])

if __name__ == "__main__":
    demo.queue().launch()