andreslu commited on
Commit
5e42615
1 Parent(s): 4e06688

Upload inductor.py

Browse files
Files changed (1) hide show
  1. inductor.py +9 -8
inductor.py CHANGED
@@ -76,25 +76,25 @@ class BartInductor(object):
76
  self.stop_weight = stop_weight[0, :]
77
 
78
  def clean(self, text):
79
- segments = text.split('<mask>')
80
- if len(segments) == 3 and segments[2].startswith('.'):
81
- return '<mask>'.join(segments[:2]) + '<mask>.'
 
82
  else:
83
  return text
84
 
85
  def generate(self, inputs, k=10, topk=10, return_scores=False):
86
  with torch.no_grad():
87
  tB_probs = self.generate_rule(inputs, k)
 
88
  if return_scores:
89
- ret = [(t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>'), t[1]) for t in tB_probs]
90
- new_ret = []
91
  for temp in ret:
92
  temp = (self.clean(temp[0].strip()), temp[1])
93
  if len(new_ret) < topk and temp not in new_ret:
94
  new_ret.append(temp)
95
  else:
96
- ret = [t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>') for t in tB_probs]
97
- new_ret = []
98
  for temp in ret:
99
  temp = self.clean(temp.strip())
100
  if len(new_ret) < topk and temp not in new_ret:
@@ -134,7 +134,7 @@ class BartInductor(object):
134
  return ret
135
 
136
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
137
- spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
138
  generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64)
139
  generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
140
  #num_beam_groups=max(120, k),
@@ -300,6 +300,7 @@ class BartInductor(object):
300
  for k1 in tB_prob:
301
  ret.append([k1, tB_prob[k1]])
302
  ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k]
 
303
  if self.if_then:
304
  for i, temp in enumerate(ret):
305
  sentence = temp[0]
 
76
  self.stop_weight = stop_weight[0, :]
77
 
78
  def clean(self, text):
79
+ segments = re.split(r'<ent\d>', text)
80
+ last_segment = segments[-1]
81
+ if last_segment.startswith('.'):
82
+ return text[:text.rfind(last_segment)]+'.'
83
  else:
84
  return text
85
 
86
  def generate(self, inputs, k=10, topk=10, return_scores=False):
87
  with torch.no_grad():
88
  tB_probs = self.generate_rule(inputs, k)
89
+ new_ret = []
90
  if return_scores:
91
+ ret = [(t[0], t[1]) for t in tB_probs]
 
92
  for temp in ret:
93
  temp = (self.clean(temp[0].strip()), temp[1])
94
  if len(new_ret) < topk and temp not in new_ret:
95
  new_ret.append(temp)
96
  else:
97
+ ret = [t[0] for t in tB_probs]
 
98
  for temp in ret:
99
  temp = self.clean(temp.strip())
100
  if len(new_ret) < topk and temp not in new_ret:
 
134
  return ret
135
 
136
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
137
+ spans = [t.lower().strip() for t in re.split(r'<.*?>', tA[:-1])]
138
  generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64)
139
  generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
140
  #num_beam_groups=max(120, k),
 
300
  for k1 in tB_prob:
301
  ret.append([k1, tB_prob[k1]])
302
  ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k]
303
+
304
  if self.if_then:
305
  for i, temp in enumerate(ret):
306
  sentence = temp[0]