BucketOfFish
commited on
Commit
•
c572a14
1
Parent(s):
0f3418e
Corrected param name
Browse files- phi2_model.py +4 -4
- streaming_inference.py +2 -2
phi2_model.py
CHANGED
@@ -87,7 +87,7 @@ class Embedding(nn.Module):
|
|
87 |
class Phi2Model(Phi2PreTrainedModel):
|
88 |
def __init__(self, config: Phi2Config) -> None:
|
89 |
super().__init__(config)
|
90 |
-
self.
|
91 |
vocab_size=config.vocab_size,
|
92 |
d_embedding=config.d_embedding,
|
93 |
embd_pdrop=config.embd_pdrop,
|
@@ -113,10 +113,10 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
113 |
|
114 |
"""
|
115 |
def get_input_embeddings(self) -> nn.Embedding:
|
116 |
-
return self.
|
117 |
|
118 |
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
119 |
-
self.
|
120 |
"""
|
121 |
|
122 |
def forward(
|
@@ -125,7 +125,7 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
125 |
kv_cache: KVCache | None = None,
|
126 |
key_padding_mask: torch.BoolTensor | None = None,
|
127 |
) -> torch.FloatTensor:
|
128 |
-
x = self.
|
129 |
for block in self.parallel_blocks:
|
130 |
x = block(
|
131 |
x,
|
|
|
87 |
class Phi2Model(Phi2PreTrainedModel):
|
88 |
def __init__(self, config: Phi2Config) -> None:
|
89 |
super().__init__(config)
|
90 |
+
self.embedding = Embedding(
|
91 |
vocab_size=config.vocab_size,
|
92 |
d_embedding=config.d_embedding,
|
93 |
embd_pdrop=config.embd_pdrop,
|
|
|
113 |
|
114 |
"""
|
115 |
def get_input_embeddings(self) -> nn.Embedding:
|
116 |
+
return self.embedding.embeddings
|
117 |
|
118 |
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
119 |
+
self.embedding.embeddings = new_embeddings
|
120 |
"""
|
121 |
|
122 |
def forward(
|
|
|
125 |
kv_cache: KVCache | None = None,
|
126 |
key_padding_mask: torch.BoolTensor | None = None,
|
127 |
) -> torch.FloatTensor:
|
128 |
+
x = self.embedding(input_ids)
|
129 |
for block in self.parallel_blocks:
|
130 |
x = block(
|
131 |
x,
|
streaming_inference.py
CHANGED
@@ -22,14 +22,14 @@ if __name__ == "__main__":
|
|
22 |
for key, value in phi_model_state_dict.items():
|
23 |
# lm_head.ln.weight -> lm_head_layer_norm.weight
|
24 |
# lm_head.linear.weight -> lm_head_linear.weight
|
25 |
-
# transformer.embd.wte.weight -> model.
|
26 |
# transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
|
27 |
# transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
|
28 |
# transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
|
29 |
# transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
|
30 |
if key.startswith("transformer"):
|
31 |
key = key.replace("transformer.", "model.")
|
32 |
-
key = key.replace(".embd.wte.", ".
|
33 |
key = key.replace(".h.", ".parallel_blocks.")
|
34 |
key = key.replace(".ln.", ".layer_norm.")
|
35 |
key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
|
|
|
22 |
for key, value in phi_model_state_dict.items():
|
23 |
# lm_head.ln.weight -> lm_head_layer_norm.weight
|
24 |
# lm_head.linear.weight -> lm_head_linear.weight
|
25 |
+
# transformer.embd.wte.weight -> model.embedding.embeddings.weight
|
26 |
# transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
|
27 |
# transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
|
28 |
# transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
|
29 |
# transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
|
30 |
if key.startswith("transformer"):
|
31 |
key = key.replace("transformer.", "model.")
|
32 |
+
key = key.replace(".embd.wte.", ".embedding.embeddings.")
|
33 |
key = key.replace(".h.", ".parallel_blocks.")
|
34 |
key = key.replace(".ln.", ".layer_norm.")
|
35 |
key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
|