yairschiff
commited on
Update modeling_caduceus.py
Browse filesChange `loss_weight` to `loss_weights` in weighted CE.
- modeling_caduceus.py +4 -4
modeling_caduceus.py
CHANGED
@@ -263,15 +263,15 @@ def cross_entropy(logits, y, ignore_index=-100):
|
|
263 |
return F.cross_entropy(logits, y, ignore_index=ignore_index)
|
264 |
|
265 |
|
266 |
-
def weighted_cross_entropy(logits, y,
|
267 |
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
|
268 |
logits = logits.view(-1, logits.shape[-1])
|
269 |
y = y.view(-1)
|
270 |
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
|
271 |
-
|
272 |
-
|
273 |
# TODO: Follows GPN implementation, but should we remove weight normalization?
|
274 |
-
return (ce * (
|
275 |
|
276 |
|
277 |
class CaduceusPreTrainedModel(PreTrainedModel):
|
|
|
263 |
return F.cross_entropy(logits, y, ignore_index=ignore_index)
|
264 |
|
265 |
|
266 |
+
def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
|
267 |
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
|
268 |
logits = logits.view(-1, logits.shape[-1])
|
269 |
y = y.view(-1)
|
270 |
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
|
271 |
+
loss_weights = loss_weights.view(-1)
|
272 |
+
loss_weights[y == ignore_index] = 0.0
|
273 |
# TODO: Follows GPN implementation, but should we remove weight normalization?
|
274 |
+
return (ce * (loss_weights / loss_weights.sum())).sum()
|
275 |
|
276 |
|
277 |
class CaduceusPreTrainedModel(PreTrainedModel):
|