import json from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from .phi2_configuration import Phi2Config from .phi2_model import Phi2ModelForCausalLM if __name__ == "__main__": # make and load tokenizer, use tokenizer to initialize token_streamer tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) token_streamer = TextIteratorStreamer(tokenizer) # make model and run model.generate(streamer=TextIteratorStreamer) on a thread device = "cuda" model_config = Phi2Config(**json.load(open("simplified_phi2/config.json"))) model = Phi2ModelForCausalLM(model_config).to(device) phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) phi_model_state_dict = phi_model.state_dict() model_state_dict = {} for key, value in phi_model_state_dict.items(): # lm_head.ln.weight -> lm_head_layer_norm.weight # lm_head.linear.weight -> lm_head_linear.weight # transformer.embd.wte.weight -> model.embedding.embeddings.weight # transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight # transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight # transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight # transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight if key.startswith("transformer"): key = key.replace("transformer.", "model.") key = key.replace(".embd.wte.", ".embedding.embeddings.") key = key.replace(".h.", ".parallel_blocks.") key = key.replace(".ln.", ".layer_norm.") key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.") key = key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.") else: key = key.replace("lm_head.ln.", "lm_head_layer_norm.") key = key.replace("lm_head.linear.", "lm_head_linear.") model_state_dict[key] = value model.load_state_dict(model_state_dict) model.eval() thread = Thread( target=model.generate, kwargs=dict( tokenizer( # returns a torch dictionary "Here is an essay on sea monkeys: ", return_tensors="pt", return_attention_mask=False, ).to(device), streamer=token_streamer, max_new_tokens=500, eos_token_id=tokenizer.eos_token_id, ), ) thread.start() # generate my_output = "" for new_token in token_streamer: my_output += new_token print(new_token, end="", flush=True) print()