File size: 4,331 Bytes
9cb6afd
 
 
0538992
9cb6afd
 
 
 
0538992
 
9cb6afd
 
 
 
 
 
 
 
 
 
 
 
b50dce8
9cb6afd
 
 
b50dce8
9cb6afd
 
 
b50dce8
9cb6afd
 
 
b50dce8
9cb6afd
 
b50dce8
9cb6afd
 
 
b50dce8
9cb6afd
 
 
 
0538992
9cb6afd
 
0538992
 
 
 
 
 
 
9cb6afd
0538992
 
 
 
 
 
 
 
9cb6afd
0538992
 
9cb6afd
0538992
9cb6afd
 
0538992
 
 
 
9cb6afd
0538992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cb6afd
0538992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cb6afd
 
0538992
 
 
9cb6afd
 
0538992
9cb6afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0538992
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from typing import Tuple

import os
import torch
import time
import json
from pathlib import Path

os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from llama import ModelArgs, Transformer, Tokenizer, LLaMA, default_quantize

from google.cloud import storage

bucket_name = os.environ.get("GCS_BUCKET")

llama_weight_path = "weights/llama"
tokenizer_weight_path = "weights/tokenizer"

def download_pretrained_models(
    ckpt_path: str,
    tokenizer_path: str
):
    print("creating local directories...")
    os.makedirs(llama_weight_path)
    os.makedirs(tokenizer_weight_path)

    print("initialize GCS client...")
    storage_client = storage.Client.create_anonymous_client()
    bucket = storage_client.bucket(bucket_name)

    print(f"download {ckpt_path} model weights...")
    blobs = bucket.list_blobs(prefix=f"{ckpt_path}/")
    for blob in blobs:
        filename = blob.name.split("/")[1]
        print(f"-{filename}")
        blob.download_to_filename(f"{llama_weight_path}/{filename}")

    print(f"download {tokenizer_path} tokenizer weights...")
    blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/")
    for blob in blobs:
        filename = blob.name.split("/")[1]
        print(f"-{filename}")
        blob.download_to_filename(f"{tokenizer_weight_path}/{filename}")    

def get_pretrained_models(
    ckpt_path: str, 
    tokenizer_path: str) -> LLaMA:

    download_pretrained_models(ckpt_path, tokenizer_path)
    # max_seq_len: int = 512, max_batch_size: int = 32
    generator = load(
        ckpt_dir=llama_weight_path, 
        tokenizer_path=tokenizer_weight_path,
        max_seq_len=512,
        max_batch_size=1
    )

    return generator

def load(
    ckpt_dir: str,
    tokenizer_path: str,
    max_seq_len: int,
    max_batch_size: int,
) -> LLaMA:
    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
    print(checkpoints)

    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
    )
    tokenizer = Tokenizer(model_path=f"{tokenizer_path}/tokenizer.model")
    model_args.vocab_size = tokenizer.n_words

    torch.set_default_tensor_type(torch.HalfTensor)
    print("Allocating transformer on host")
    ctx_tok = default_quantize.set(True)
    model = Transformer(model_args)
    default_quantize.set(ctx_tok)
    key_to_dim = {
        "w1": 0,
        "w2": -1,
        "w3": 0,
        "wo": -1,
        "wq": 0,
        "wk": 0,
        "wv": 0,
        "output": 0,
        "tok_embeddings": -1,
        "ffn_norm": None,
        "attention_norm": None,
        "norm": None,
        "rope": None,
    }

    # ?
    torch.set_default_tensor_type(torch.FloatTensor)

    # load the state dict incrementally, to avoid memory problems
    for i, ckpt in enumerate(checkpoints):
        print(f"Loading checkpoint {i}")
        checkpoint = torch.load(ckpt, map_location="cpu")
        for parameter_name, parameter in model.named_parameters():
            short_name = parameter_name.split(".")[-2]
            if key_to_dim[short_name] is None and i == 0:
                parameter.data = checkpoint[parameter_name]
            elif key_to_dim[short_name] == 0:
                size = checkpoint[parameter_name].size(0)
                parameter.data[size * i : size * (i + 1), :] = checkpoint[
                    parameter_name
                ]
            elif key_to_dim[short_name] == -1:
                size = checkpoint[parameter_name].size(-1)
                parameter.data[:, size * i : size * (i + 1)] = checkpoint[
                    parameter_name
                ]
            del checkpoint[parameter_name]
        del checkpoint

    model.cuda()

    generator = LLaMA(model, tokenizer)
    print(
        f"Loaded in {time.time() - start_time:.2f} seconds with {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GiB"
    )
    return generator


def get_output(
    generator: LLaMA,
    prompt: str,
    max_gen_len: int = 256,
    temperature: float = 0.8, 
    top_p: float = 0.95):
    
    prompts = [prompt]
    results = generator.generate(
        prompts, 
        max_gen_len=max_gen_len, 
        temperature=temperature, 
        top_p=top_p
    )

    return results