from textwrap import wrap from transformers import pipeline import nlpaug.augmenter.char as nac import streamlit as st st.markdown('# ByT5 Dutch OCR Corrector :pill:') st.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.') st.markdown(""" To use this: - Enter a text with OCR mistakes and hit 'unscramble':point_down: - Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""") @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False) def load_model(): with st.spinner('Please wait for the model to load...'): ocr_pipeline=pipeline( 'text2text-generation', model='ml6team/byt5-base-dutch-ocr-correction', tokenizer='ml6team/byt5-base-dutch-ocr-correction' ) return ocr_pipeline ocr_pipeline = load_model() if 'text' not in st.session_state: st.session_state['text'] = "" left_area, right_area = st.columns(2) # Format the left area left_area.header("Input") form = left_area.form(key='ocrcorrector') placeholder = form.empty() placeholder.empty() input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text') scramble_button = form.form_submit_button(label='Scramble') submit_button = form.form_submit_button(label='Unscramble') # Right area right_area.header("Output") if scramble_button: aug = nac.OcrAug() st.session_state.text = st.session_state.input_text base_text = st.session_state.text augmented_data = aug.augment(base_text) st.session_state.text = augmented_data del st.session_state.input_text placeholder.empty() input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text') if submit_button: base_text = st.session_state.input_text output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))]) right_area.markdown('#####') right_area.text_area(value=output_text, label="Corrected text:")