working app
Browse files- app.py +491 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from numpy import ndarray
|
5 |
+
import pandas as pd
|
6 |
+
import torch as T
|
7 |
+
from torch import Tensor, device
|
8 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig, AutoModel
|
9 |
+
from nltk.corpus import stopwords
|
10 |
+
from nltk.stem.porter import *
|
11 |
+
import json
|
12 |
+
import nltk
|
13 |
+
from nltk import FreqDist
|
14 |
+
from nltk.corpus import gutenberg
|
15 |
+
import urllib.request
|
16 |
+
from string import punctuation
|
17 |
+
from math import log,exp,sqrt
|
18 |
+
import random
|
19 |
+
|
20 |
+
nltk.download('stopwords')
|
21 |
+
nltk.download('gutenberg')
|
22 |
+
|
23 |
+
cos = T.nn.CosineSimilarity(dim=0)
|
24 |
+
|
25 |
+
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-info.txt")
|
26 |
+
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-unix.txt")
|
27 |
+
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/profanity.json")
|
28 |
+
|
29 |
+
#gdown.download('https://drive.google.com/uc?id=16j6oQbqIUfdY1kMFOonXVDdG7A0C6CXD&confirm=t',use_cookies=True)
|
30 |
+
#gdown.download(id='13-3DyP4Df1GzrdQ_W4fLhPYAA1Gscg1j',use_cookies=True)
|
31 |
+
#gdown.download(id='180X6ztER2lKVP_dKinSJNE0XRtmnixAM',use_cookies=True)
|
32 |
+
|
33 |
+
CONTEXTUAL_EMBEDDING_LAYERS = [12]
|
34 |
+
EXTEND_SUBWORDS=True
|
35 |
+
MAX_SUBWORDS=1
|
36 |
+
DEBUG_OUTPUT=True
|
37 |
+
DISTRACTORS_FROM_TEXT=False
|
38 |
+
MIN_SENT_WORDS = 7
|
39 |
+
|
40 |
+
# Frequencies are used to decide if a distractor candidate might be a subword
|
41 |
+
stemmer = PorterStemmer()
|
42 |
+
freq = FreqDist(i.lower() for i in gutenberg.words())
|
43 |
+
print(freq.most_common()[:5])
|
44 |
+
|
45 |
+
words_unix = set(line.strip() for line in open('dict-unix.txt'))
|
46 |
+
words_info = set(line.strip() for line in open('dict-info.txt'))
|
47 |
+
words_small = words_unix.intersection(words_info)
|
48 |
+
words_large = words_unix.union(words_info)
|
49 |
+
f = open('profanity.json')
|
50 |
+
profanity = json.load(f)
|
51 |
+
|
52 |
+
import stanza
|
53 |
+
|
54 |
+
nlp = stanza.Pipeline(lang='en', processors='tokenize')#, model_dir='/data/ondovbd/stanza_resources')
|
55 |
+
|
56 |
+
nltk.download('punkt')
|
57 |
+
nltk_sent_toker = nltk.data.load('tokenizers/punkt/english.pickle')
|
58 |
+
|
59 |
+
def is_word(str):
|
60 |
+
'''Check if word exists in dictionary'''
|
61 |
+
splt = str.lower().split("'")
|
62 |
+
if len(splt) > 2:
|
63 |
+
return False
|
64 |
+
elif len(splt) == 2:
|
65 |
+
return is_word(splt[0]) and (splt[1] in ['t','nt','s','ll'])
|
66 |
+
elif '-' in str:
|
67 |
+
for word in str.split('-'):
|
68 |
+
if not is_word(word):
|
69 |
+
return False
|
70 |
+
return True
|
71 |
+
else:
|
72 |
+
return str.lower() in words_unix or str.lower() in words_info
|
73 |
+
|
74 |
+
def get_emb(snt_toks, tgt_toks, layers=None):
|
75 |
+
'''Embeds a group of subword tokens in place of a mask, using the entire
|
76 |
+
sentence for context. Returns the average of the target token embeddings,
|
77 |
+
which are summed over the hidden layers.
|
78 |
+
|
79 |
+
snt_toks: the tokenized sentence, including the mask token
|
80 |
+
tgt_toks: the tokens (subwords) to replace the mask token
|
81 |
+
layers (optional): which hidden layers to sum (list of indices)'''
|
82 |
+
mask_idx = snt_toks.index(toker.mask_token_id)
|
83 |
+
snt_toks = snt_toks.copy()
|
84 |
+
|
85 |
+
while mask_idx + len(tgt_toks)-1 >= 512:
|
86 |
+
# Shift text by 100 words
|
87 |
+
snt_toks = snt_toks[100:]
|
88 |
+
mask_idx -= 100
|
89 |
+
|
90 |
+
snt_toks[mask_idx:mask_idx+1] = tgt_toks
|
91 |
+
snt_toks = snt_toks[:512]
|
92 |
+
with T.no_grad():
|
93 |
+
if T.cuda.is_available():
|
94 |
+
T.tensor([snt_toks]).cuda()
|
95 |
+
T.tensor([[1]*len(snt_toks)]).cuda()
|
96 |
+
output = model(T.tensor([snt_toks]), T.tensor([[1]*len(snt_toks)]), output_hidden_states=True)
|
97 |
+
layers = CONTEXTUAL_EMBEDDING_LAYERS if layers is None else layers
|
98 |
+
output = T.stack([output.hidden_states[i] for i in layers]).sum(0).squeeze()
|
99 |
+
# Only select the tokens that constitute the requested word
|
100 |
+
return output[mask_idx:mask_idx+len(tgt_toks)].mean(dim=0)
|
101 |
+
|
102 |
+
def energy(ctx, scaled_dists, scaled_sims, choices, words, ans):
|
103 |
+
|
104 |
+
#Calculate and add cosine similarity scores
|
105 |
+
'''Cost function to help choose best distractors'''
|
106 |
+
#e = [embs[i] for i in choices] #+ [sem_emb_ans]
|
107 |
+
#w = [words[i] for i in choices] #+ [ans]
|
108 |
+
|
109 |
+
hm_sim = 0
|
110 |
+
e_ctx = 0
|
111 |
+
for i in choices:
|
112 |
+
hm_sim += 1./scaled_sims[i]
|
113 |
+
e_ctx += ctx[i]
|
114 |
+
|
115 |
+
e_sim = float(len(choices))/hm_sim
|
116 |
+
|
117 |
+
hm_emb = 0
|
118 |
+
count = 0
|
119 |
+
c = choices + [len(ctx)]
|
120 |
+
for i in range(len(c)):
|
121 |
+
for j in range(i):
|
122 |
+
d = scaled_dists['%s-%s'%(max(c[i],c[j]), min(c[i], c[j]))]
|
123 |
+
#print(c[i], c[j], d)
|
124 |
+
hm_emb += 1./d
|
125 |
+
count += 1
|
126 |
+
e_emb = float(count)/hm_emb
|
127 |
+
return float(e_emb), e_ctx, float(e_sim)
|
128 |
+
|
129 |
+
def anneal(probs_sent_context, probs_para_context, embs, emb_ans, words, k, ans):
|
130 |
+
'''find k distractor indices that are optimally high probability and distant
|
131 |
+
in embedding space'''
|
132 |
+
# probs_sent_context = T.as_tensor(probs_sent_context) / sum(probs_sent_context)
|
133 |
+
m = len(probs_sent_context)
|
134 |
+
# probs_para_context = T.as_tensor(probs_para_context) / sum(probs_para_context)
|
135 |
+
its = 1000
|
136 |
+
n = len(probs_para_context)
|
137 |
+
choices = list(range(k))
|
138 |
+
|
139 |
+
dists = {}
|
140 |
+
embsa = embs + [emb_ans]
|
141 |
+
for i in range(len(embsa)):
|
142 |
+
for j in range(i):
|
143 |
+
dists['%s-%s'%(i,j)] = 1-cos(embsa[i], embsa[j]) # cosine "distance"
|
144 |
+
#print(words[i], words[j], 1-cos(embs[i], embs[j]))
|
145 |
+
|
146 |
+
dist_min = T.min(T.tensor(list(dists.values())))
|
147 |
+
dist_max = T.max(T.tensor(list(dists.values())))
|
148 |
+
for key, dist in dists.items():
|
149 |
+
dists[key] = (dist - dist_min)/(dist_max-dist_min)
|
150 |
+
|
151 |
+
sims = T.tensor([cos(emb_ans, emb) for emb in embs])
|
152 |
+
scaled_sims = (sims - T.min(sims))/(T.max(sims)-T.min(sims))
|
153 |
+
|
154 |
+
ctx = T.tensor(probs_sent_context).log()-ALPHA*T.tensor(probs_para_context).log()
|
155 |
+
ctx = (ctx-T.min(ctx))/(T.max(ctx)-T.min(ctx))
|
156 |
+
|
157 |
+
e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans)
|
158 |
+
e = e_ctx + BETA * e_emb
|
159 |
+
#e = SIM_ANNEAL_EMB_WEIGHT * e_emb + e_prob
|
160 |
+
for i in range(its):
|
161 |
+
t = 1.-(i)/its
|
162 |
+
mut_idx = random.randrange(k) # which choice to mutate
|
163 |
+
orig = choices[mut_idx]
|
164 |
+
new = orig
|
165 |
+
while (new in choices): # mutate choice until not in current list
|
166 |
+
new = random.randrange(m)
|
167 |
+
choices[mut_idx] = new
|
168 |
+
e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans)
|
169 |
+
e_new = e_ctx + BETA * e_emb
|
170 |
+
delta = e_new - e
|
171 |
+
exponent = delta/t
|
172 |
+
if exponent < -50:
|
173 |
+
exponent = -50 # avoid underflow
|
174 |
+
if delta > 0 or exp(exponent) > random.random():
|
175 |
+
e = e_new # accept new state
|
176 |
+
else:
|
177 |
+
choices[mut_idx] = orig
|
178 |
+
if DEBUG_OUTPUT:
|
179 |
+
print([words[j] for j in choices] + [ans], "e: %f"%(e))
|
180 |
+
return choices
|
181 |
+
|
182 |
+
def get_softmax_logits(toks, n_masks = 1, sub_ids = []):
|
183 |
+
# Tokenize text - Keep length of inpts at or below 512 (including answer token length artifically added at end)
|
184 |
+
msk_idx = toks.index(toker.mask_token_id)
|
185 |
+
toks = toks.copy()
|
186 |
+
toks[msk_idx:msk_idx+1] = [toker.mask_token_id] * n_masks + sub_ids
|
187 |
+
|
188 |
+
# If the masked_token is over 512 (excluding answer token length artifically added at end) tokens away
|
189 |
+
while msk_idx >= 512:
|
190 |
+
# Shift text by 100 words
|
191 |
+
toks = toks[100:]
|
192 |
+
msk_idx -= 100
|
193 |
+
toks = toks[:512]
|
194 |
+
# Find the predicted words for the fill-in-the-blank mask term based on sentence-context alone
|
195 |
+
with T.no_grad():
|
196 |
+
t=T.tensor([toks])
|
197 |
+
m=T.tensor([[1]*len(toks)])
|
198 |
+
if T.cuda.is_available():
|
199 |
+
t.cuda()
|
200 |
+
m.cuda()
|
201 |
+
output = model(t, m)
|
202 |
+
sm = T.softmax(output.logits[0, msk_idx:msk_idx+n_masks, :], dim=1)
|
203 |
+
return sm
|
204 |
+
|
205 |
+
e=1e-10
|
206 |
+
|
207 |
+
def candidates(text, answer):
|
208 |
+
'''Create list of unique distractors that does not include the actual answer'''
|
209 |
+
if DEBUG_OUTPUT:
|
210 |
+
print(text)
|
211 |
+
|
212 |
+
# Get only sentence with blanked text to tokenize
|
213 |
+
doc = nlp(text)
|
214 |
+
#sents = [sentence.text for sentence in doc.sentences]
|
215 |
+
sents = nltk_sent_toker.tokenize(text)
|
216 |
+
msk_snt_idx = [i for i in range(len(sents)) if toker.mask_token in sents[i]][0]
|
217 |
+
just_masked_sentence = sents[msk_snt_idx]
|
218 |
+
|
219 |
+
prv_snts = sents[:msk_snt_idx]
|
220 |
+
nxt_snts = sents[msk_snt_idx+1:]
|
221 |
+
|
222 |
+
if len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and len(prv_snts):
|
223 |
+
just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence])
|
224 |
+
|
225 |
+
while len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and (len(prv_snts) or len(nxt_snts)):
|
226 |
+
if T.rand(1) < 0.5 and len(prv_snts):
|
227 |
+
just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence])
|
228 |
+
elif len(nxt_snts):
|
229 |
+
just_masked_sentence = ' '.join([just_masked_sentence, nxt_snts.pop(0)])
|
230 |
+
|
231 |
+
ctx = just_masked_sentence
|
232 |
+
while len(ctx.split(' ')) < 3 * len(just_masked_sentence.split(' ')) and (len(prv_snts) or len(nxt_snts)):
|
233 |
+
if len(prv_snts):
|
234 |
+
ctx = ' '.join([prv_snts.pop(), ctx])
|
235 |
+
if len(nxt_snts):
|
236 |
+
ctx = ' '.join([ctx, nxt_snts.pop(0)])
|
237 |
+
|
238 |
+
# just_masked_sentence = ' '.join([just_masked_sentence.replace('<mask>', 'banana'),
|
239 |
+
# just_masked_sentence.replace('<mask>', 'banana'),
|
240 |
+
## just_masked_sentence,
|
241 |
+
# just_masked_sentence.replace('<mask>', 'banana'),
|
242 |
+
# just_masked_sentence.replace('<mask>', 'banana')])
|
243 |
+
#just_masked_sentence = ' '.join([just_masked_sentence, just_masked_sentence, just_masked_sentence, just_masked_sentence, just_masked_sentence])
|
244 |
+
|
245 |
+
tiled = just_masked_sentence
|
246 |
+
while len(tiled) < len(text):
|
247 |
+
tiled += ' ' + just_masked_sentence
|
248 |
+
just_masked_sentence = tiled
|
249 |
+
|
250 |
+
if DEBUG_OUTPUT:
|
251 |
+
print(ctx)
|
252 |
+
print(text)
|
253 |
+
print(just_masked_sentence)
|
254 |
+
toks_para = toker.encode(text)
|
255 |
+
toks_sent = toker.encode(just_masked_sentence)
|
256 |
+
# Get softmaxed logits from sentence alone and full-text
|
257 |
+
# sent_sm, sent_pos, sent_ids = get_span_logits(just_masked_sentence, answer)
|
258 |
+
# para_sm, para_pos, para_ids = get_span_logits(text, answer)
|
259 |
+
|
260 |
+
sent_sms_all = []
|
261 |
+
para_sms_all = []
|
262 |
+
para_sms_right = []
|
263 |
+
|
264 |
+
for i in range(MAX_SUBWORDS):
|
265 |
+
para_sms = get_softmax_logits(toks_para, i + 1)
|
266 |
+
para_sms_all.append(para_sms)
|
267 |
+
sent_sms = get_softmax_logits(toks_sent, i + 1)
|
268 |
+
sent_sms_all.append(sent_sms)
|
269 |
+
para_sms_right.append(T.exp((sent_sms[i].log()+para_sms[i].log())/2) * (suffix_mask_inv if i == 0 else suffix_mask))
|
270 |
+
|
271 |
+
# Create 2 lists: (1) notes highest probability for each token across n-mask lists if token is suffix and (2) notes number of mask terms to add
|
272 |
+
para_sm_best, para_pos_best = T.max(T.vstack(para_sms_right), 0)
|
273 |
+
|
274 |
+
distractors = []
|
275 |
+
stems = []
|
276 |
+
embs = []
|
277 |
+
sent_probs = []
|
278 |
+
para_probs = []
|
279 |
+
|
280 |
+
ans_stem = stemmer.stem(answer.lower())
|
281 |
+
|
282 |
+
emb_ans = get_emb(toks_para, toker(answer)['input_ids'][1:-1])
|
283 |
+
para_words = text.lower().split(' ')
|
284 |
+
blank_word_idx = [idx for idx, word in enumerate(text.split(' ')) if toker.mask_token in word][0] # Need to remove punctuation
|
285 |
+
if (blank_word_idx - 1) < 0:
|
286 |
+
prev_word = 'beforeanytext'
|
287 |
+
else:
|
288 |
+
prev_word = para_words[blank_word_idx-1]
|
289 |
+
if (blank_word_idx + 1) >= len(para_words):
|
290 |
+
next_word = 'afteralltext'
|
291 |
+
else:
|
292 |
+
next_word = para_words[blank_word_idx+1]
|
293 |
+
|
294 |
+
# Need to check if the token is outside of the tokenizer based on predictions being made at all
|
295 |
+
if len(para_sms_all[0]) > 0:
|
296 |
+
top_ctx = T.topk((sent_sms_all[0][0]*word_mask+e).log() - ALPHA * (para_sms_all[0][0]*word_mask+e).log(), len(para_sms_all[0][0]), dim=0)
|
297 |
+
para_top_ids = top_ctx.indices.tolist()
|
298 |
+
para_top_probs = top_ctx.values.tolist()
|
299 |
+
|
300 |
+
for i, id in enumerate(para_top_ids):
|
301 |
+
|
302 |
+
sub_ids = [int(id)] # cumulative list of subword token ids
|
303 |
+
dec = toker.decode(sub_ids).strip()
|
304 |
+
if DEBUG_OUTPUT:
|
305 |
+
print('Trying:', dec)
|
306 |
+
#print(para_pos[id])
|
307 |
+
#if para_pos_best[id] > 0:
|
308 |
+
# continue
|
309 |
+
|
310 |
+
if dec.isupper() != answer.isupper():
|
311 |
+
continue
|
312 |
+
|
313 |
+
if EXTEND_SUBWORDS and para_pos_best[id] > 0:
|
314 |
+
if DEBUG_OUTPUT:
|
315 |
+
print("Extending %s with %d masks..."%(dec, para_pos_best[id]))
|
316 |
+
ext_ids, _ = extend(toks_sent, toks_para, [id], para_pos_best[id], para_words)
|
317 |
+
sub_ids = ext_ids + sub_ids
|
318 |
+
dec_ext = toker.decode(sub_ids).strip()
|
319 |
+
if DEBUG_OUTPUT:
|
320 |
+
print("Extended %s to %s"%(dec, dec_ext))
|
321 |
+
if is_word(dec_ext) or (dec_ext != '' and dec_ext in para_words):
|
322 |
+
dec = dec_ext # choose new word
|
323 |
+
else:
|
324 |
+
sub_ids = [int(id)] # reset
|
325 |
+
|
326 |
+
if len(toker.decode(sub_ids).lower().strip()) < 2:
|
327 |
+
continue
|
328 |
+
|
329 |
+
if dec[0].isupper() != answer[0].isupper():
|
330 |
+
continue
|
331 |
+
|
332 |
+
# Only add distractor if it does not contain punctuation
|
333 |
+
#if any(p in dec for p in punctuation):
|
334 |
+
# pass
|
335 |
+
#continue
|
336 |
+
|
337 |
+
if dec.lower() in profanity:
|
338 |
+
continue
|
339 |
+
|
340 |
+
# make sure is a word, either in dict or somewhere else in text
|
341 |
+
if not is_word(dec) and dec.lower() not in para_words:
|
342 |
+
continue
|
343 |
+
|
344 |
+
# make sure is not the same as an adjacent word
|
345 |
+
if dec.lower() == prev_word or dec.lower() == next_word:
|
346 |
+
continue
|
347 |
+
|
348 |
+
# Don't add the distractor if stem matches another
|
349 |
+
stem = stemmer.stem(dec).lower()
|
350 |
+
if stem in stems or stem == ans_stem:
|
351 |
+
continue
|
352 |
+
|
353 |
+
# Only add distractor if it does not contain a number
|
354 |
+
if any(char.isdigit() for char in toker.decode([id])):
|
355 |
+
continue
|
356 |
+
|
357 |
+
# Only add distractor if the distractor exists in the text already
|
358 |
+
if DISTRACTORS_FROM_TEXT and dec.lower() not in para_words:
|
359 |
+
continue
|
360 |
+
|
361 |
+
#if answer[0].isupper():
|
362 |
+
# dec = dec.capitalize()
|
363 |
+
|
364 |
+
# PASSED ALL TESTS; finally add distractor and computations
|
365 |
+
distractors.append(dec)
|
366 |
+
stems.append(stem)
|
367 |
+
sent_logprob = 0
|
368 |
+
para_logprob = 0
|
369 |
+
nsubs = len(sub_ids)
|
370 |
+
for j in range(nsubs):
|
371 |
+
sub_id = sub_ids[j]
|
372 |
+
sent_logprob_j = log(sent_sms_all[nsubs-1][j][sub_id])
|
373 |
+
para_logprob_j = log(para_sms_all[nsubs-1][j][sub_id])
|
374 |
+
#if j == 0 or sent_logprob_j > sent_logprob:
|
375 |
+
# sent_logprob = sent_logprob_j
|
376 |
+
#if j == 0 or para_logprob_j > para_logprob:
|
377 |
+
# para_logprob = para_logprob_j
|
378 |
+
sent_logprob += sent_logprob_j
|
379 |
+
para_logprob += para_logprob_j
|
380 |
+
sent_logprob /= nsubs
|
381 |
+
para_logprob /= nsubs
|
382 |
+
if DEBUG_OUTPUT:
|
383 |
+
print("%s (p_sent=%f, p_para=%f)"%(dec,sent_logprob,para_logprob))
|
384 |
+
sent_probs.append(exp(sent_logprob))
|
385 |
+
para_probs.append(exp(para_logprob))
|
386 |
+
# sent_probs.append(sent_sms_all[nsubs-1][nsubs-1][sub_id])
|
387 |
+
# para_probs.append(para_sms_all[nsubs-1][nsubs-1][sub_id])
|
388 |
+
embs.append(get_emb(toks_para, sub_ids))
|
389 |
+
|
390 |
+
if len(distractors) >= K:
|
391 |
+
break
|
392 |
+
if DEBUG_OUTPUT:
|
393 |
+
print('Corresponding Text: ', text)
|
394 |
+
print('Correct Answer: ', answer)
|
395 |
+
print('Distractors created before annealing: ', distractors)
|
396 |
+
#indices = anneal(sent_probs, para_probs, embs, emb_ans, number_of_distractors, distractors, answer)
|
397 |
+
#distractors = [distractors[i] for i in indices]
|
398 |
+
#distractors += [''] * (number_of_distractors - len(distractors))
|
399 |
+
|
400 |
+
return sent_probs, para_probs, embs, emb_ans, distractors
|
401 |
+
|
402 |
+
def create_distractors(text, answer):
|
403 |
+
sent_probs, para_probs, embs, emb_ans, distractors = candidates(text, answer)
|
404 |
+
#print(distractors)
|
405 |
+
indices = anneal(sent_probs, para_probs, embs, emb_ans, distractors, 3, answer)
|
406 |
+
return [distractors[x] for x in indices]
|
407 |
+
|
408 |
+
st.title("nCloze")
|
409 |
+
st.subheader("Create a multiple-choice cloze test from a passage")
|
410 |
+
|
411 |
+
|
412 |
+
def blank(tok):
|
413 |
+
if tok == 'a(n)':
|
414 |
+
strp = tok
|
415 |
+
else:
|
416 |
+
strp = tok.strip(punctuation)
|
417 |
+
print(strp, tok.replace(strp, toker.mask_token))
|
418 |
+
return strp, tok.replace(strp, toker.mask_token)
|
419 |
+
|
420 |
+
test = """In contrast to necrosis, which is a form of traumatic cell death that results from acute cellular injury, apoptosis is a highly regulated and controlled process that confers advantages during an organism's life cycle. For example, the separation of fingers and toes in a developing human embryo occurs because cells between the digits undergo apoptosis. Unlike necrosis, apoptosis produces cell fragments called apoptotic bodies that phagocytes are able to engulf and remove before the contents of the cell can spill out onto surrounding cells and cause damage to them."""
|
421 |
+
st.header("Basic options")
|
422 |
+
SPACING = int(st.text_input('Blank spacing', value="7"))
|
423 |
+
OFFSET = int(st.text_input('First word to blank (0 to use spacing)', value="0"))
|
424 |
+
st.header("Advanced options")
|
425 |
+
ALPHA = float(st.text_input('Incorrectness weight', value="0.75"))
|
426 |
+
BETA = float(st.text_input('Distinctness weight', value="0.75"))
|
427 |
+
MODEL_TYPE = st.text_input('Masked Language Model (from HuggingFace)', value="roberta-large")
|
428 |
+
K = 16
|
429 |
+
|
430 |
+
model = AutoModelForMaskedLM.from_pretrained(MODEL_TYPE)#, cache_dir=CACHE_DIR)
|
431 |
+
|
432 |
+
if T.cuda.is_available():
|
433 |
+
model.cuda()
|
434 |
+
|
435 |
+
toker = AutoTokenizer.from_pretrained(MODEL_TYPE, add_prefix_space=True)
|
436 |
+
|
437 |
+
sorted_toker_vocab_dict = sorted(toker.vocab.items(), key=lambda x:x[1])
|
438 |
+
if toker.mask_token == '[MASK]': # BERT style
|
439 |
+
suffix_mask = T.FloatTensor([1 if (('##' == x[0][:2]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) # 1 means is-suffix and 0 mean not-suffix
|
440 |
+
else: # RoBERTa style
|
441 |
+
suffix_mask = T.FloatTensor([1 if (('Ġ' != x[0][0]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) # 1 means is-suffix and 0 mean not-suffix
|
442 |
+
suffix_mask_inv = suffix_mask * -1 + 1
|
443 |
+
word_mask = suffix_mask_inv*T.FloatTensor([1 if is_word(x[0][1:]) and x[0][1:].lower() not in profanity else 0 for x in sorted_toker_vocab_dict])
|
444 |
+
if T.cuda.is_available():
|
445 |
+
suffix_mask=suffix_mask.cuda()
|
446 |
+
suffix_mask_inv=suffix_mask_inv.cuda()
|
447 |
+
word_mask = word_mask.cuda()
|
448 |
+
|
449 |
+
st.subheader("Passage")
|
450 |
+
st.text_area('Passage to create a cloze test from:',value=test,key="text", max_chars=1024, height=275)
|
451 |
+
|
452 |
+
def generate():
|
453 |
+
ws = st.session_state.text.split()
|
454 |
+
wb = st.session_state.text.split()
|
455 |
+
|
456 |
+
qs = []
|
457 |
+
i = OFFSET - 1 if OFFSET > 0 else SPACING
|
458 |
+
j = 0
|
459 |
+
while i < len(ws):
|
460 |
+
a, b = blank(ws[i])
|
461 |
+
while b == '' and i < len(ws)-1:
|
462 |
+
i += 1
|
463 |
+
a, b = blank(ws[i])
|
464 |
+
if b != '':
|
465 |
+
w = ws[i]
|
466 |
+
ws[i] = b
|
467 |
+
wb[i] = b
|
468 |
+
|
469 |
+
while j<i:
|
470 |
+
yield(' ' + ws[j])
|
471 |
+
j += 1
|
472 |
+
masked = ' '.join(ws)
|
473 |
+
#st.write(masked)
|
474 |
+
ds = create_distractors(masked, a)
|
475 |
+
print(ds, a)
|
476 |
+
q = ds+[a+'\*']
|
477 |
+
random.shuffle(q)
|
478 |
+
yield(b.replace(toker.mask_token,' **['+', '.join(q)+']**'))
|
479 |
+
j+=1
|
480 |
+
qs.append(ds)
|
481 |
+
ws[i] = w
|
482 |
+
i += SPACING
|
483 |
+
while j<len(ws):
|
484 |
+
yield(' ' + ws[j])
|
485 |
+
j += 1
|
486 |
+
|
487 |
+
# Load model and run inference
|
488 |
+
if st.button("Generate"):
|
489 |
+
|
490 |
+
st.write_stream(generate())
|
491 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
transformers
|
3 |
+
torch
|
4 |
+
nltk
|
5 |
+
numpy
|
6 |
+
stanza
|
7 |
+
pandas
|
8 |
+
urllib
|