Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
Alex Birch commited on
Commit
ec8bed8
1 Parent(s): ec8ea9d

apply device-transfer patch from https://github.com/mosaicml/llm-foundry/pull/225/files

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +6 -1
modeling_mpt.py CHANGED
@@ -298,7 +298,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
298
  return_dict = return_dict if return_dict is not None else self.config.return_dict
299
  use_cache = use_cache if use_cache is not None else self.config.use_cache
300
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
301
- logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
 
 
 
 
 
302
  if self.logit_scale is not None:
303
  if self.logit_scale == 0:
304
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
298
  return_dict = return_dict if return_dict is not None else self.config.return_dict
299
  use_cache = use_cache if use_cache is not None else self.config.use_cache
300
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
301
+ # move outputs to same device as weights for token embedding
302
+ # needed to support HF `device_map`
303
+ logits = F.linear(
304
+ outputs.last_hidden_state.to(self.transformer.wte.weight.device),
305
+ self.transformer.wte.weight
306
+ )
307
  if self.logit_scale is not None:
308
  if self.logit_scale == 0:
309
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')