BucketOfFish commited on
Commit
78f6f3b
1 Parent(s): 16cc769

Renaming state dict keys from Phi2

Browse files
Files changed (2) hide show
  1. phi2_model.py +8 -8
  2. streaming_inference.py +23 -33
phi2_model.py CHANGED
@@ -91,7 +91,7 @@ class Embedding(nn.Module):
91
  class Phi2Model(Phi2PreTrainedModel):
92
  def __init__(self, config: Phi2Config) -> None:
93
  super().__init__(config)
94
- self.embedding = Embedding(
95
  vocab_size=config.vocab_size,
96
  d_embedding=config.d_embedding,
97
  embd_pdrop=config.embd_pdrop,
@@ -117,10 +117,10 @@ class Phi2Model(Phi2PreTrainedModel):
117
 
118
  """
119
  def get_input_embeddings(self) -> nn.Embedding:
120
- return self.embedding.embeddings
121
 
122
  def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
123
- self.embedding.embeddings = new_embeddings
124
  """
125
 
126
  def forward(
@@ -129,7 +129,7 @@ class Phi2Model(Phi2PreTrainedModel):
129
  kv_cache: KVCache | None = None,
130
  key_padding_mask: torch.BoolTensor | None = None,
131
  ) -> torch.FloatTensor:
132
- x = self.embedding(input_ids)
133
  for block in self.parallel_blocks:
134
  x = block(
135
  x,
@@ -143,8 +143,8 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
143
  def __init__(self, config: Phi2Config) -> None:
144
  super().__init__(config)
145
  self.pretrained_model = Phi2Model(config)
146
- self.layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
147
- self.linear = nn.Linear(config.d_embedding, config.vocab_size)
148
  self.loss_fn = nn.CrossEntropyLoss()
149
  self.post_init() # calls self._init_weights() for all modules
150
 
@@ -156,8 +156,8 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
156
  labels: torch.LongTensor | None = None,
157
  ) -> CausalLMOutputWithPast:
158
  x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
159
- x = self.layer_norm(x)
160
- logits = self.linear(x).to(torch.float32)
161
  loss = (
162
  self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
163
  if labels is not None
 
91
  class Phi2Model(Phi2PreTrainedModel):
92
  def __init__(self, config: Phi2Config) -> None:
93
  super().__init__(config)
94
+ self.rotary_embedding = Embedding(
95
  vocab_size=config.vocab_size,
96
  d_embedding=config.d_embedding,
97
  embd_pdrop=config.embd_pdrop,
 
117
 
118
  """
119
  def get_input_embeddings(self) -> nn.Embedding:
120
+ return self.rotary_embedding.embeddings
121
 
122
  def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
123
+ self.rotary_embedding.embeddings = new_embeddings
124
  """
125
 
126
  def forward(
 
129
  kv_cache: KVCache | None = None,
130
  key_padding_mask: torch.BoolTensor | None = None,
131
  ) -> torch.FloatTensor:
132
+ x = self.rotary_embedding(input_ids)
133
  for block in self.parallel_blocks:
134
  x = block(
135
  x,
 
143
  def __init__(self, config: Phi2Config) -> None:
144
  super().__init__(config)
145
  self.pretrained_model = Phi2Model(config)
146
+ self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
147
+ self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
148
  self.loss_fn = nn.CrossEntropyLoss()
149
  self.post_init() # calls self._init_weights() for all modules
150
 
 
156
  labels: torch.LongTensor | None = None,
157
  ) -> CausalLMOutputWithPast:
158
  x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
159
+ x = self.lm_head_layer_norm(x)
160
+ logits = self.lm_head_linear(x).to(torch.float32)
161
  loss = (
162
  self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
163
  if labels is not None
streaming_inference.py CHANGED
@@ -1,43 +1,11 @@
1
  import json
2
  from threading import Thread
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
- import torch
5
 
6
  from .phi2_configuration import Phi2Config
7
  from .phi2_model import Phi2ModelForCausalLM
8
 
9
 
10
- # This works, but is not streaming
11
- """
12
- if __name__ == "__main__":
13
- device = "cuda"
14
-
15
- model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
16
- model = Phi2ModelForCausalLM(model_config).to(device)
17
- phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
18
- model.load_state_dict(phi_model.state_dict())
19
-
20
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
21
-
22
- text = "Write an essay on sea monkeys: "
23
- tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False).to(device)
24
- outputs = model.generate(**tokens, max_length=200)
25
- text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
26
- print(text)
27
- """
28
-
29
-
30
- # This is streaming, but does not work because you can't set trust_remote_code=True
31
- """
32
- if __name__ == "__main__":
33
- client = InferenceClient(model="microsoft/phi-2")
34
- text = "How do you make cheese?"
35
- for token in client.text_generation(text, max_new_tokens=500, stream=True):
36
- print(token, end="")
37
- """
38
-
39
-
40
- # This is trying the TextIteratorStreamer class
41
  if __name__ == "__main__":
42
  # make and load tokenizer, use tokenizer to initialize token_streamer
43
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
@@ -48,7 +16,29 @@ if __name__ == "__main__":
48
  model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
49
  model = Phi2ModelForCausalLM(model_config).to(device)
50
  phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
51
- model.load_state_dict(phi_model.state_dict())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  thread = Thread(
53
  target=model.generate,
54
  kwargs=dict(
 
1
  import json
2
  from threading import Thread
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
4
 
5
  from .phi2_configuration import Phi2Config
6
  from .phi2_model import Phi2ModelForCausalLM
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  if __name__ == "__main__":
10
  # make and load tokenizer, use tokenizer to initialize token_streamer
11
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
 
16
  model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
17
  model = Phi2ModelForCausalLM(model_config).to(device)
18
  phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
19
+
20
+ phi_model_state_dict = phi_model.state_dict()
21
+ model_state_dict = {}
22
+ for key, value in phi_model_state_dict.items():
23
+ # transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
24
+ # transformer.h.0.mlp.fc1.weight -> pretrained_model.parallel_blocks.0.mlp.fc1.weight
25
+ # transformer.h.0.ln.weight -> pretrained_model.parallel_blocks.0.layer_norm.weight
26
+ # transformer.h.0.mixer.Wqkv.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.Wqkv.weight
27
+ # transformer.h.0.mixer.out_proj.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.fc_out.weight
28
+ # lm_head.ln.weight -> lm_head_layer_norm.weight
29
+ # lm_head.linear.weight -> lm_head_linear.weight
30
+ if key.startswith("transformer"):
31
+ key.replace("transformer.", "model.")
32
+ key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
33
+ key.replace(".h.", ".parallel_blocks")
34
+ key.replace(".ln.", ".layer_norm.")
35
+ key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
36
+ key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
37
+ key.replace(".lm_head.ln.", ".lm_head_layer_norm.")
38
+ key.replace(".lm_head.linear.", ".lm_head_linear.")
39
+ model_state_dict[key] = value
40
+ model.load_state_dict(model_state_dict)
41
+
42
  thread = Thread(
43
  target=model.generate,
44
  kwargs=dict(