Italia-9B / generate.py
leafspark's picture
add model
56811f1 verified
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
# Derivated from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/generate/base.py
import os
import sys
import time
from pathlib import Path
from typing import Any, Optional
import torch
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from modello_italia import Italia, ItaliaConfig, Tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MI_SYSTEM_PROMPT_SHORT = (
"Tu sei Modello Italia, un modello di linguaggio naturale addestrato da iGenius."
)
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
if torch._dynamo.is_compiling():
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
distribution = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / distribution, dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
def sample(
logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
) -> torch.Tensor:
logits = logits[0, -1]
# optionally crop the logits to only the top k options
if top_k is not None:
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
# optionally scale the logits and sample from a probability distribution
if temperature > 0.0:
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
return multinomial_num_samples_1(probs)
return torch.argmax(logits, dim=-1, keepdim=True)
def next_token(
model: Italia, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
logits = model(x, input_pos)
next = sample(logits, **kwargs)
return next.to(dtype=x.dtype)
@torch.inference_mode()
def generate(
model: Italia,
prompt: torch.Tensor,
tokenizer: Tokenizer,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
prompt: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
tokenizer: Tokenizer instance to decode generated tokens
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
"""
T = prompt.size(0)
assert max_returned_tokens > T
device = prompt.device
tokens = [prompt]
input_pos = torch.tensor([T], device=device)
token = next_token(
model,
torch.arange(0, T, device=device),
prompt.view(1, -1),
temperature=temperature,
top_k=top_k,
).clone()
tokens.append(token)
for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k
).clone()
tokens.append(token)
if token == tokenizer.eos_id:
break
os.system('cls' if os.name == 'nt' else 'clear')
print(tokenizer.decode(torch.cat(tokens)[T:]))
input_pos = input_pos.add_(1)
return torch.cat(tokens)
@torch.inference_mode()
def main(
prompt: str = "Ciao, chi sei?",
*,
num_samples: int = 1,
max_new_tokens: int = 200,
top_k: Optional[int] = 200,
temperature: float = 0.4,
checkpoint_dir: Path = Path("."),
) -> None:
"""Generates text samples based on a pre-trained model and tokenizer.
Args:
prompt: The prompt string to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_dir: The checkpoint directory to load.
"""
config = ItaliaConfig()
checkpoint_path = checkpoint_dir / "italia.bin"
tokenizer = Tokenizer(checkpoint_dir)
prompt = f"<|system|>{MI_SYSTEM_PROMPT_SHORT}\n<|user|>{prompt}\n<|assistant|>"
encoded = tokenizer.encode(prompt, device=device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
print(f"Loading model {str(checkpoint_path)!r}")
t0 = time.perf_counter()
model = Italia(config)
model.load_state_dict(torch.load(checkpoint_path, mmap=True))
model.to(device)
print(
f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.",
file=sys.stderr,
)
model.max_seq_length = max_returned_tokens
model.set_kv_cache(batch_size=1, device=device)
model.eval()
for _ in range(num_samples):
t0 = time.perf_counter()
y = generate(
model,
encoded,
tokenizer,
max_returned_tokens,
temperature=temperature,
top_k=top_k,
)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
#print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec")
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
CLI(main)