Files changed (1) hide show
  1. modeling_baichuan.py +2 -1
modeling_baichuan.py CHANGED
@@ -502,9 +502,10 @@ class NormHead(nn.Module):
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
 
505
  elif self.first_flag:
506
  self.first_flag = False
507
- self.weight = nn.Parameter(nn.functional.normalize(self.weight))
508
  norm_weight = self.weight
509
  else:
510
  norm_weight = self.weight
 
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
505
+ self.first_flag = False
506
  elif self.first_flag:
507
  self.first_flag = False
508
+ self.weight.data = nn.functional.normalize(self.weight)
509
  norm_weight = self.weight
510
  else:
511
  norm_weight = self.weight