LLaMA-13B / gen.py
chansung's picture
Update gen.py
b50dce8
from typing import Tuple
import os
import torch
import time
import json
from pathlib import Path
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from llama import ModelArgs, Transformer, Tokenizer, LLaMA, default_quantize
from google.cloud import storage
bucket_name = os.environ.get("GCS_BUCKET")
llama_weight_path = "weights/llama"
tokenizer_weight_path = "weights/tokenizer"
def download_pretrained_models(
ckpt_path: str,
tokenizer_path: str
):
print("creating local directories...")
os.makedirs(llama_weight_path)
os.makedirs(tokenizer_weight_path)
print("initialize GCS client...")
storage_client = storage.Client.create_anonymous_client()
bucket = storage_client.bucket(bucket_name)
print(f"download {ckpt_path} model weights...")
blobs = bucket.list_blobs(prefix=f"{ckpt_path}/")
for blob in blobs:
filename = blob.name.split("/")[1]
print(f"-{filename}")
blob.download_to_filename(f"{llama_weight_path}/{filename}")
print(f"download {tokenizer_path} tokenizer weights...")
blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/")
for blob in blobs:
filename = blob.name.split("/")[1]
print(f"-{filename}")
blob.download_to_filename(f"{tokenizer_weight_path}/{filename}")
def get_pretrained_models(
ckpt_path: str,
tokenizer_path: str) -> LLaMA:
download_pretrained_models(ckpt_path, tokenizer_path)
# max_seq_len: int = 512, max_batch_size: int = 32
generator = load(
ckpt_dir=llama_weight_path,
tokenizer_path=tokenizer_weight_path,
max_seq_len=512,
max_batch_size=1
)
return generator
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
print(checkpoints)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=f"{tokenizer_path}/tokenizer.model")
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.HalfTensor)
print("Allocating transformer on host")
ctx_tok = default_quantize.set(True)
model = Transformer(model_args)
default_quantize.set(ctx_tok)
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
# ?
torch.set_default_tensor_type(torch.FloatTensor)
# load the state dict incrementally, to avoid memory problems
for i, ckpt in enumerate(checkpoints):
print(f"Loading checkpoint {i}")
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in model.named_parameters():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i : size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i : size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint[parameter_name]
del checkpoint
model.cuda()
generator = LLaMA(model, tokenizer)
print(
f"Loaded in {time.time() - start_time:.2f} seconds with {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GiB"
)
return generator
def get_output(
generator: LLaMA,
prompt: str,
max_gen_len: int = 256,
temperature: float = 0.8,
top_p: float = 0.95):
prompts = [prompt]
results = generator.generate(
prompts,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p
)
return results