|
import sys |
|
import time |
|
import warnings |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import lightning as L |
|
import torch |
|
|
|
|
|
wd = Path(__file__).parent.parent.resolve() |
|
sys.path.append(str(wd)) |
|
|
|
from lit_llama import LLaMA, Tokenizer |
|
from lit_llama.utils import lazy_load, llama_model_lookup, quantization |
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
model: LLaMA, |
|
idx: torch.Tensor, |
|
max_new_tokens: int, |
|
*, |
|
max_seq_length: Optional[int] = None, |
|
temperature: float = 1.0, |
|
top_k: Optional[int] = None, |
|
eos_id: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. |
|
|
|
The implementation of this function is modified from A. Karpathy's nanoGPT. |
|
|
|
Args: |
|
model: The model to use. |
|
idx: Tensor of shape (T) with indices of the prompt sequence. |
|
max_new_tokens: The number of new tokens to generate. |
|
max_seq_length: The maximum sequence length allowed. |
|
temperature: Scales the predicted logits by 1 / temperature |
|
top_k: If specified, only sample among the tokens with the k highest probabilities |
|
eos_id: If specified, stop generating any more token once the <eos> token is triggered |
|
""" |
|
|
|
T = idx.size(0) |
|
T_new = T + max_new_tokens |
|
if max_seq_length is None: |
|
max_seq_length = min(T_new, model.config.block_size) |
|
|
|
device, dtype = idx.device, idx.dtype |
|
|
|
empty = torch.empty(T_new, dtype=dtype, device=device) |
|
empty[:T] = idx |
|
idx = empty |
|
input_pos = torch.arange(0, T, device=device) |
|
|
|
if idx.device.type == "xla": |
|
import torch_xla.core.xla_model as xm |
|
|
|
xm.mark_step() |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
x = idx.index_select(0, input_pos).view(1, -1) |
|
|
|
|
|
logits = model(x, max_seq_length, input_pos) |
|
logits = logits[0, -1] / temperature |
|
|
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits = torch.where(logits < v[[-1]], -float("Inf"), logits) |
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) |
|
|
|
|
|
input_pos = input_pos[-1:] + 1 |
|
|
|
if idx.device.type == "xla": |
|
xm.mark_step() |
|
|
|
|
|
idx = idx.index_copy(0, input_pos, idx_next) |
|
|
|
|
|
if idx_next == eos_id: |
|
return idx[:input_pos] |
|
|
|
return idx |
|
|
|
|
|
def main( |
|
prompt: str = "Hello, my name is", |
|
*, |
|
num_samples: int = 1, |
|
max_new_tokens: int = 50, |
|
top_k: int = 200, |
|
temperature: float = 0.8, |
|
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), |
|
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), |
|
quantize: Optional[str] = None, |
|
) -> None: |
|
"""Generates text samples based on a pre-trained LLaMA model and tokenizer. |
|
|
|
Args: |
|
prompt: The prompt string to use for generating the samples. |
|
num_samples: The number of text samples to generate. |
|
max_new_tokens: The number of generation steps to take. |
|
top_k: The number of top most probable tokens to consider in the sampling process. |
|
temperature: A value controlling the randomness of the sampling process. Higher values result in more random |
|
samples. |
|
checkpoint_path: The checkpoint path to load. |
|
tokenizer_path: The tokenizer path to load. |
|
quantize: Whether to quantize the model and using which method: |
|
``"llm.int8"``: LLM.int8() mode, |
|
``"gptq.int4"``: GPTQ 4-bit mode. |
|
""" |
|
assert checkpoint_path.is_file(), checkpoint_path |
|
assert tokenizer_path.is_file(), tokenizer_path |
|
|
|
precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true" |
|
fabric = L.Fabric(devices=1, precision=precision) |
|
|
|
print("Loading model ...", file=sys.stderr) |
|
t0 = time.time() |
|
with lazy_load(checkpoint_path) as checkpoint: |
|
name = llama_model_lookup(checkpoint) |
|
|
|
with fabric.init_module(empty_init=True), quantization(mode=quantize): |
|
model = LLaMA.from_name(name) |
|
|
|
model.load_state_dict(checkpoint) |
|
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) |
|
|
|
model.eval() |
|
model = fabric.setup(model) |
|
|
|
tokenizer = Tokenizer(tokenizer_path) |
|
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) |
|
prompt_length = encoded.size(0) |
|
|
|
L.seed_everything(1234) |
|
for i in range(num_samples): |
|
t0 = time.perf_counter() |
|
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k) |
|
t = time.perf_counter() - t0 |
|
|
|
model.reset_cache() |
|
print(tokenizer.decode(y)) |
|
tokens_generated = y.size(0) - prompt_length |
|
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) |
|
if fabric.device.type == "cuda": |
|
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) |
|
|
|
|
|
if __name__ == "__main__": |
|
from jsonargparse import CLI |
|
|
|
torch.set_float32_matmul_precision("high") |
|
warnings.filterwarnings( |
|
|
|
"ignore", |
|
message="ComplexHalf support is experimental and many operators don't support it yet" |
|
) |
|
warnings.filterwarnings( |
|
|
|
"ignore", |
|
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization", |
|
) |
|
CLI(main) |
|
|