File size: 4,781 Bytes
2b20905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BartTokenizer, BartForConditionalGeneration, pipeline
import numpy as np
import torch
from textstat import textstat



MAX_LEN = 256
NUM_BEAMS = 4
EARLY_STOPPING = True
N_OUT = 4



cwi_tok = AutoTokenizer.from_pretrained('twigs/cwi-regressor')
cwi_model = AutoModelForSequenceClassification.from_pretrained('twigs/cwi-regressor')
simpl_tok = BartTokenizer.from_pretrained('twigs/bart-text2text-simplifier')
simpl_model = BartForConditionalGeneration.from_pretrained('twigs/bart-text2text-simplifier')
cwi_pipe = pipeline('text-classification', model=cwi_model, tokenizer=cwi_tok, function_to_apply='none', device=0)
fill_pipe = pipeline('fill-mask', model=simpl_model, tokenizer=simpl_tok, top_k=1, device=0)


def id_replace_complex(s, threshold=0.4):

  # get all tokens
  tokens = re.compile('\w+').findall(s)
  cands = [f"{t}. {s}" for t in tokens]
  # get complex tokens
  # if score >= threshold select tokens[idx]
  compl_tok = [tokens[idx] for idx, x in enumerate(
      cwi_pipe(cands)) if x['score'] >= threshold]

  # potentially parallelizable, depends on desired behaviour
  for t in compl_tok:
    idx = s.index(t)
    s = s[:idx] + '<mask>' + s[idx+len(t):]
    # get top candidate for mask fill in complex token
    s = fill_pipe(s)[0]['sequence']

  return s, compl_tok


def generate_candidate_text(s, model, tokenizer, tokenized=False):

  out = simpl_tok([s], max_length=256, padding="max_length",  truncation=True, return_tensors='pt').to('cuda') if not tokenized else s

  generated_ids = model.generate(
      input_ids=out['input_ids'],
      attention_mask=out['attention_mask'],
      use_cache=True,
      decoder_start_token_id=simpl_model.config.pad_token_id,
      num_beams=NUM_BEAMS,
      max_length=MAX_LEN,
      early_stopping=EARLY_STOPPING,
      num_return_sequences=N_OUT
  )

  return  [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[
      1:] for ids in generated_ids]


def rank_candidate_text(sentences):
    """ Currently being done with simple FKGL """
    fkgl_scores = [textstat.flesch_kincaid_grade(s) for s in sentences]
    return sentences[np.argmin(fkgl_scores)]
    

def full_pipeline(source, simpl_model, simpl_tok, tokens, lexical=False):
  
  modified, complex_words  = id_replace_complex(source, threshold=0.2) if lexical else source, None
  cands = generate_candidate_text(tokens+modified, simpl_model, simpl_tok)
  output = rank_candidate_text(cands)
  return output, complex_words


aug_tok = ['c_', 'lev_', 'dep_', 'rank_', 'rat_', 'n_syl_']
tokens = ['CharRatio', 'LevSim', 'DependencyTreeDepth',
          'WordComplexity', 'WordRatio']

default_values = [0.8, 0.6, 0.9, 0.8, 0.9, 1.9]
user_values = default_values
tok_values = dict((t, default_values[idx]) for idx, t in enumerate(tokens))

example_sentences = ["A matchbook is a small cardboard folder (matchcover) enclosing a quantity of matches and having a coarse striking surface on the exterior.",
                     "If there are no strong land use controls, buildings are built along a bypass, converting it into an ordinary town road, and the bypass may eventually become as congested as the local streets it was intended to avoid.",
                     "Plot Captain Caleb Holt (Kirk Cameron) is a firefighter in Albany, Georgia and firmly keeps the cardinal rule of all firemen, \"Never leave your partner behind\".",
                     "Britpop emerged from the British independent music scene of the early 1990s and was characterised by bands influenced by British guitar pop music of the 1960s and 1970s."]


def main():

    st.title("Make it Simple")

    with st.expander("Example sentences"):
        for s in example_sentences:
            st.code(body=s)

    with st.form(key="form"):
        input_sentence = st.text_area("Original sentence")
        tok = st.multiselect(
            label="Tokens to augment the sentence", options=tokens, default=tokens)
        if (tok):
            st.text("Select the desired intensity")
            for idx, t in enumerate(tok):
                user_values[idx] = st.slider(
                    t, min_value=0., max_value=1., value=tok_values[t], step=0.1, key=t)

        submit = st.form_submit_button("Process")
        if (submit):
            
            tokens = [t+str(v) for t, v in zip(aug_tok, user_values)]
            output, words = full_pipeline(input_sentence, simpl_model, simpl_tok, tokens)
            
            with st.container():
                st.write("Original sentence:")
                st.write(input_sentence)
                st.write("Output sentence:")
                st.write(output)


if __name__ == '__main__':
    main()