Update modeling_ltgbert.py
Browse files- modeling_ltgbert.py +2 -2
modeling_ltgbert.py
CHANGED
@@ -179,7 +179,7 @@ class Attention(nn.Module):
|
|
179 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
180 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
181 |
position_indices = config.position_bucket_size - 1 + position_indices
|
182 |
-
self.register_buffer("position_indices", position_indices, persistent=
|
183 |
|
184 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
185 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
@@ -210,7 +210,7 @@ class Attention(nn.Module):
|
|
210 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
211 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
212 |
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
213 |
-
position_indices = self.position_bucket_size - 1 + position_indices
|
214 |
self.position_indices = position_indices.to(hidden_states.device)
|
215 |
|
216 |
hidden_states = self.pre_layer_norm(hidden_states)
|
|
|
179 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
180 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
181 |
position_indices = config.position_bucket_size - 1 + position_indices
|
182 |
+
self.register_buffer("position_indices", position_indices, persistent=False)
|
183 |
|
184 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
185 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
|
|
210 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
211 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
212 |
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
213 |
+
position_indices = self.config.position_bucket_size - 1 + position_indices
|
214 |
self.position_indices = position_indices.to(hidden_states.device)
|
215 |
|
216 |
hidden_states = self.pre_layer_norm(hidden_states)
|