File size: 5,817 Bytes
2b20905
 
 
 
e38de8a
2b20905
 
 
 
 
 
 
 
 
 
e38de8a
 
2b20905
e38de8a
 
 
b48a5e1
98ce4ad
2b20905
 
98ce4ad
2b20905
 
 
 
 
 
 
 
21ed34f
98ce4ad
 
a2f49fc
 
 
 
 
 
 
98ce4ad
 
 
 
21ed34f
2b20905
 
 
7667cd9
e38de8a
8c6ed30
2b20905
 
 
 
 
 
 
 
 
 
 
 
e38de8a
2b20905
 
 
 
 
 
e38de8a
2b20905
 
fc4ec36
21ed34f
fc4ec36
2b20905
21ed34f
fc4ec36
e38de8a
2b20905
e38de8a
bfb18f7
e38de8a
2b20905
e38de8a
 
c47b01d
2b20905
e38de8a
 
 
 
2b20905
 
 
 
 
 
 
 
e38de8a
 
2b20905
7a49f84
 
 
2b20905
a849aa0
2b20905
 
 
 
 
 
 
 
 
d934578
21ed34f
 
 
 
e38de8a
 
 
241fc6e
 
 
 
 
 
e38de8a
 
21ed34f
 
 
 
 
 
 
 
 
e38de8a
 
2b20905
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BartTokenizer, BartForConditionalGeneration, pipeline
import numpy as np
import torch
import re
from textstat import textstat


MAX_LEN = 256
NUM_BEAMS = 4
EARLY_STOPPING = True
N_OUT = 4


cwi_tok = AutoTokenizer.from_pretrained('twigs/cwi-regressor')
cwi_model = AutoModelForSequenceClassification.from_pretrained(
    'twigs/cwi-regressor')
simpl_tok = BartTokenizer.from_pretrained('twigs/bart-text2text-simplifier')
simpl_model = BartForConditionalGeneration.from_pretrained(
    'twigs/bart-text2text-simplifier')
cwi_pipe = pipeline('text-classification', model=cwi_model,
                    tokenizer=cwi_tok, function_to_apply='none')
fill_pipe = pipeline('fill-mask', top_k=1)


def id_replace_complex(s, threshold=0.2):

  # get all tokens
  tokens = re.compile('\w+').findall(s)
  cands = [f"{t}. {s}" for t in tokens]
  # get complex tokens
  # if score >= threshold select tokens[idx]
  compl_tok = [tokens[idx] for idx, x in enumerate(
      cwi_pipe(cands)) if x['score'] >= threshold]
  
  masked = [s[:s.index(t)] + '<mask>' + s[s.index(t)+len(t):] for t in compl_tok]
  cands = fill_pipe(masked)
  # structure is different in 1 vs n complex words
  replacements = [el['token_str'] if type(
      el) == dict else el[0]['token_str'] for el in cands]
  # some tokens get prefixed with space
  replacements = [tok if tok.find(' ') == -1 else tok[1:]
                  for tok in replacements]

  for i, el in enumerate(compl_tok):
    idx = s.index(el)
    s = s[:idx] + replacements[i] + s[idx+len(el):]
  
  return s, compl_tok, replacements

def generate_candidate_text(s, model, tokenizer, tokenized=False):


  out = simpl_tok([s], max_length=256, padding="max_length",  truncation=True,
                  return_tensors='pt') if not tokenized else s

  generated_ids = model.generate(
      input_ids=out['input_ids'],
      attention_mask=out['attention_mask'],
      use_cache=True,
      decoder_start_token_id=simpl_model.config.pad_token_id,
      num_beams=NUM_BEAMS,
      max_length=MAX_LEN,
      early_stopping=EARLY_STOPPING,
      num_return_sequences=N_OUT
  )

  return [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[
      1:] for ids in generated_ids]


def rank_candidate_text(sentences):
    fkgl_scores = [textstat.flesch_kincaid_grade(s) for s in sentences]
    return sentences[np.argmin(fkgl_scores)]


def full_pipeline(source, simpl_model, simpl_tok, tokens, lexical=False):
  
  modified, complex_words, replacements  = id_replace_complex(source, threshold=0.2) if lexical else (source, None, None)
  cands = generate_candidate_text(tokens+modified, simpl_model, simpl_tok)
  output = rank_candidate_text(cands)
  return output, complex_words, replacements
  
def main():

    aug_tok = ['c_', 'lev_', 'dep_', 'rank_', 'rat_', 'n_syl_']
    base_tokens = ['CharRatio', 'LevSim', 'DependencyTreeDepth',
            'WordComplexity', 'WordRatio', 'NumberOfSyllables']

    default_values = [0.8, 0.6, 0.9, 0.8, 0.9, 1.9]
    user_values = default_values
    tok_values = dict((t, default_values[idx]) for idx, t in enumerate(base_tokens))

    example_sentences = ["A matchbook is a small cardboard folder (matchcover) enclosing a quantity of matches and having a coarse striking surface on the exterior.",
                        "If there are no strong land use controls, buildings are built along a bypass, converting it into an ordinary town road, and the bypass may eventually become as congested as the local streets it was intended to avoid.",
                        "Plot Captain Caleb Holt (Kirk Cameron) is a firefighter in Albany, Georgia and firmly keeps the cardinal rule of all firemen, \"Never leave your partner behind\".",
                        "Britpop emerged from the British independent music scene of the early 1990s and was characterised by bands influenced by British guitar pop music of the 1960s and 1970s."]


    st.title("Make it Simple")

    with st.expander("Example sentences"):
        for s in example_sentences:
            st.code(body=s)


    with st.form(key="simplify"):
        input_sentence = st.text_area("Original sentence")
    
        lexical = st.checkbox("Identify and replace complex words", value=True)

        tok = st.multiselect(
            label="Tokens to augment the sentence", options=base_tokens, default=base_tokens)
        if (tok):
            st.text("Select the desired intensity")
            for idx, t in enumerate(tok):
                user_values[idx] = st.slider(
                    t, min_value=0., max_value=1., value=tok_values[t], step=0.1, key=t)

        submit = st.form_submit_button("Process")
        if (submit):
            
            tokens = " ".join([t+str(v) for t, v in zip(aug_tok, user_values)]) + " "
            output, words, replacements = full_pipeline(input_sentence, simpl_model, simpl_tok, tokens, lexical)
            
    
            c1, c2, c3 = st.columns([1,1,2])

            with c1:
                st.markdown("#### Words identified as complex")
                if words:
                    for w in words:
                        st.markdown(f"* {w}")

                else:
                    st.markdown("None :smile:")

            with c2:
                st.markdown("#### Their mask-predicted replacement")
                if replacements:
                    for w in replacements:
                        st.markdown(f"* {w}")

                else:
                    st.markdown("None :smile:")

            with c3:
                st.markdown(f"#### Original Sentence:\n > {input_sentence}") 
                st.markdown(f"#### Output Sentence:\n > {output}") 


if __name__ == '__main__':
    main()