Update modeling_baichuan.py

#6
by JaheimLee - opened
Files changed (1) hide show
  1. modeling_baichuan.py +1 -1
modeling_baichuan.py CHANGED
@@ -513,7 +513,7 @@ class NormHead(nn.Module):
513
  norm_weight = nn.functional.normalize(self.weight)
514
  elif self.first_flag:
515
  self.first_flag = False
516
- self.weight = nn.Parameter(nn.functional.normalize(self.weight))
517
  norm_weight = self.weight
518
  else:
519
  norm_weight = self.weight
 
513
  norm_weight = nn.functional.normalize(self.weight)
514
  elif self.first_flag:
515
  self.first_flag = False
516
+ self.weight.data = nn.functional.normalize(self.weight)
517
  norm_weight = self.weight
518
  else:
519
  norm_weight = self.weight