File size: 2,783 Bytes
dc6124b
 
 
 
16cc769
 
dc6124b
 
 
 
 
 
 
 
 
16cc769
 
dc6124b
78f6f3b
 
 
 
 
 
c572a14
10aca20
 
 
 
78f6f3b
10aca20
c572a14
10aca20
 
 
 
 
 
 
78f6f3b
 
4f25dda
78f6f3b
dc6124b
 
 
4f25dda
dc6124b
 
 
c07c430
dc6124b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from .phi2_configuration import Phi2Config
from .phi2_model import Phi2ModelForCausalLM


if __name__ == "__main__":
    # make and load tokenizer, use tokenizer to initialize token_streamer
    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
    token_streamer = TextIteratorStreamer(tokenizer)

    # make model and run model.generate(streamer=TextIteratorStreamer) on a thread
    device = "cuda"
    model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
    model = Phi2ModelForCausalLM(model_config).to(device)
    phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)

    phi_model_state_dict = phi_model.state_dict()
    model_state_dict = {}
    for key, value in phi_model_state_dict.items():
        # lm_head.ln.weight -> lm_head_layer_norm.weight
        # lm_head.linear.weight -> lm_head_linear.weight
        # transformer.embd.wte.weight -> model.embedding.embeddings.weight
        # transformer.h.0.mlp.fc1.weight -> model.parallel_blocks.0.mlp.fc1.weight
        # transformer.h.0.ln.weight -> model.parallel_blocks.0.layer_norm.weight
        # transformer.h.0.mixer.Wqkv.weight -> model.parallel_blocks.0.multi_head_attention.Wqkv.weight
        # transformer.h.0.mixer.out_proj.weight -> model.parallel_blocks.0.multi_head_attention.fc_out.weight
        if key.startswith("transformer"):
            key = key.replace("transformer.", "model.")
            key = key.replace(".embd.wte.", ".embedding.embeddings.")
            key = key.replace(".h.", ".parallel_blocks.")
            key = key.replace(".ln.", ".layer_norm.")
            key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
            key = key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
        else:
            key = key.replace("lm_head.ln.", "lm_head_layer_norm.")
            key = key.replace("lm_head.linear.", "lm_head_linear.")
        model_state_dict[key] = value
    model.load_state_dict(model_state_dict)
    model.eval()

    thread = Thread(
        target=model.generate,
        kwargs=dict(
            tokenizer(  # returns a torch dictionary
                "Here is an essay on sea monkeys: ",
                return_tensors="pt",
                return_attention_mask=False,
            ).to(device),
            streamer=token_streamer,
            max_new_tokens=500,
            eos_token_id=tokenizer.eos_token_id,
        ),
    )
    thread.start()

    # generate
    my_output = ""
    for new_token in token_streamer:
        my_output += new_token
        print(new_token, end="", flush=True)
    print()