import streamlit as st import os import torch from transformers import pipeline from transformers import AutoConfig, AutoTokenizer, AutoModelForTokenClassification from libs.normalizer import Normalizer from libs.examples import LANGUAGES, EXAMPLES from libs.dummy import outputs as dummy_outputs from libs.utils import local_css, remote_css import meta MODELS = { "English (en)": "m3hrdadfi/typo-detector-distilbert-en", "Persian (fa)": "m3hrdadfi/typo-detector-distilbert-fa", "Icelandic (is)": "m3hrdadfi/typo-detector-distilbert-is", } API_TOKEN = os.environ.get("API_TOKEN") class TypoDetector: def __init__( self, model_name_or_path: str = "m3hrdadfi/typo-detector-distilbert-en" ) -> None: self.debug = False self.dummy_outputs = dummy_outputs self.model_name_or_path = model_name_or_path self.task_name = "token-classification" self.tokenizer = None self.config = None self.model = None self.nlp = None self.normalizer = None def load(self, api_token=None): api_token = api_token if api_token else False if not self.debug: self.config = AutoConfig.from_pretrained(self.model_name_or_path, use_auth_token=api_token) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_auth_token=api_token) self.model = AutoModelForTokenClassification.from_pretrained( self.model_name_or_path, config=self.config, use_auth_token=api_token) self.nlp = pipeline( self.task_name, model=self.model, tokenizer=self.tokenizer, aggregation_strategy="average" ) self.normalizer = Normalizer() def detect(self, sentence): if self.debug: return self.dummy_outputs[0] typos = [sentence[r["start"]: r["end"]] for r in self.nlp(sentence)] detected = sentence for typo in typos: detected = detected.replace(typo, f'{typo}') return detected @st.cache(allow_output_mutation=True) def load_typo_detectors(): en_detector = TypoDetector(MODELS["English (en)"]) en_detector.load() is_detector = TypoDetector(MODELS["Icelandic (is)"]) is_detector.load() fa_detector = TypoDetector(MODELS["Persian (fa)"]) fa_detector.load(api_token=API_TOKEN) return { "en": en_detector, "fa": fa_detector, "is": is_detector } def main(): st.set_page_config( page_title="Typo Detector", page_icon="⚡", layout="wide", initial_sidebar_state="expanded" ) remote_css("https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/dist/font-face.css") local_css("assets/style.css") detectors = load_typo_detectors() col1, col2 = st.beta_columns([6, 4]) with col2: st.markdown(meta.INFO, unsafe_allow_html=True) with col1: language = st.selectbox( 'Examples (select from this list)', LANGUAGES, index=0 ) detector = detectors[language] is_rtl = "rtl" if language == "fa" else "ltr" if language == "fa": local_css("assets/rtl.css") else: local_css("assets/ltr.css") prompts = list(EXAMPLES[language].keys()) + ["Custom"] prompt = st.selectbox( 'Examples (select from this list)', prompts, # index=len(prompts) - 1, index=0 ) if prompt == "Custom": prompt_box = "" else: prompt_box = EXAMPLES[language][prompt] text = st.text_area( 'Insert your text: ', detector.normalizer(prompt_box), height=100 ) text = detector.normalizer(text) entered_text = st.empty() detect_typos = st.button('Detect Typos !') st.markdown( "
", unsafe_allow_html=True ) if detect_typos: words = text.split() with st.spinner("Detecting..."): if not len(words) > 3: entered_text.markdown( "Insert your text (at least three words)" ) else: detected = detector.detect(text) detected = f"

{detected}

" st.markdown( detected, unsafe_allow_html=True ) if __name__ == '__main__': main()