BucketOfFish commited on
Commit
c572a14
1 Parent(s): 0f3418e

Corrected param name

Browse files
Files changed (2) hide show
  1. phi2_model.py +4 -4
  2. streaming_inference.py +2 -2
phi2_model.py CHANGED
@@ -87,7 +87,7 @@ class Embedding(nn.Module):
87
  class Phi2Model(Phi2PreTrainedModel):
88
  def __init__(self, config: Phi2Config) -> None:
89
  super().__init__(config)
90
- self.rotary_embedding = Embedding(
91
  vocab_size=config.vocab_size,
92
  d_embedding=config.d_embedding,
93
  embd_pdrop=config.embd_pdrop,
@@ -113,10 +113,10 @@ class Phi2Model(Phi2PreTrainedModel):
113
 
114
  """
115
  def get_input_embeddings(self) -> nn.Embedding:
116
- return self.rotary_embedding.embeddings
117
 
118
  def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
119
- self.rotary_embedding.embeddings = new_embeddings
120
  """
121
 
122
  def forward(
@@ -125,7 +125,7 @@ class Phi2Model(Phi2PreTrainedModel):
125
  kv_cache: KVCache | None = None,
126
  key_padding_mask: torch.BoolTensor | None = None,
127
  ) -> torch.FloatTensor:
128
- x = self.rotary_embedding(input_ids)
129
  for block in self.parallel_blocks:
130
  x = block(
131
  x,
 
87
  class Phi2Model(Phi2PreTrainedModel):
88
  def __init__(self, config: Phi2Config) -> None:
89
  super().__init__(config)
90
+ self.embedding = Embedding(
91
  vocab_size=config.vocab_size,
92
  d_embedding=config.d_embedding,
93
  embd_pdrop=config.embd_pdrop,
 
113
 
114
  """
115
  def get_input_embeddings(self) -> nn.Embedding:
116
+ return self.embedding.embeddings
117
 
118
  def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
119
+ self.embedding.embeddings = new_embeddings
120
  """
121
 
122
  def forward(
 
125
  kv_cache: KVCache | None = None,
126
  key_padding_mask: torch.BoolTensor | None = None,
127
  ) -> torch.FloatTensor:
128
+ x = self.embedding(input_ids)
129
  for block in self.parallel_blocks:
130
  x = block(
131
  x,
streaming_inference.py CHANGED
@@ -22,14 +22,14 @@ if __name__ == "__main__":
22
  for key, value in phi_model_state_dict.items():
23
  # lm_head.ln.weight -> lm_head_layer_norm.weight
24
  # lm_head.linear.weight -> lm_head_linear.weight
25
- # transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
26
  # transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
27
  # transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
28
  # transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
29
  # transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
30
  if key.startswith("transformer"):
31
  key = key.replace("transformer.", "model.")
32
- key = key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
33
  key = key.replace(".h.", ".parallel_blocks.")
34
  key = key.replace(".ln.", ".layer_norm.")
35
  key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
 
22
  for key, value in phi_model_state_dict.items():
23
  # lm_head.ln.weight -> lm_head_layer_norm.weight
24
  # lm_head.linear.weight -> lm_head_linear.weight
25
+ # transformer.embd.wte.weight -> model.embedding.embeddings.weight
26
  # transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
27
  # transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
28
  # transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
29
  # transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
30
  if key.startswith("transformer"):
31
  key = key.replace("transformer.", "model.")
32
+ key = key.replace(".embd.wte.", ".embedding.embeddings.")
33
  key = key.replace(".h.", ".parallel_blocks.")
34
  key = key.replace(".ln.", ".layer_norm.")
35
  key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")