andreslu commited on
Commit
16c895c
1 Parent(s): d34098d

Update inductor.py

Browse files
Files changed (1) hide show
  1. inductor.py +15 -9
inductor.py CHANGED
@@ -82,17 +82,23 @@ class BartInductor(object):
82
  else:
83
  return text
84
 
85
- def generate(self, inputs, k=10, topk=10):
86
  with torch.no_grad():
87
  tB_probs = self.generate_rule(inputs, k)
88
- #ret = [t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>') for t in tB_probs]
89
- ret = [(t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>'), t[1]) for t in tB_probs]
90
-
91
- new_ret = []
92
- for temp in ret:
93
- temp = (self.clean(temp[0].strip()), temp[1])
94
- if len(new_ret) < topk and temp not in new_ret:
95
- new_ret.append(temp)
 
 
 
 
 
 
96
 
97
  return new_ret
98
 
 
82
  else:
83
  return text
84
 
85
+ def generate(self, inputs, k=10, topk=10, return_scores=False):
86
  with torch.no_grad():
87
  tB_probs = self.generate_rule(inputs, k)
88
+ if return_scores:
89
+ ret = [(t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>'), t[1]) for t in tB_probs]
90
+ new_ret = []
91
+ for temp in ret:
92
+ temp = (self.clean(temp[0].strip()), temp[1])
93
+ if len(new_ret) < topk and temp not in new_ret:
94
+ new_ret.append(temp)
95
+ else:
96
+ ret = [t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>') for t in tB_probs]
97
+ new_ret = []
98
+ for temp in ret:
99
+ temp = self.clean(temp[0].strip())
100
+ if len(new_ret) < topk and temp not in new_ret:
101
+ new_ret.append(temp)
102
 
103
  return new_ret
104