yairschiff commited on
Commit
0fd3e52
·
verified ·
1 Parent(s): 37ffde2

Update modeling_caduceus.py

Browse files

Change `loss_weight` to `loss_weights` in weighted CE.

Files changed (1) hide show
  1. modeling_caduceus.py +4 -4
modeling_caduceus.py CHANGED
@@ -263,15 +263,15 @@ def cross_entropy(logits, y, ignore_index=-100):
263
  return F.cross_entropy(logits, y, ignore_index=ignore_index)
264
 
265
 
266
- def weighted_cross_entropy(logits, y, loss_weight, ignore_index=-100):
267
  """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
268
  logits = logits.view(-1, logits.shape[-1])
269
  y = y.view(-1)
270
  ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
271
- loss_weight = loss_weight.view(-1)
272
- loss_weight[y == ignore_index] = 0.0
273
  # TODO: Follows GPN implementation, but should we remove weight normalization?
274
- return (ce * (loss_weight / loss_weight.sum())).sum()
275
 
276
 
277
  class CaduceusPreTrainedModel(PreTrainedModel):
 
263
  return F.cross_entropy(logits, y, ignore_index=ignore_index)
264
 
265
 
266
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
267
  """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
268
  logits = logits.view(-1, logits.shape[-1])
269
  y = y.view(-1)
270
  ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
271
+ loss_weights = loss_weights.view(-1)
272
+ loss_weights[y == ignore_index] = 0.0
273
  # TODO: Follows GPN implementation, but should we remove weight normalization?
274
+ return (ce * (loss_weights / loss_weights.sum())).sum()
275
 
276
 
277
  class CaduceusPreTrainedModel(PreTrainedModel):