Chongruo Wu commited on
Commit
f1020dc
·
1 Parent(s): c899f8b

Fix a bug related to displaying ce_loss

Browse files

Former-commit-id: 2fd1438861a0ef29d82b049836c62160402e8bb7

Files changed (1) hide show
  1. model/LISA.py +1 -2
model/LISA.py CHANGED
@@ -306,7 +306,6 @@ class LISAForCausalLM(LlavaLlamaForCausalLM):
306
 
307
  ce_loss = model_output.loss
308
  ce_loss = ce_loss * self.ce_loss_weight
309
- loss = ce_loss
310
  mask_bce_loss = 0
311
  mask_dice_loss = 0
312
  num_masks = 0
@@ -333,7 +332,7 @@ class LISAForCausalLM(LlavaLlamaForCausalLM):
333
  mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
334
  mask_loss = mask_bce_loss + mask_dice_loss
335
 
336
- loss += mask_loss
337
 
338
  return {
339
  "loss": loss,
 
306
 
307
  ce_loss = model_output.loss
308
  ce_loss = ce_loss * self.ce_loss_weight
 
309
  mask_bce_loss = 0
310
  mask_dice_loss = 0
311
  num_masks = 0
 
332
  mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
333
  mask_loss = mask_bce_loss + mask_dice_loss
334
 
335
+ loss = ce_loss + mask_loss
336
 
337
  return {
338
  "loss": loss,