simplified_phi2 / streaming_inference.py
BucketOfFish's picture
Fixed inference script bug and made deterministic
4f25dda
raw
history blame contribute delete
No virus
2.78 kB
import json
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from .phi2_configuration import Phi2Config
from .phi2_model import Phi2ModelForCausalLM
if __name__ == "__main__":
# make and load tokenizer, use tokenizer to initialize token_streamer
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
token_streamer = TextIteratorStreamer(tokenizer)
# make model and run model.generate(streamer=TextIteratorStreamer) on a thread
device = "cuda"
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
model = Phi2ModelForCausalLM(model_config).to(device)
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
phi_model_state_dict = phi_model.state_dict()
model_state_dict = {}
for key, value in phi_model_state_dict.items():
# lm_head.ln.weight -> lm_head_layer_norm.weight
# lm_head.linear.weight -> lm_head_linear.weight
# transformer.embd.wte.weight -> model.embedding.embeddings.weight
# transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
# transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
# transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
# transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
if key.startswith("transformer"):
key = key.replace("transformer.", "model.")
key = key.replace(".embd.wte.", ".embedding.embeddings.")
key = key.replace(".h.", ".parallel_blocks.")
key = key.replace(".ln.", ".layer_norm.")
key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
key = key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
else:
key = key.replace("lm_head.ln.", "lm_head_layer_norm.")
key = key.replace("lm_head.linear.", "lm_head_linear.")
model_state_dict[key] = value
model.load_state_dict(model_state_dict)
model.eval()
thread = Thread(
target=model.generate,
kwargs=dict(
tokenizer( # returns a torch dictionary
"Here is an essay on sea monkeys: ",
return_tensors="pt",
return_attention_mask=False,
).to(device),
streamer=token_streamer,
max_new_tokens=500,
eos_token_id=tokenizer.eos_token_id,
),
)
thread.start()
# generate
my_output = ""
for new_token in token_streamer:
my_output += new_token
print(new_token, end="", flush=True)
print()