wuzhiying2023 commited on
Commit
229e4eb
1 Parent(s): e03d54f

fix NormHead eval

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +1 -0
modeling_baichuan.py CHANGED
@@ -502,6 +502,7 @@ class NormHead(nn.Module):
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
 
505
  elif self.first_flag:
506
  self.first_flag = False
507
  self.weight.data = nn.functional.normalize(self.weight)
 
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
505
+ self.first_flag = True
506
  elif self.first_flag:
507
  self.first_flag = False
508
  self.weight.data = nn.functional.normalize(self.weight)