Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import time | |
import warnings | |
from pathlib import Path | |
from typing import Optional | |
import lightning as L | |
import torch | |
from lit_llama import LLaMA, Tokenizer | |
from lit_llama.utils import EmptyInitOnDevice, lazy_load | |
def generate( | |
model: torch.nn.Module, | |
idx: torch.Tensor, | |
max_new_tokens: int, | |
max_seq_length: int, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
eos_id: Optional[int] = None, | |
tokenizer = None, | |
) -> torch.Tensor: | |
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | |
The implementation of this function is modified from A. Karpathy's nanoGPT. | |
Args: | |
model: The model to use. | |
idx: Tensor of shape (T) with indices of the prompt sequence. | |
max_new_tokens: The number of new tokens to generate. | |
max_seq_length: The maximum sequence length allowed. | |
temperature: Scales the predicted logits by 1 / temperature | |
top_k: If specified, only sample among the tokens with the k highest probabilities | |
eos_id: If specified, stop generating any more token once the <eos> token is triggered | |
""" | |
# create an empty tensor of the expected final shape and fill in the current tokens | |
# import pdb; pdb.set_trace() | |
if type(idx) == tuple: | |
# import pdb; pdb.set_trace() | |
T = idx[0].shape[-1] + idx[2].shape[-1] + len(idx[1]) | |
before_len = idx[0].shape[-1] | |
catted = torch.cat((idx[0], torch.zeros((1, len(idx[1]))).cuda(), idx[2]), dim=1).long() | |
idx = (catted, idx[1], before_len) | |
T_new = T + max_new_tokens | |
# import pdb; pdb.set_trace() | |
empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device) | |
empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device) | |
empty[:T] = idx[0] | |
idx = (empty, idx[1], [before_len]) | |
# import pdb; pdb.set_trace() | |
else: | |
# import pdb; pdb.set_trace() | |
T = idx.size(0) | |
T_new = T + max_new_tokens | |
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device) | |
empty[:T] = idx | |
idx = empty | |
# generate max_new_tokens tokens | |
# import pdb; pdb.set_trace() | |
for t in range(T, T_new): | |
if type(idx) == tuple: | |
idx_cond = idx[0][:t] | |
tmp = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:] | |
# import pdb; pdb.set_trace() | |
idx_cond = (tmp.view(1, -1), idx[1].unsqueeze(0), idx[2]) | |
else: | |
# ignore the not-filled-yet tokens | |
idx_cond = idx[:t] | |
# if the sequence context is growing too long we must crop it at max_seq_length | |
idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:] | |
# forward | |
if type(idx) == tuple: | |
logits = model(idx_cond, maxlen=idx_cond[0].size(1)) | |
else: | |
logits = model(idx_cond.view(1, -1)) | |
logits = logits[0, -1] / temperature | |
# import pdb; pdb.set_trace() | |
# optionally crop the logits to only the top k options | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
logits[logits < v[[-1]]] = -float("Inf") | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
idx_next = torch.multinomial(probs, num_samples=1) | |
# concatenate the new generation | |
if type(idx) == tuple: | |
seq = idx[0] | |
seq[t] = idx_next | |
idx = (seq, idx[1], idx[2]) | |
else: | |
idx[t] = idx_next | |
# if <eos> token is triggered, return the output (stop generation) | |
if idx_next == eos_id: | |
if type(idx) == tuple: | |
return idx[0][:t+1] | |
else: | |
return idx[:t + 1] # include the EOS token | |
if type(idx) == tuple: | |
return idx[0] | |
else: | |
return idx | |
def main( | |
prompt: str = "Hello, my name is", | |
*, | |
num_samples: int = 1, | |
max_new_tokens: int = 50, | |
top_k: int = 200, | |
temperature: float = 0.8, | |
checkpoint_path: Optional[Path] = None, | |
tokenizer_path: Optional[Path] = None, | |
model_size: str = "7B", | |
quantize: Optional[str] = None, | |
) -> None: | |
"""Generates text samples based on a pre-trained LLaMA model and tokenizer. | |
Args: | |
prompt: The prompt string to use for generating the samples. | |
num_samples: The number of text samples to generate. | |
max_new_tokens: The number of generation steps to take. | |
top_k: The number of top most probable tokens to consider in the sampling process. | |
temperature: A value controlling the randomness of the sampling process. Higher values result in more random | |
samples. | |
checkpoint_path: The checkpoint path to load. | |
tokenizer_path: The tokenizer path to load. | |
model_size: The model size to load. | |
quantize: Whether to quantize the model and using which method: | |
``"llm.int8"``: LLM.int8() mode, | |
``"gptq.int4"``: GPTQ 4-bit mode. | |
""" | |
if not checkpoint_path: | |
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth") | |
if not tokenizer_path: | |
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model") | |
assert checkpoint_path.is_file(), checkpoint_path | |
assert tokenizer_path.is_file(), tokenizer_path | |
fabric = L.Fabric(accelerator="cuda", devices=1) | |
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
print("Loading model ...", file=sys.stderr) | |
t0 = time.time() | |
with EmptyInitOnDevice( | |
device=fabric.device, dtype=dtype, quantization_mode=quantize | |
): | |
model = LLaMA.from_name(model_size) | |
checkpoint = lazy_load(checkpoint_path) | |
model.load_state_dict(checkpoint) | |
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) | |
model.eval() | |
model = fabric.setup_module(model) | |
tokenizer = Tokenizer(tokenizer_path) | |
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) | |
L.seed_everything(1234) | |
t0 = time.perf_counter() | |
for _ in range(num_samples): | |
y = generate( | |
model, | |
encoded_prompt, | |
max_new_tokens, | |
model.config.block_size, # type: ignore[union-attr,arg-type] | |
temperature=temperature, | |
top_k=top_k, | |
) | |
print(tokenizer.decode(y)) | |
t = time.perf_counter() - t0 | |
print(f"\n\nTime for inference: {t:.02f} sec total, {num_samples * max_new_tokens / t:.02f} tokens/sec", file=sys.stderr) | |
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) | |
if __name__ == "__main__": | |
from jsonargparse import CLI | |
torch.set_float32_matmul_precision("high") | |
warnings.filterwarnings( | |
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 | |
"ignore", | |
message="ComplexHalf support is experimental and many operators don't support it yet" | |
) | |
warnings.filterwarnings( | |
# Triggered in bitsandbytes/autograd/_functions.py:298 | |
"ignore", | |
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization", | |
) | |
CLI(main) | |