|
|
|
|
|
|
|
import json |
|
import os |
|
import sys |
|
import time |
|
from pathlib import Path |
|
from typing import List, Literal, Optional, Tuple, TypedDict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairscale.nn.model_parallel.initialize import ( |
|
get_model_parallel_rank, |
|
initialize_model_parallel, |
|
model_parallel_is_initialized, |
|
) |
|
|
|
from llama.model import ModelArgs, Transformer |
|
from llama.tokenizer import Tokenizer |
|
|
|
Role = Literal["system", "user", "assistant"] |
|
|
|
|
|
class Message(TypedDict): |
|
role: Role |
|
content: str |
|
|
|
|
|
class CompletionPrediction(TypedDict, total=False): |
|
generation: str |
|
tokens: List[str] |
|
logprobs: List[float] |
|
|
|
|
|
class ChatPrediction(TypedDict, total=False): |
|
generation: Message |
|
tokens: List[str] |
|
logprobs: List[float] |
|
|
|
|
|
Dialog = List[Message] |
|
|
|
B_INST, E_INST = "[INST]", "[/INST]" |
|
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" |
|
DEFAULT_SYSTEM_PROMPT = """\ |
|
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. |
|
|
|
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" |
|
|
|
|
|
class Llama: |
|
@staticmethod |
|
def build( |
|
ckpt_dir: str, |
|
tokenizer_path: str, |
|
max_seq_len: int, |
|
max_batch_size: int, |
|
model_parallel_size: Optional[int] = None, |
|
) -> "Llama": |
|
if not torch.distributed.is_initialized(): |
|
torch.distributed.init_process_group("nccl") |
|
if not model_parallel_is_initialized(): |
|
if model_parallel_size is None: |
|
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
initialize_model_parallel(model_parallel_size) |
|
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
torch.cuda.set_device(local_rank) |
|
|
|
|
|
torch.manual_seed(1) |
|
|
|
if local_rank > 0: |
|
sys.stdout = open(os.devnull, "w") |
|
|
|
start_time = time.time() |
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) |
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" |
|
assert model_parallel_size == len( |
|
checkpoints |
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" |
|
ckpt_path = checkpoints[get_model_parallel_rank()] |
|
checkpoint = torch.load(ckpt_path, map_location="cpu") |
|
with open(Path(ckpt_dir) / "params.json", "r") as f: |
|
params = json.loads(f.read()) |
|
|
|
model_args: ModelArgs = ModelArgs( |
|
max_seq_len=max_seq_len, |
|
max_batch_size=max_batch_size, |
|
**params, |
|
) |
|
tokenizer = Tokenizer(model_path=tokenizer_path) |
|
model_args.vocab_size = tokenizer.n_words |
|
torch.set_default_tensor_type(torch.cuda.HalfTensor) |
|
model = Transformer(model_args) |
|
model.load_state_dict(checkpoint, strict=False) |
|
print(f"Loaded in {time.time() - start_time:.2f} seconds") |
|
|
|
return Llama(model, tokenizer) |
|
|
|
def __init__(self, model: Transformer, tokenizer: Tokenizer): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
prompt_tokens: List[List[int]], |
|
max_gen_len: int, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
logprobs: bool = False, |
|
echo: bool = False, |
|
) -> Tuple[List[List[int]], Optional[List[List[float]]]]: |
|
params = self.model.params |
|
bsz = len(prompt_tokens) |
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) |
|
|
|
min_prompt_len = min(len(t) for t in prompt_tokens) |
|
max_prompt_len = max(len(t) for t in prompt_tokens) |
|
assert max_prompt_len <= params.max_seq_len |
|
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) |
|
|
|
pad_id = self.tokenizer.pad_id |
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") |
|
for k, t in enumerate(prompt_tokens): |
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") |
|
if logprobs: |
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float) |
|
|
|
prev_pos = 0 |
|
eos_reached = torch.tensor([False] * bsz, device="cuda") |
|
input_text_mask = tokens != pad_id |
|
for cur_pos in range(min_prompt_len, total_len): |
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) |
|
if logprobs: |
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( |
|
input=logits.transpose(1, 2), |
|
target=tokens[:, prev_pos + 1 : cur_pos + 1], |
|
reduction="none", |
|
ignore_index=pad_id, |
|
) |
|
if temperature > 0: |
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1) |
|
next_token = sample_top_p(probs, top_p) |
|
else: |
|
next_token = torch.argmax(logits[:, -1], dim=-1) |
|
|
|
next_token = next_token.reshape(-1) |
|
|
|
next_token = torch.where( |
|
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token |
|
) |
|
tokens[:, cur_pos] = next_token |
|
eos_reached |= (~input_text_mask[:, cur_pos]) & ( |
|
next_token == self.tokenizer.eos_id |
|
) |
|
prev_pos = cur_pos |
|
if all(eos_reached): |
|
break |
|
|
|
if logprobs: |
|
token_logprobs = token_logprobs.tolist() |
|
out_tokens, out_logprobs = [], [] |
|
for i, toks in enumerate(tokens.tolist()): |
|
|
|
start = 0 if echo else len(prompt_tokens[i]) |
|
toks = toks[start : len(prompt_tokens[i]) + max_gen_len] |
|
if logprobs: |
|
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] |
|
|
|
if self.tokenizer.eos_id in toks: |
|
eos_idx = toks.index(self.tokenizer.eos_id) |
|
toks = toks[:eos_idx] |
|
probs = probs[:eos_idx] if logprobs else None |
|
out_tokens.append(toks) |
|
out_logprobs.append(probs) |
|
return (out_tokens, out_logprobs if logprobs else None) |
|
|
|
def text_completion( |
|
self, |
|
prompts: List[str], |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
max_gen_len: Optional[int] = None, |
|
logprobs: bool = False, |
|
echo: bool = False, |
|
) -> List[CompletionPrediction]: |
|
if max_gen_len is None: |
|
max_gen_len = self.model.params.max_seq_len - 1 |
|
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] |
|
generation_tokens, generation_logprobs = self.generate( |
|
prompt_tokens=prompt_tokens, |
|
max_gen_len=max_gen_len, |
|
temperature=temperature, |
|
top_p=top_p, |
|
logprobs=logprobs, |
|
echo=echo, |
|
) |
|
if logprobs: |
|
return [ |
|
{ |
|
"generation": self.tokenizer.decode(t), |
|
"tokens": [self.tokenizer.decode(x) for x in t], |
|
"logprobs": logprobs_i, |
|
} |
|
for t, logprobs_i in zip(generation_tokens, generation_logprobs) |
|
] |
|
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] |
|
|
|
def chat_completion( |
|
self, |
|
dialogs: List[Dialog], |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
max_gen_len: Optional[int] = None, |
|
logprobs: bool = False, |
|
) -> List[ChatPrediction]: |
|
if max_gen_len is None: |
|
max_gen_len = self.model.params.max_seq_len - 1 |
|
prompt_tokens = [] |
|
for dialog in dialogs: |
|
if dialog[0]["role"] != "system": |
|
dialog = [ |
|
{ |
|
"role": "system", |
|
"content": DEFAULT_SYSTEM_PROMPT, |
|
} |
|
] + dialog |
|
dialog = [ |
|
{ |
|
"role": dialog[1]["role"], |
|
"content": B_SYS |
|
+ dialog[0]["content"] |
|
+ E_SYS |
|
+ dialog[1]["content"], |
|
} |
|
] + dialog[2:] |
|
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( |
|
[msg["role"] == "assistant" for msg in dialog[1::2]] |
|
), ( |
|
"model only supports 'system', 'user' and 'assistant' roles, " |
|
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)" |
|
) |
|
dialog_tokens: List[int] = sum( |
|
[ |
|
self.tokenizer.encode( |
|
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", |
|
bos=True, |
|
eos=True, |
|
) |
|
for prompt, answer in zip( |
|
dialog[::2], |
|
dialog[1::2], |
|
) |
|
], |
|
[], |
|
) |
|
assert ( |
|
dialog[-1]["role"] == "user" |
|
), f"Last message must be from user, got {dialog[-1]['role']}" |
|
dialog_tokens += self.tokenizer.encode( |
|
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", |
|
bos=True, |
|
eos=False, |
|
) |
|
prompt_tokens.append(dialog_tokens) |
|
|
|
generation_tokens, generation_logprobs = self.generate( |
|
prompt_tokens=prompt_tokens, |
|
max_gen_len=max_gen_len, |
|
temperature=temperature, |
|
top_p=top_p, |
|
logprobs=logprobs, |
|
) |
|
if logprobs: |
|
return [ |
|
{ |
|
"generation": { |
|
"role": "assistant", |
|
"content": self.tokenizer.decode(t), |
|
}, |
|
"tokens": [self.tokenizer.decode(x) for x in t], |
|
"logprobs": logprobs_i, |
|
} |
|
for t, logprobs_i in zip(generation_tokens, generation_logprobs) |
|
] |
|
return [ |
|
{"generation": {"role": "assistant", "content": self.tokenizer.decode(t)}} |
|
for t in generation_tokens |
|
] |
|
|
|
|
|
def sample_top_p(probs, p): |
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
|
probs_sum = torch.cumsum(probs_sort, dim=-1) |
|
mask = probs_sum - probs_sort > p |
|
probs_sort[mask] = 0.0 |
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
|
next_token = torch.multinomial(probs_sort, num_samples=1) |
|
next_token = torch.gather(probs_idx, -1, next_token) |
|
return next_token |
|
|