File size: 4,331 Bytes
9cb6afd 0538992 9cb6afd 0538992 9cb6afd b50dce8 9cb6afd b50dce8 9cb6afd b50dce8 9cb6afd b50dce8 9cb6afd b50dce8 9cb6afd b50dce8 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 9cb6afd 0538992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from typing import Tuple
import os
import torch
import time
import json
from pathlib import Path
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from llama import ModelArgs, Transformer, Tokenizer, LLaMA, default_quantize
from google.cloud import storage
bucket_name = os.environ.get("GCS_BUCKET")
llama_weight_path = "weights/llama"
tokenizer_weight_path = "weights/tokenizer"
def download_pretrained_models(
ckpt_path: str,
tokenizer_path: str
):
print("creating local directories...")
os.makedirs(llama_weight_path)
os.makedirs(tokenizer_weight_path)
print("initialize GCS client...")
storage_client = storage.Client.create_anonymous_client()
bucket = storage_client.bucket(bucket_name)
print(f"download {ckpt_path} model weights...")
blobs = bucket.list_blobs(prefix=f"{ckpt_path}/")
for blob in blobs:
filename = blob.name.split("/")[1]
print(f"-{filename}")
blob.download_to_filename(f"{llama_weight_path}/{filename}")
print(f"download {tokenizer_path} tokenizer weights...")
blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/")
for blob in blobs:
filename = blob.name.split("/")[1]
print(f"-{filename}")
blob.download_to_filename(f"{tokenizer_weight_path}/{filename}")
def get_pretrained_models(
ckpt_path: str,
tokenizer_path: str) -> LLaMA:
download_pretrained_models(ckpt_path, tokenizer_path)
# max_seq_len: int = 512, max_batch_size: int = 32
generator = load(
ckpt_dir=llama_weight_path,
tokenizer_path=tokenizer_weight_path,
max_seq_len=512,
max_batch_size=1
)
return generator
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
print(checkpoints)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=f"{tokenizer_path}/tokenizer.model")
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.HalfTensor)
print("Allocating transformer on host")
ctx_tok = default_quantize.set(True)
model = Transformer(model_args)
default_quantize.set(ctx_tok)
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
# ?
torch.set_default_tensor_type(torch.FloatTensor)
# load the state dict incrementally, to avoid memory problems
for i, ckpt in enumerate(checkpoints):
print(f"Loading checkpoint {i}")
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in model.named_parameters():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i : size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i : size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint[parameter_name]
del checkpoint
model.cuda()
generator = LLaMA(model, tokenizer)
print(
f"Loaded in {time.time() - start_time:.2f} seconds with {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GiB"
)
return generator
def get_output(
generator: LLaMA,
prompt: str,
max_gen_len: int = 256,
temperature: float = 0.8,
top_p: float = 0.95):
prompts = [prompt]
results = generator.generate(
prompts,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p
)
return results
|