Update modeling_mpt.py

#44
by vchiley - opened
Files changed (1) hide show
  1. modeling_mpt.py +1 -1
modeling_mpt.py CHANGED
@@ -182,7 +182,7 @@ class MPTModel(MPTPreTrainedModel):
182
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
183
  assert isinstance(self.emb_drop, nn.Module)
184
  x = self.emb_drop(x_shrunk)
185
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
186
  if use_cache and past_key_values is None:
187
  past_key_values = [() for _ in range(self.config.n_layers)]
188
  all_hidden_states = () if output_hidden_states else None
 
182
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
183
  assert isinstance(self.emb_drop, nn.Module)
184
  x = self.emb_drop(x_shrunk)
185
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
186
  if use_cache and past_key_values is None:
187
  past_key_values = [() for _ in range(self.config.n_layers)]
188
  all_hidden_states = () if output_hidden_states else None