Spaces:
Runtime error
Runtime error
Upload inductor.py
Browse files- inductor.py +9 -8
inductor.py
CHANGED
@@ -76,25 +76,25 @@ class BartInductor(object):
|
|
76 |
self.stop_weight = stop_weight[0, :]
|
77 |
|
78 |
def clean(self, text):
|
79 |
-
segments =
|
80 |
-
|
81 |
-
|
|
|
82 |
else:
|
83 |
return text
|
84 |
|
85 |
def generate(self, inputs, k=10, topk=10, return_scores=False):
|
86 |
with torch.no_grad():
|
87 |
tB_probs = self.generate_rule(inputs, k)
|
|
|
88 |
if return_scores:
|
89 |
-
ret = [(t[0]
|
90 |
-
new_ret = []
|
91 |
for temp in ret:
|
92 |
temp = (self.clean(temp[0].strip()), temp[1])
|
93 |
if len(new_ret) < topk and temp not in new_ret:
|
94 |
new_ret.append(temp)
|
95 |
else:
|
96 |
-
ret = [t[0]
|
97 |
-
new_ret = []
|
98 |
for temp in ret:
|
99 |
temp = self.clean(temp.strip())
|
100 |
if len(new_ret) < topk and temp not in new_ret:
|
@@ -134,7 +134,7 @@ class BartInductor(object):
|
|
134 |
return ret
|
135 |
|
136 |
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
|
137 |
-
spans = [t.lower().strip() for t in tA[:-1]
|
138 |
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64)
|
139 |
generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
|
140 |
#num_beam_groups=max(120, k),
|
@@ -300,6 +300,7 @@ class BartInductor(object):
|
|
300 |
for k1 in tB_prob:
|
301 |
ret.append([k1, tB_prob[k1]])
|
302 |
ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k]
|
|
|
303 |
if self.if_then:
|
304 |
for i, temp in enumerate(ret):
|
305 |
sentence = temp[0]
|
|
|
76 |
self.stop_weight = stop_weight[0, :]
|
77 |
|
78 |
def clean(self, text):
|
79 |
+
segments = re.split(r'<ent\d>', text)
|
80 |
+
last_segment = segments[-1]
|
81 |
+
if last_segment.startswith('.'):
|
82 |
+
return text[:text.rfind(last_segment)]+'.'
|
83 |
else:
|
84 |
return text
|
85 |
|
86 |
def generate(self, inputs, k=10, topk=10, return_scores=False):
|
87 |
with torch.no_grad():
|
88 |
tB_probs = self.generate_rule(inputs, k)
|
89 |
+
new_ret = []
|
90 |
if return_scores:
|
91 |
+
ret = [(t[0], t[1]) for t in tB_probs]
|
|
|
92 |
for temp in ret:
|
93 |
temp = (self.clean(temp[0].strip()), temp[1])
|
94 |
if len(new_ret) < topk and temp not in new_ret:
|
95 |
new_ret.append(temp)
|
96 |
else:
|
97 |
+
ret = [t[0] for t in tB_probs]
|
|
|
98 |
for temp in ret:
|
99 |
temp = self.clean(temp.strip())
|
100 |
if len(new_ret) < topk and temp not in new_ret:
|
|
|
134 |
return ret
|
135 |
|
136 |
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
|
137 |
+
spans = [t.lower().strip() for t in re.split(r'<.*?>', tA[:-1])]
|
138 |
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64)
|
139 |
generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
|
140 |
#num_beam_groups=max(120, k),
|
|
|
300 |
for k1 in tB_prob:
|
301 |
ret.append([k1, tB_prob[k1]])
|
302 |
ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k]
|
303 |
+
|
304 |
if self.if_then:
|
305 |
for i, temp in enumerate(ret):
|
306 |
sentence = temp[0]
|