BucketOfFish commited on
Commit
10aca20
1 Parent(s): 78f6f3b

Fixed weight loading from original Phi2 model

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. phi2_model.py +4 -7
  3. streaming_inference.py +14 -13
config.json CHANGED
@@ -13,7 +13,7 @@
13
  "torch_dtype": "float16",
14
  "transformers_version": "4.29.0",
15
 
16
- "vocab_size": 50304,
17
  "vocab_chunk_for_gpu_efficiency": 64,
18
  "initial_cos_sin_cache_len": 2048,
19
  "d_embedding": 2560,
 
13
  "torch_dtype": "float16",
14
  "transformers_version": "4.29.0",
15
 
16
+ "vocab_size": 51200,
17
  "vocab_chunk_for_gpu_efficiency": 64,
18
  "initial_cos_sin_cache_len": 2048,
19
  "d_embedding": 2560,
phi2_model.py CHANGED
@@ -13,11 +13,6 @@ class Phi2PreTrainedModel(PreTrainedModel):
13
  supports_gradient_checkpointing = False
14
  # _no_split_modules = ["ParallelAttentionBlock"]
15
 
16
- # weight loading
17
- # base_model_prefix = "transformer"
18
- # _keys_to_ignore_on_load_missing = [""]
19
- # _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
20
-
21
  def __init__(self, config: Phi2Config):
22
  super().__init__(config)
23
  self.config = config
@@ -42,6 +37,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
42
  input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
43
  kv_cache: KVCache | None = None,
44
  key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
 
45
  ) -> dict[str, Any]:
46
  if not kv_cache:
47
  kv_cache = KVCache(
@@ -142,7 +138,7 @@ class Phi2Model(Phi2PreTrainedModel):
142
  class Phi2ModelForCausalLM(Phi2PreTrainedModel):
143
  def __init__(self, config: Phi2Config) -> None:
144
  super().__init__(config)
145
- self.pretrained_model = Phi2Model(config)
146
  self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
147
  self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
148
  self.loss_fn = nn.CrossEntropyLoss()
@@ -154,8 +150,9 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
154
  kv_cache: KVCache | None = None,
155
  key_padding_mask: torch.BoolTensor | None = None,
156
  labels: torch.LongTensor | None = None,
 
157
  ) -> CausalLMOutputWithPast:
158
- x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
159
  x = self.lm_head_layer_norm(x)
160
  logits = self.lm_head_linear(x).to(torch.float32)
161
  loss = (
 
13
  supports_gradient_checkpointing = False
14
  # _no_split_modules = ["ParallelAttentionBlock"]
15
 
 
 
 
 
 
16
  def __init__(self, config: Phi2Config):
17
  super().__init__(config)
18
  self.config = config
 
37
  input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
38
  kv_cache: KVCache | None = None,
39
  key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
40
+ **kwargs,
41
  ) -> dict[str, Any]:
42
  if not kv_cache:
43
  kv_cache = KVCache(
 
138
  class Phi2ModelForCausalLM(Phi2PreTrainedModel):
139
  def __init__(self, config: Phi2Config) -> None:
140
  super().__init__(config)
141
+ self.model = Phi2Model(config)
142
  self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
143
  self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
144
  self.loss_fn = nn.CrossEntropyLoss()
 
150
  kv_cache: KVCache | None = None,
151
  key_padding_mask: torch.BoolTensor | None = None,
152
  labels: torch.LongTensor | None = None,
153
+ **kwargs,
154
  ) -> CausalLMOutputWithPast:
155
+ x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
156
  x = self.lm_head_layer_norm(x)
157
  logits = self.lm_head_linear(x).to(torch.float32)
158
  loss = (
streaming_inference.py CHANGED
@@ -20,22 +20,23 @@ if __name__ == "__main__":
20
  phi_model_state_dict = phi_model.state_dict()
21
  model_state_dict = {}
22
  for key, value in phi_model_state_dict.items():
23
- # transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
24
- # transformer.h.0.mlp.fc1.weight -> pretrained_model.parallel_blocks.0.mlp.fc1.weight
25
- # transformer.h.0.ln.weight -> pretrained_model.parallel_blocks.0.layer_norm.weight
26
- # transformer.h.0.mixer.Wqkv.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.Wqkv.weight
27
- # transformer.h.0.mixer.out_proj.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.fc_out.weight
28
  # lm_head.ln.weight -> lm_head_layer_norm.weight
29
  # lm_head.linear.weight -> lm_head_linear.weight
 
 
 
 
 
30
  if key.startswith("transformer"):
31
- key.replace("transformer.", "model.")
32
- key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
33
- key.replace(".h.", ".parallel_blocks")
34
- key.replace(".ln.", ".layer_norm.")
35
- key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
36
- key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
37
- key.replace(".lm_head.ln.", ".lm_head_layer_norm.")
38
- key.replace(".lm_head.linear.", ".lm_head_linear.")
 
39
  model_state_dict[key] = value
40
  model.load_state_dict(model_state_dict)
41
 
 
20
  phi_model_state_dict = phi_model.state_dict()
21
  model_state_dict = {}
22
  for key, value in phi_model_state_dict.items():
 
 
 
 
 
23
  # lm_head.ln.weight -> lm_head_layer_norm.weight
24
  # lm_head.linear.weight -> lm_head_linear.weight
25
+ # transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
26
+ # transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
27
+ # transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
28
+ # transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
29
+ # transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
30
  if key.startswith("transformer"):
31
+ key = key.replace("transformer.", "model.")
32
+ key = key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
33
+ key = key.replace(".h.", ".parallel_blocks.")
34
+ key = key.replace(".ln.", ".layer_norm.")
35
+ key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
36
+ key = key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
37
+ else:
38
+ key = key.replace("lm_head.ln.", "lm_head_layer_norm.")
39
+ key = key.replace("lm_head.linear.", "lm_head_linear.")
40
  model_state_dict[key] = value
41
  model.load_state_dict(model_state_dict)
42