Spaces:
Sleeping
Sleeping
| import transformers | |
| import re | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import torch | |
| import gradio as gr | |
| from difflib import Differ | |
| from concurrent.futures import ThreadPoolExecutor | |
| import os | |
| description = """# 🙋🏻♂️Welcome to Tonic's On-Device📲⌚🎅🏻OCR Corrector (CPU) | |
| 📲⌚🎅🏻OCRonos-Vintage is a small specialized model for OCR correction of cultural heritage archives pre-trained with llm.c. OCRonos-Vintage is only 124 million parameters. It can run easily on CPU or provide correction at scale on GPUs (>10k tokens/seconds) while providing a quality of correction comparable to GPT-4 or the llama version of OCRonos for English-speaking cultural archives. | |
| ### Join us : | |
| 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 | |
| """ | |
| model_name = "PleIAs/OCRonos-Vintage" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = GPT2LMHeadModel.from_pretrained(model_name).to(device) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| def diff_texts(text1, text2): | |
| d = Differ() | |
| return [ | |
| (token[2:], token[0] if token[0] != " " else None) | |
| for token in d.compare(text1.split(), text2.split()) | |
| ] | |
| def split_text(text, max_tokens=400): | |
| tokens = tokenizer.tokenize(text) | |
| chunks = [] | |
| current_chunk = [] | |
| for token in tokens: | |
| current_chunk.append(token) | |
| if len(current_chunk) >= max_tokens: | |
| chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) | |
| current_chunk = [] | |
| if current_chunk: | |
| chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) | |
| return chunks | |
| def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()): | |
| prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| torch.set_num_threads(num_threads) | |
| with ThreadPoolExecutor(max_workers=num_threads) as executor: | |
| future = executor.submit( | |
| model.generate, | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer.eos_token_id, | |
| top_k=50, | |
| num_return_sequences=1, | |
| do_sample=False | |
| ) | |
| output = future.result() | |
| result = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return result.split("### Correction ###")[1].strip() | |
| def process_text(user_message): | |
| chunks = split_text(user_message) | |
| corrected_chunks = [] | |
| for chunk in chunks: | |
| corrected_chunk = ocr_correction(chunk) | |
| corrected_chunks.append(corrected_chunk) | |
| corrected_text = ' '.join(corrected_chunks) | |
| return diff_texts(user_message, corrected_text) | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| gr.Markdown(description) | |
| text_input = gr.Textbox( | |
| label="↘️Enter 👁️OCR'ed Text Outputs Here", | |
| info="""Hi there, ;fémy name à`gis tonic 45and i like to ride my vpotz""", | |
| lines=5, | |
| ) | |
| process_button = gr.Button("Correct using 📲⌚🎅🏻OCRonos") | |
| text_output = gr.HighlightedText( | |
| label="📲⌚🎅🏻OCRonos Correction:", | |
| combine_adjacent=True, | |
| show_legend=True, | |
| color_map={"+": "green", "-": "red"} | |
| ) | |
| process_button.click(process_text, inputs=text_input, outputs=[text_output]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |