lgcharpe commited on
Commit
72230ef
·
verified ·
1 Parent(s): 41b3fbd

Using max_position_embeddings instead of max_sequence_length to standardise with HF

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +3 -3
modeling_gptbert.py CHANGED
@@ -284,7 +284,7 @@ class RotaryPositionalEmbeddings(nn.Module):
284
 
285
  head_size = config.query_key_head_size
286
  assert head_size % 2 == 0
287
- max_seq_len = config.max_sequence_length
288
 
289
  inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
290
  pos = torch.arange(max_seq_len, dtype=torch.float32)
@@ -370,14 +370,14 @@ class SelfAttention(nn.Module):
370
 
371
  # Initialize rotary embeddings based on whether FlashAttention is available
372
  if flash_attn_varlen_qkvpacked_func is not None:
373
- self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
374
  else:
375
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
376
 
377
  self.scale = 1.0 / math.sqrt(self.d_qk)
378
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
379
 
380
- self.sequence_length = config.max_sequence_length
381
  self.window_length = None
382
 
383
  def set_window_length(self, window_length: int):
 
284
 
285
  head_size = config.query_key_head_size
286
  assert head_size % 2 == 0
287
+ max_seq_len = config.max_position_embeddings
288
 
289
  inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
290
  pos = torch.arange(max_seq_len, dtype=torch.float32)
 
370
 
371
  # Initialize rotary embeddings based on whether FlashAttention is available
372
  if flash_attn_varlen_qkvpacked_func is not None:
373
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_position_embeddings)
374
  else:
375
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
376
 
377
  self.scale = 1.0 / math.sqrt(self.d_qk)
378
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
379
 
380
+ self.sequence_length = config.max_position_embeddings
381
  self.window_length = None
382
 
383
  def set_window_length(self, window_length: int):