M commited on
Commit
8b2b3a6
1 Parent(s): 85c1f1c

Fixing "RuntimeError: expected scalar type Half but found Float" error

Browse files

Change in Commit "85c1f1c201273bbfee661d4a2f8307c95f8956c9" is raising an error when using the model for 8bit inference.

Files changed (1) hide show
  1. modeling_mpt.py +1 -1
modeling_mpt.py CHANGED
@@ -182,7 +182,7 @@ class MPTModel(MPTPreTrainedModel):
182
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
183
  assert isinstance(self.emb_drop, nn.Module)
184
  x = self.emb_drop(x_shrunk)
185
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
186
  if use_cache and past_key_values is None:
187
  past_key_values = [() for _ in range(self.config.n_layers)]
188
  all_hidden_states = () if output_hidden_states else None
 
182
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
183
  assert isinstance(self.emb_drop, nn.Module)
184
  x = self.emb_drop(x_shrunk)
185
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
186
  if use_cache and past_key_values is None:
187
  past_key_values = [() for _ in range(self.config.n_layers)]
188
  all_hidden_states = () if output_hidden_states else None