Spaces:
Build error
Build error
danseith
commited on
Commit
•
1ca245c
1
Parent(s):
e20eecd
Added more sanity checking.
Browse files
app.py
CHANGED
@@ -134,14 +134,13 @@ scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
|
134 |
|
135 |
def unmask(text, temp, rounds):
|
136 |
sampling = 'multi'
|
137 |
-
successful_iters = 0
|
138 |
for round in range(rounds):
|
139 |
text = add_mask(text, size=1)
|
140 |
split_text = text.split()
|
141 |
res = scrambler(text, temp=temp, top_k=10)
|
142 |
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
143 |
out = {item["token_str"]: item["score"] for item in res}
|
144 |
-
score_to_str = {out[k]:k for k in out.keys()}
|
145 |
score_list = list(score_to_str.keys())
|
146 |
if sampling == 'multi':
|
147 |
idx = np.argmax(np.random.multinomial(1, score_list, 1))
|
@@ -151,12 +150,11 @@ def unmask(text, temp, rounds):
|
|
151 |
new_token = score_to_str[score]
|
152 |
if len(list(new_token)) < 2:
|
153 |
continue
|
154 |
-
split_text[mask_pos] = new_token
|
155 |
text = ' '.join(split_text)
|
156 |
-
successful_iters += 1
|
157 |
text = list(text)
|
158 |
text[0] = text[0].upper()
|
159 |
-
return ''.join(text)
|
160 |
|
161 |
|
162 |
textbox = gr.Textbox(label="Example prompts", lines=5)
|
|
|
134 |
|
135 |
def unmask(text, temp, rounds):
|
136 |
sampling = 'multi'
|
|
|
137 |
for round in range(rounds):
|
138 |
text = add_mask(text, size=1)
|
139 |
split_text = text.split()
|
140 |
res = scrambler(text, temp=temp, top_k=10)
|
141 |
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
142 |
out = {item["token_str"]: item["score"] for item in res}
|
143 |
+
score_to_str = {out[k] : k for k in out.keys()}
|
144 |
score_list = list(score_to_str.keys())
|
145 |
if sampling == 'multi':
|
146 |
idx = np.argmax(np.random.multinomial(1, score_list, 1))
|
|
|
150 |
new_token = score_to_str[score]
|
151 |
if len(list(new_token)) < 2:
|
152 |
continue
|
153 |
+
split_text[mask_pos] = '*' + new_token + '*'
|
154 |
text = ' '.join(split_text)
|
|
|
155 |
text = list(text)
|
156 |
text[0] = text[0].upper()
|
157 |
+
return ''.join(text)
|
158 |
|
159 |
|
160 |
textbox = gr.Textbox(label="Example prompts", lines=5)
|