twigs commited on
Commit
98ce4ad
1 Parent(s): b4f09d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -20,11 +20,10 @@ simpl_model = BartForConditionalGeneration.from_pretrained(
20
  'twigs/bart-text2text-simplifier')
21
  cwi_pipe = pipeline('text-classification', model=cwi_model,
22
  tokenizer=cwi_tok, function_to_apply='none')
23
- fill_pipe = pipeline('fill-mask', model=simpl_model,
24
- tokenizer=simpl_tok, top_k=1)
25
 
26
 
27
- def id_replace_complex(s, threshold=0.4):
28
 
29
  # get all tokens
30
  tokens = re.compile('\w+').findall(s)
@@ -34,19 +33,16 @@ def id_replace_complex(s, threshold=0.4):
34
  compl_tok = [tokens[idx] for idx, x in enumerate(
35
  cwi_pipe(cands)) if x['score'] >= threshold]
36
 
37
- replacements = []
38
- # potentially parallelizable, depends on desired behaviour
39
- for t in compl_tok:
40
- idx = s.index(t)
41
- s = s[:idx] + '<mask>' + s[idx+len(t):]
42
- # get top candidate for mask fill in complex token
43
- top_result = fill_pipe(s)[0]
44
- s = top_result['sequence']
45
- print(s)
46
- replacements.append(top_result['token_str'])
47
  return s, compl_tok, replacements
48
 
49
-
50
  def generate_candidate_text(s, model, tokenizer, tokenized=False):
51
 
52
 
 
20
  'twigs/bart-text2text-simplifier')
21
  cwi_pipe = pipeline('text-classification', model=cwi_model,
22
  tokenizer=cwi_tok, function_to_apply='none')
23
+ fill_pipe = pipeline('fill-mask', top_k=1)
 
24
 
25
 
26
+ def id_replace_complex(s, threshold=0.2):
27
 
28
  # get all tokens
29
  tokens = re.compile('\w+').findall(s)
 
33
  compl_tok = [tokens[idx] for idx, x in enumerate(
34
  cwi_pipe(cands)) if x['score'] >= threshold]
35
 
36
+ masked = [s[:s.index(t)] + '<mask>' + s[s.index(t)+len(t):] for t in compl_tok]
37
+ cands = fill_pipe(masked)
38
+ replacements = [el['token_str'][1:] if type(el) == dict else el[0]['token_str'][1:] for el in cands]
39
+
40
+ for i, el in enumerate(compl_tok):
41
+ idx = s.index(el)
42
+ s = s[:idx] + replacements[i] + s[idx+len(el):]
43
+
 
 
44
  return s, compl_tok, replacements
45
 
 
46
  def generate_candidate_text(s, model, tokenizer, tokenized=False):
47
 
48