simplified_phi2 / streaming_inference.py
BucketOfFish's picture
Fixed inference script bug and made deterministic
4f25dda
import json
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from .phi2_configuration import Phi2Config
from .phi2_model import Phi2ModelForCausalLM
if __name__ == "__main__":
# make and load tokenizer, use tokenizer to initialize token_streamer
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
token_streamer = TextIteratorStreamer(tokenizer)
# make model and run model.generate(streamer=TextIteratorStreamer) on a thread
device = "cuda"
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
model = Phi2ModelForCausalLM(model_config).to(device)
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
phi_model_state_dict = phi_model.state_dict()
model_state_dict = {}
for key, value in phi_model_state_dict.items():
# lm_head.ln.weight -> lm_head_layer_norm.weight
# lm_head.linear.weight -> lm_head_linear.weight
# transformer.embd.wte.weight -> model.embedding.embeddings.weight
# transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
# transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
# transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
# transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
if key.startswith("transformer"):
key = key.replace("transformer.", "model.")
key = key.replace(".embd.wte.", ".embedding.embeddings.")
key = key.replace(".h.", ".parallel_blocks.")
key = key.replace(".ln.", ".layer_norm.")
key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
key = key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
else:
key = key.replace("lm_head.ln.", "lm_head_layer_norm.")
key = key.replace("lm_head.linear.", "lm_head_linear.")
model_state_dict[key] = value
model.load_state_dict(model_state_dict)
model.eval()
thread = Thread(
target=model.generate,
kwargs=dict(
tokenizer( # returns a torch dictionary
"Here is an essay on sea monkeys: ",
return_tensors="pt",
return_attention_mask=False,
).to(device),
streamer=token_streamer,
max_new_tokens=500,
eos_token_id=tokenizer.eos_token_id,
),
)
thread.start()
# generate
my_output = ""
for new_token in token_streamer:
my_output += new_token
print(new_token, end="", flush=True)
print()