ltg
/

lgcharpe commited on
Commit
e1a145d
·
verified ·
1 Parent(s): 3f2f30b

Update modeling_ltgbert.py

Browse files
Files changed (1) hide show
  1. modeling_ltgbert.py +2 -2
modeling_ltgbert.py CHANGED
@@ -179,7 +179,7 @@ class Attention(nn.Module):
179
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
180
  position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
181
  position_indices = config.position_bucket_size - 1 + position_indices
182
- self.register_buffer("position_indices", position_indices, persistent=True)
183
 
184
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
185
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
@@ -210,7 +210,7 @@ class Attention(nn.Module):
210
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
211
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
212
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
213
- position_indices = self.position_bucket_size - 1 + position_indices
214
  self.position_indices = position_indices.to(hidden_states.device)
215
 
216
  hidden_states = self.pre_layer_norm(hidden_states)
 
179
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
180
  position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
181
  position_indices = config.position_bucket_size - 1 + position_indices
182
+ self.register_buffer("position_indices", position_indices, persistent=False)
183
 
184
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
185
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
 
210
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
211
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
212
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
213
+ position_indices = self.config.position_bucket_size - 1 + position_indices
214
  self.position_indices = position_indices.to(hidden_states.device)
215
 
216
  hidden_states = self.pre_layer_norm(hidden_states)