Using max_position_embeddings instead of max_sequence_length to standardise with HF
Browse files- modeling_gptbert.py +3 -3
modeling_gptbert.py
CHANGED
|
@@ -284,7 +284,7 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|
| 284 |
|
| 285 |
head_size = config.query_key_head_size
|
| 286 |
assert head_size % 2 == 0
|
| 287 |
-
max_seq_len = config.
|
| 288 |
|
| 289 |
inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
|
| 290 |
pos = torch.arange(max_seq_len, dtype=torch.float32)
|
|
@@ -370,14 +370,14 @@ class SelfAttention(nn.Module):
|
|
| 370 |
|
| 371 |
# Initialize rotary embeddings based on whether FlashAttention is available
|
| 372 |
if flash_attn_varlen_qkvpacked_func is not None:
|
| 373 |
-
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.
|
| 374 |
else:
|
| 375 |
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
| 376 |
|
| 377 |
self.scale = 1.0 / math.sqrt(self.d_qk)
|
| 378 |
self.lambdas = nn.Parameter(torch.tensor([0.5]))
|
| 379 |
|
| 380 |
-
self.sequence_length = config.
|
| 381 |
self.window_length = None
|
| 382 |
|
| 383 |
def set_window_length(self, window_length: int):
|
|
|
|
| 284 |
|
| 285 |
head_size = config.query_key_head_size
|
| 286 |
assert head_size % 2 == 0
|
| 287 |
+
max_seq_len = config.max_position_embeddings
|
| 288 |
|
| 289 |
inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
|
| 290 |
pos = torch.arange(max_seq_len, dtype=torch.float32)
|
|
|
|
| 370 |
|
| 371 |
# Initialize rotary embeddings based on whether FlashAttention is available
|
| 372 |
if flash_attn_varlen_qkvpacked_func is not None:
|
| 373 |
+
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_position_embeddings)
|
| 374 |
else:
|
| 375 |
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
| 376 |
|
| 377 |
self.scale = 1.0 / math.sqrt(self.d_qk)
|
| 378 |
self.lambdas = nn.Parameter(torch.tensor([0.5]))
|
| 379 |
|
| 380 |
+
self.sequence_length = config.max_position_embeddings
|
| 381 |
self.window_length = None
|
| 382 |
|
| 383 |
def set_window_length(self, window_length: int):
|