Update modeling_ltgbert.py
Browse filesFixing error in the re-bucketing for the Attention module
- modeling_ltgbert.py +1 -1
modeling_ltgbert.py
CHANGED
@@ -209,7 +209,7 @@ class Attention(nn.Module):
|
|
209 |
if self.position_indices.size(0) < query_len:
|
210 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
211 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
212 |
-
position_indices = self.make_log_bucket_position(position_indices, self.position_bucket_size, 512)
|
213 |
position_indices = self.position_bucket_size - 1 + position_indices
|
214 |
self.position_indices = position_indices.to(hidden_states.device)
|
215 |
|
|
|
209 |
if self.position_indices.size(0) < query_len:
|
210 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
211 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
212 |
+
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
213 |
position_indices = self.position_bucket_size - 1 + position_indices
|
214 |
self.position_indices = position_indices.to(hidden_states.device)
|
215 |
|