Hyungtae Kim commited on
Commit
27675b5
1 Parent(s): 8fb2de8

Remove custom embedding.

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +2 -3
modeling_mpt.py CHANGED
@@ -12,7 +12,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokeniz
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
- from .custom_embedding import SharedEmbedding
16
  from .norm import NORM_CLASS_REGISTRY
17
  from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
@@ -56,7 +55,7 @@ class MPTModel(MPTPreTrainedModel):
56
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
57
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
58
  self.embedding_fraction = config.embedding_fraction
59
- self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
60
  if not self.alibi:
61
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
62
  self.emb_drop = nn.Dropout(config.emb_pdrop)
@@ -322,7 +321,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
322
  if inputs_embeds is not None:
323
  raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
324
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
325
- logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
326
  if self.logit_scale is not None:
327
  if self.logit_scale == 0:
328
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
 
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
 
55
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
56
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
57
  self.embedding_fraction = config.embedding_fraction
58
+ self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
59
  if not self.alibi:
60
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
61
  self.emb_drop = nn.Dropout(config.emb_pdrop)
 
321
  if inputs_embeds is not None:
322
  raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
323
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
324
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
325
  if self.logit_scale is not None:
326
  if self.logit_scale == 0:
327
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')