simplified_phi2 / streaming_inference.py
BucketOfFish's picture
Renaming state dict keys from Phi2
78f6f3b
raw
history blame
2.77 kB
import json
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from .phi2_configuration import Phi2Config
from .phi2_model import Phi2ModelForCausalLM
if __name__ == "__main__":
# make and load tokenizer, use tokenizer to initialize token_streamer
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
token_streamer = TextIteratorStreamer(tokenizer)
# make model and run model.generate(streamer=TextIteratorStreamer) on a thread
device = "cuda"
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
model = Phi2ModelForCausalLM(model_config).to(device)
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
phi_model_state_dict = phi_model.state_dict()
model_state_dict = {}
for key, value in phi_model_state_dict.items():
# transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
# transformer.h.0.mlp.fc1.weight -> pretrained_model.parallel_blocks.0.mlp.fc1.weight
# transformer.h.0.ln.weight -> pretrained_model.parallel_blocks.0.layer_norm.weight
# transformer.h.0.mixer.Wqkv.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.Wqkv.weight
# transformer.h.0.mixer.out_proj.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.fc_out.weight
# lm_head.ln.weight -> lm_head_layer_norm.weight
# lm_head.linear.weight -> lm_head_linear.weight
if key.startswith("transformer"):
key.replace("transformer.", "model.")
key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
key.replace(".h.", ".parallel_blocks")
key.replace(".ln.", ".layer_norm.")
key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
key.replace(".lm_head.ln.", ".lm_head_layer_norm.")
key.replace(".lm_head.linear.", ".lm_head_linear.")
model_state_dict[key] = value
model.load_state_dict(model_state_dict)
thread = Thread(
target=model.generate,
kwargs=dict(
tokenizer( # returns a torch dictionary
"Here is an essay on sea monkeys: ",
return_tensors="pt",
return_attention_mask=False,
).to(device),
streamer=token_streamer,
max_new_tokens=500,
eos_token_id=tokenizer.eos_token_id,
),
)
thread.start()
# generate
my_output = ""
for new_token in token_streamer:
my_output += new_token
print(new_token, end="", flush=True)
print()