danseith commited on
Commit
1ca245c
1 Parent(s): e20eecd

Added more sanity checking.

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -134,14 +134,13 @@ scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
134
 
135
  def unmask(text, temp, rounds):
136
  sampling = 'multi'
137
- successful_iters = 0
138
  for round in range(rounds):
139
  text = add_mask(text, size=1)
140
  split_text = text.split()
141
  res = scrambler(text, temp=temp, top_k=10)
142
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
143
  out = {item["token_str"]: item["score"] for item in res}
144
- score_to_str = {out[k]:k for k in out.keys()}
145
  score_list = list(score_to_str.keys())
146
  if sampling == 'multi':
147
  idx = np.argmax(np.random.multinomial(1, score_list, 1))
@@ -151,12 +150,11 @@ def unmask(text, temp, rounds):
151
  new_token = score_to_str[score]
152
  if len(list(new_token)) < 2:
153
  continue
154
- split_text[mask_pos] = new_token
155
  text = ' '.join(split_text)
156
- successful_iters += 1
157
  text = list(text)
158
  text[0] = text[0].upper()
159
- return ''.join(text) + str(successful_iters)
160
 
161
 
162
  textbox = gr.Textbox(label="Example prompts", lines=5)
 
134
 
135
  def unmask(text, temp, rounds):
136
  sampling = 'multi'
 
137
  for round in range(rounds):
138
  text = add_mask(text, size=1)
139
  split_text = text.split()
140
  res = scrambler(text, temp=temp, top_k=10)
141
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
142
  out = {item["token_str"]: item["score"] for item in res}
143
+ score_to_str = {out[k] : k for k in out.keys()}
144
  score_list = list(score_to_str.keys())
145
  if sampling == 'multi':
146
  idx = np.argmax(np.random.multinomial(1, score_list, 1))
 
150
  new_token = score_to_str[score]
151
  if len(list(new_token)) < 2:
152
  continue
153
+ split_text[mask_pos] = '*' + new_token + '*'
154
  text = ' '.join(split_text)
 
155
  text = list(text)
156
  text[0] = text[0].upper()
157
+ return ''.join(text)
158
 
159
 
160
  textbox = gr.Textbox(label="Example prompts", lines=5)