import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification, BartTokenizer, BartForConditionalGeneration, pipeline import numpy as np import torch import re from textstat import textstat MAX_LEN = 256 NUM_BEAMS = 4 EARLY_STOPPING = True N_OUT = 4 cwi_tok = AutoTokenizer.from_pretrained('twigs/cwi-regressor') cwi_model = AutoModelForSequenceClassification.from_pretrained( 'twigs/cwi-regressor') simpl_tok = BartTokenizer.from_pretrained('twigs/bart-text2text-simplifier') simpl_model = BartForConditionalGeneration.from_pretrained( 'twigs/bart-text2text-simplifier') cwi_pipe = pipeline('text-classification', model=cwi_model, tokenizer=cwi_tok, function_to_apply='none') fill_pipe = pipeline('fill-mask', top_k=1) def id_replace_complex(s, threshold=0.2): # get all tokens tokens = re.compile('\w+').findall(s) cands = [f"{t}. {s}" for t in tokens] # get complex tokens # if score >= threshold select tokens[idx] compl_tok = [tokens[idx] for idx, x in enumerate( cwi_pipe(cands)) if x['score'] >= threshold] masked = [s[:s.index(t)] + '' + s[s.index(t)+len(t):] for t in compl_tok] cands = fill_pipe(masked) # structure is different in 1 vs n complex words replacements = [el['token_str'] if type( el) == dict else el[0]['token_str'] for el in cands] # some tokens get prefixed with space replacements = [tok if tok.find(' ') == -1 else tok[1:] for tok in replacements] for i, el in enumerate(compl_tok): idx = s.index(el) s = s[:idx] + replacements[i] + s[idx+len(el):] return s, compl_tok, replacements def generate_candidate_text(s, model, tokenizer, tokenized=False): out = simpl_tok([s], max_length=256, padding="max_length", truncation=True, return_tensors='pt') if not tokenized else s generated_ids = model.generate( input_ids=out['input_ids'], attention_mask=out['attention_mask'], use_cache=True, decoder_start_token_id=simpl_model.config.pad_token_id, num_beams=NUM_BEAMS, max_length=MAX_LEN, early_stopping=EARLY_STOPPING, num_return_sequences=N_OUT ) return [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[ 1:] for ids in generated_ids] def rank_candidate_text(sentences): fkgl_scores = [textstat.flesch_kincaid_grade(s) for s in sentences] return sentences[np.argmin(fkgl_scores)] def full_pipeline(source, simpl_model, simpl_tok, tokens, lexical=False): modified, complex_words, replacements = id_replace_complex(source, threshold=0.2) if lexical else (source, None, None) cands = generate_candidate_text(tokens+modified, simpl_model, simpl_tok) output = rank_candidate_text(cands) return output, complex_words, replacements def main(): aug_tok = ['c_', 'lev_', 'dep_', 'rank_', 'rat_', 'n_syl_'] base_tokens = ['CharRatio', 'LevSim', 'DependencyTreeDepth', 'WordComplexity', 'WordRatio', 'NumberOfSyllables'] default_values = [0.8, 0.6, 0.9, 0.8, 0.9, 1.9] user_values = default_values tok_values = dict((t, default_values[idx]) for idx, t in enumerate(base_tokens)) example_sentences = ["A matchbook is a small cardboard folder (matchcover) enclosing a quantity of matches and having a coarse striking surface on the exterior.", "If there are no strong land use controls, buildings are built along a bypass, converting it into an ordinary town road, and the bypass may eventually become as congested as the local streets it was intended to avoid.", "Plot Captain Caleb Holt (Kirk Cameron) is a firefighter in Albany, Georgia and firmly keeps the cardinal rule of all firemen, \"Never leave your partner behind\".", "Britpop emerged from the British independent music scene of the early 1990s and was characterised by bands influenced by British guitar pop music of the 1960s and 1970s."] st.title("Make it Simple") with st.expander("Example sentences"): for s in example_sentences: st.code(body=s) with st.form(key="simplify"): input_sentence = st.text_area("Original sentence") lexical = st.checkbox("Identify and replace complex words", value=True) tok = st.multiselect( label="Tokens to augment the sentence", options=base_tokens, default=base_tokens) if (tok): st.text("Select the desired intensity") for idx, t in enumerate(tok): user_values[idx] = st.slider( t, min_value=0., max_value=1., value=tok_values[t], step=0.1, key=t) submit = st.form_submit_button("Process") if (submit): tokens = " ".join([t+str(v) for t, v in zip(aug_tok, user_values)]) + " " output, words, replacements = full_pipeline(input_sentence, simpl_model, simpl_tok, tokens, lexical) c1, c2, c3 = st.columns([1,1,2]) with c1: st.markdown("#### Words identified as complex") if words: for w in words: st.markdown(f"* {w}") else: st.markdown("None :smile:") with c2: st.markdown("#### Their mask-predicted replacement") if replacements: for w in replacements: st.markdown(f"* {w}") else: st.markdown("None :smile:") with c3: st.markdown(f"#### Original Sentence:\n > {input_sentence}") st.markdown(f"#### Output Sentence:\n > {output}") if __name__ == '__main__': main()