kuaizhirui commited on
Commit
1682dd0
1 Parent(s): f9d4d8d

fix NormHead eval

Browse files

I encountered this problem when using Baichuan2-7B-Base with deepspeed stage3 for sft. A similar situation also happened in the place such as https://github.com/baichuan-inc/Baichuan2/issues/39#issuecomment-1710146497
I found that Baichuan2-13B-Chat has solved this problem, so I synced the code here

Files changed (1) hide show
  1. modeling_baichuan.py +2 -1
modeling_baichuan.py CHANGED
@@ -502,9 +502,10 @@ class NormHead(nn.Module):
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
 
505
  elif self.first_flag:
506
  self.first_flag = False
507
- self.weight = nn.Parameter(nn.functional.normalize(self.weight))
508
  norm_weight = self.weight
509
  else:
510
  norm_weight = self.weight
 
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
505
+ self.first_flag = False
506
  elif self.first_flag:
507
  self.first_flag = False
508
+ self.weight.data = nn.functional.normalize(self.weight)
509
  norm_weight = self.weight
510
  else:
511
  norm_weight = self.weight