Spaces:
Running
Running
seonil
commited on
Commit
•
a1a10ca
1
Parent(s):
37d452a
bugfix
Browse files- __pycache__/harim_scorer.cpython-39.pyc +0 -0
- harim_plus.py +3 -2
- harim_scorer.py +3 -2
__pycache__/harim_scorer.cpython-39.pyc
CHANGED
Binary files a/__pycache__/harim_scorer.cpython-39.pyc and b/__pycache__/harim_scorer.cpython-39.pyc differ
|
|
harim_plus.py
CHANGED
@@ -207,18 +207,19 @@ class Harimplus_Scorer:
|
|
207 |
emp_in = emp_in.to(self._device)
|
208 |
tgt_in = tgt_in.to(self._device)
|
209 |
tgt_mask = tgt_mask.to(self._device)
|
|
|
210 |
|
211 |
with torch.no_grad():
|
212 |
# token_type_ids attribute causes error
|
213 |
s2s_logits = self._encdec_model.forward(
|
214 |
input_ids = src_in.input_ids,
|
215 |
attention_mask = src_in.attention_mask,
|
216 |
-
labels = tgt_in.input_ids,
|
217 |
return_dict=True).logits
|
218 |
lm_logits = self._encdec_model.forward(
|
219 |
input_ids = emp_in.input_ids,
|
220 |
attention_mask = emp_in.attention_mask,
|
221 |
-
labels = tgt_in.input_ids,
|
222 |
return_dict=True).logits
|
223 |
sent_lengths = tgt_mask.sum(-1)
|
224 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
|
|
207 |
emp_in = emp_in.to(self._device)
|
208 |
tgt_in = tgt_in.to(self._device)
|
209 |
tgt_mask = tgt_mask.to(self._device)
|
210 |
+
fill_ignore_mask = ~(tgt_mask.bool())
|
211 |
|
212 |
with torch.no_grad():
|
213 |
# token_type_ids attribute causes error
|
214 |
s2s_logits = self._encdec_model.forward(
|
215 |
input_ids = src_in.input_ids,
|
216 |
attention_mask = src_in.attention_mask,
|
217 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
218 |
return_dict=True).logits
|
219 |
lm_logits = self._encdec_model.forward(
|
220 |
input_ids = emp_in.input_ids,
|
221 |
attention_mask = emp_in.attention_mask,
|
222 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
223 |
return_dict=True).logits
|
224 |
sent_lengths = tgt_mask.sum(-1)
|
225 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
harim_scorer.py
CHANGED
@@ -141,18 +141,19 @@ class Harimplus_Scorer:
|
|
141 |
emp_in = emp_in.to(self._device)
|
142 |
tgt_in = tgt_in.to(self._device)
|
143 |
tgt_mask = tgt_mask.to(self._device)
|
|
|
144 |
|
145 |
with torch.no_grad():
|
146 |
# token_type_ids attribute causes error
|
147 |
s2s_logits = self._encdec_model.forward(
|
148 |
input_ids = src_in.input_ids,
|
149 |
attention_mask = src_in.attention_mask,
|
150 |
-
labels = tgt_in.input_ids,
|
151 |
return_dict=True).logits
|
152 |
lm_logits = self._encdec_model.forward(
|
153 |
input_ids = emp_in.input_ids,
|
154 |
attention_mask = emp_in.attention_mask,
|
155 |
-
labels = tgt_in.input_ids,
|
156 |
return_dict=True).logits
|
157 |
sent_lengths = tgt_mask.sum(-1)
|
158 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
|
|
141 |
emp_in = emp_in.to(self._device)
|
142 |
tgt_in = tgt_in.to(self._device)
|
143 |
tgt_mask = tgt_mask.to(self._device)
|
144 |
+
fill_ignore_mask = ~(tgt_mask.bool())
|
145 |
|
146 |
with torch.no_grad():
|
147 |
# token_type_ids attribute causes error
|
148 |
s2s_logits = self._encdec_model.forward(
|
149 |
input_ids = src_in.input_ids,
|
150 |
attention_mask = src_in.attention_mask,
|
151 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
152 |
return_dict=True).logits
|
153 |
lm_logits = self._encdec_model.forward(
|
154 |
input_ids = emp_in.input_ids,
|
155 |
attention_mask = emp_in.attention_mask,
|
156 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
157 |
return_dict=True).logits
|
158 |
sent_lengths = tgt_mask.sum(-1)
|
159 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|