from typing import Tuple import os import torch import time import json from pathlib import Path os.environ["BITSANDBYTES_NOWELCOME"] = "1" from llama import ModelArgs, Transformer, Tokenizer, LLaMA, default_quantize from google.cloud import storage bucket_name = os.environ.get("GCS_BUCKET") llama_weight_path = "weights/llama" tokenizer_weight_path = "weights/tokenizer" def download_pretrained_models( ckpt_path: str, tokenizer_path: str ): print("creating local directories...") os.makedirs(llama_weight_path) os.makedirs(tokenizer_weight_path) print("initialize GCS client...") storage_client = storage.Client.create_anonymous_client() bucket = storage_client.bucket(bucket_name) print(f"download {ckpt_path} model weights...") blobs = bucket.list_blobs(prefix=f"{ckpt_path}/") for blob in blobs: filename = blob.name.split("/")[1] print(f"-{filename}") blob.download_to_filename(f"{llama_weight_path}/{filename}") print(f"download {tokenizer_path} tokenizer weights...") blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/") for blob in blobs: filename = blob.name.split("/")[1] print(f"-{filename}") blob.download_to_filename(f"{tokenizer_weight_path}/{filename}") def get_pretrained_models( ckpt_path: str, tokenizer_path: str) -> LLaMA: download_pretrained_models(ckpt_path, tokenizer_path) # max_seq_len: int = 512, max_batch_size: int = 32 generator = load( ckpt_dir=llama_weight_path, tokenizer_path=tokenizer_weight_path, max_seq_len=512, max_batch_size=1 ) return generator def load( ckpt_dir: str, tokenizer_path: str, max_seq_len: int, max_batch_size: int, ) -> LLaMA: start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) print(checkpoints) with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params ) tokenizer = Tokenizer(model_path=f"{tokenizer_path}/tokenizer.model") model_args.vocab_size = tokenizer.n_words torch.set_default_tensor_type(torch.HalfTensor) print("Allocating transformer on host") ctx_tok = default_quantize.set(True) model = Transformer(model_args) default_quantize.set(ctx_tok) key_to_dim = { "w1": 0, "w2": -1, "w3": 0, "wo": -1, "wq": 0, "wk": 0, "wv": 0, "output": 0, "tok_embeddings": -1, "ffn_norm": None, "attention_norm": None, "norm": None, "rope": None, } # ? torch.set_default_tensor_type(torch.FloatTensor) # load the state dict incrementally, to avoid memory problems for i, ckpt in enumerate(checkpoints): print(f"Loading checkpoint {i}") checkpoint = torch.load(ckpt, map_location="cpu") for parameter_name, parameter in model.named_parameters(): short_name = parameter_name.split(".")[-2] if key_to_dim[short_name] is None and i == 0: parameter.data = checkpoint[parameter_name] elif key_to_dim[short_name] == 0: size = checkpoint[parameter_name].size(0) parameter.data[size * i : size * (i + 1), :] = checkpoint[ parameter_name ] elif key_to_dim[short_name] == -1: size = checkpoint[parameter_name].size(-1) parameter.data[:, size * i : size * (i + 1)] = checkpoint[ parameter_name ] del checkpoint[parameter_name] del checkpoint model.cuda() generator = LLaMA(model, tokenizer) print( f"Loaded in {time.time() - start_time:.2f} seconds with {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GiB" ) return generator def get_output( generator: LLaMA, prompt: str, max_gen_len: int = 256, temperature: float = 0.8, top_p: float = 0.95): prompts = [prompt] results = generator.generate( prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p ) return results