zpn commited on
Commit
6148d34
1 Parent(s): 1ef33b1

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +8 -1
modeling_hf_nomic_bert.py CHANGED
@@ -321,7 +321,8 @@ class NomicBertPreTrainedModel(PreTrainedModel):
321
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
322
  num_labels = kwargs.pop("num_labels", None)
323
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
- config.rotary_scaling_factor = rotary_scaling_factor
 
325
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
326
  config.n_positions = 2048
327
  if num_labels:
@@ -554,6 +555,12 @@ class NomicBertRotaryEmbedding(nn.Module):
554
  self.register_buffer("inv_freq", inv_freq, persistent=False)
555
  self.interleaved = interleaved
556
  self.scale_base = scale_base
 
 
 
 
 
 
557
 
558
  self._seq_len_cached = 0
559
  self._cos_cached = None
 
321
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
322
  num_labels = kwargs.pop("num_labels", None)
323
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
+ if rotary_scaling_factor:
325
+ config.rotary_scaling_factor = rotary_scaling_factor
326
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
327
  config.n_positions = 2048
328
  if num_labels:
 
555
  self.register_buffer("inv_freq", inv_freq, persistent=False)
556
  self.interleaved = interleaved
557
  self.scale_base = scale_base
558
+ scale = (
559
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
560
+ if scale_base is not None
561
+ else None
562
+ )
563
+ self.register_buffer("scale", scale, persistent=False)
564
 
565
  self._seq_len_cached = 0
566
  self._cos_cached = None