daking commited on
Commit
150c524
·
1 Parent(s): e837ad7

LLM-foundry update June 16, 2023 22:55:57

Browse files
Files changed (2) hide show
  1. custom_embedding.py +1 -2
  2. modeling_mpt.py +11 -1
custom_embedding.py CHANGED
@@ -3,10 +3,9 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch import Tensor
5
 
6
-
7
  class SharedEmbedding(nn.Embedding):
8
 
9
- def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
10
  if unembed:
11
  return F.linear(input, self.weight)
12
  return super().forward(input)
 
3
  import torch.nn.functional as F
4
  from torch import Tensor
5
 
 
6
  class SharedEmbedding(nn.Embedding):
7
 
8
+ def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
9
  if unembed:
10
  return F.linear(input, self.weight)
11
  return super().forward(input)
modeling_mpt.py CHANGED
@@ -40,6 +40,11 @@ class MPTModel(MPTPreTrainedModel):
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
 
 
 
 
43
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
44
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
45
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
@@ -47,7 +52,7 @@ class MPTModel(MPTPreTrainedModel):
47
  self.embedding_fraction = config.embedding_fraction
48
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
49
  if not self.alibi:
50
- self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
51
  self.emb_drop = nn.Dropout(config.emb_pdrop)
52
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
53
  self.norm_f = norm_class(config.d_model, device=config.init_device)
@@ -221,6 +226,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
221
  if not config.tie_word_embeddings:
222
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
223
  self.transformer = MPTModel(config)
 
 
 
 
 
224
  self.logit_scale = None
225
  if config.logit_scale is not None:
226
  logit_scale = config.logit_scale
 
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
43
+ if config.init_device == 'mixed':
44
+ if dist.get_local_rank() == 0:
45
+ config.init_device = 'cpu'
46
+ else:
47
+ config.init_device = 'meta'
48
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
49
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
50
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
 
52
  self.embedding_fraction = config.embedding_fraction
53
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
54
  if not self.alibi:
55
+ self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
56
  self.emb_drop = nn.Dropout(config.emb_pdrop)
57
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
58
  self.norm_f = norm_class(config.d_model, device=config.init_device)
 
226
  if not config.tie_word_embeddings:
227
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
228
  self.transformer = MPTModel(config)
229
+ for child in self.transformer.children():
230
+ if isinstance(child, torch.nn.ModuleList):
231
+ continue
232
+ if isinstance(child, torch.nn.Module):
233
+ child._fsdp_wrap = True
234
  self.logit_scale = None
235
  if config.logit_scale is not None:
236
  logit_scale = config.logit_scale