|
import json |
|
from threading import Thread |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
|
from .phi2_configuration import Phi2Config |
|
from .phi2_model import Phi2ModelForCausalLM |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) |
|
token_streamer = TextIteratorStreamer(tokenizer) |
|
|
|
|
|
device = "cuda" |
|
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json"))) |
|
model = Phi2ModelForCausalLM(model_config).to(device) |
|
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) |
|
|
|
phi_model_state_dict = phi_model.state_dict() |
|
model_state_dict = {} |
|
for key, value in phi_model_state_dict.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if key.startswith("transformer"): |
|
key.replace("transformer.", "model.") |
|
key.replace(".embd.wte.", ".rotary_embedding.embeddings.") |
|
key.replace(".h.", ".parallel_blocks") |
|
key.replace(".ln.", ".layer_norm.") |
|
key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.") |
|
key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.") |
|
key.replace(".lm_head.ln.", ".lm_head_layer_norm.") |
|
key.replace(".lm_head.linear.", ".lm_head_linear.") |
|
model_state_dict[key] = value |
|
model.load_state_dict(model_state_dict) |
|
|
|
thread = Thread( |
|
target=model.generate, |
|
kwargs=dict( |
|
tokenizer( |
|
"Here is an essay on sea monkeys: ", |
|
return_tensors="pt", |
|
return_attention_mask=False, |
|
).to(device), |
|
streamer=token_streamer, |
|
max_new_tokens=500, |
|
eos_token_id=tokenizer.eos_token_id, |
|
), |
|
) |
|
thread.start() |
|
|
|
|
|
my_output = "" |
|
for new_token in token_streamer: |
|
my_output += new_token |
|
print(new_token, end="", flush=True) |
|
print() |
|
|