|
""" |
|
Copied from https://github.com/lm-sys/FastChat. |
|
Later we will contribute our changes into it. |
|
""" |
|
import dataclasses |
|
from enum import auto, IntEnum |
|
from typing import List, Any, Dict |
|
import math |
|
from typing import List, Optional, Tuple, Union |
|
import random |
|
import numpy as np |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings |
|
from transformers import ( |
|
LogitsProcessorList, |
|
MinLengthLogitsProcessor, |
|
TopKLogitsWarper, |
|
TemperatureLogitsWarper, |
|
TopPLogitsWarper, |
|
StoppingCriteriaList, |
|
MaxLengthCriteria, |
|
BitsAndBytesConfig, |
|
) |
|
|
|
|
|
|
|
class SeparatorStyle(IntEnum): |
|
"""Separator styles.""" |
|
|
|
ADD_COLON_SINGLE = auto() |
|
ADD_COLON_TWO = auto() |
|
ADD_COLON_SPACE_SINGLE = auto() |
|
NO_COLON_SINGLE = auto() |
|
NO_COLON_TWO = auto() |
|
ADD_NEW_LINE_SINGLE = auto() |
|
|
|
|
|
@dataclasses.dataclass |
|
class Conversation: |
|
"""A class that manages prompt templates and keeps all conversation history.""" |
|
|
|
|
|
name: str |
|
|
|
system_template: str = "{system_message}" |
|
|
|
system_message: str = "" |
|
|
|
roles: List[str] = (("USER", "ASSISTANT"),) |
|
|
|
messages: List[List[str]] = () |
|
|
|
offset: int = 0 |
|
|
|
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE |
|
sep: str = "\n" |
|
sep2: str = None |
|
|
|
stop_str: str = None |
|
|
|
stop_token_ids: List[int] = None |
|
|
|
def get_prompt(self) -> str: |
|
"""Get the prompt for generation.""" |
|
system_prompt = self.system_template.format(system_message=self.system_message) |
|
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: |
|
ret = system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + ": " + message + self.sep |
|
else: |
|
ret += role + ":" |
|
return ret |
|
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: |
|
seps = [self.sep, self.sep2] |
|
ret = system_prompt + seps[0] |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
ret += role + ": " + message + seps[i % 2] |
|
else: |
|
ret += role + ":" |
|
return ret |
|
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: |
|
ret = system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + ": " + message + self.sep |
|
else: |
|
ret += role + ": " |
|
return ret |
|
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: |
|
ret = "" if system_prompt == "" else system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + "\n" + message + self.sep |
|
else: |
|
ret += role + "\n" |
|
return ret |
|
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: |
|
ret = system_prompt |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + message + self.sep |
|
else: |
|
ret += role |
|
return ret |
|
elif self.sep_style == SeparatorStyle.NO_COLON_TWO: |
|
seps = [self.sep, self.sep2] |
|
ret = system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
ret += role + message + seps[i % 2] |
|
else: |
|
ret += role |
|
return ret |
|
|
|
def set_system_message(self, system_message: str): |
|
"""Set the system message.""" |
|
self.system_message = system_message |
|
|
|
def append_message(self, role: str, message: str): |
|
"""Append a new message.""" |
|
self.messages.append([role, message]) |
|
|
|
def update_last_message(self, message: str): |
|
"""Update the last output. |
|
|
|
The last message is typically set to be None when constructing the prompt, |
|
so we need to update it in-place after getting the response from a model. |
|
""" |
|
self.messages[-1][1] = message |
|
|
|
def copy(self): |
|
return Conversation( |
|
name=self.name, |
|
system_template=self.system_template, |
|
system_message=self.system_message, |
|
roles=self.roles, |
|
messages=[[x, y] for x, y in self.messages], |
|
offset=self.offset, |
|
sep_style=self.sep_style, |
|
sep=self.sep, |
|
sep2=self.sep2, |
|
stop_str=self.stop_str, |
|
stop_token_ids=self.stop_token_ids, |
|
) |
|
|
|
def dict(self): |
|
return { |
|
"template_name": self.name, |
|
"system_message": self.system_message, |
|
"roles": self.roles, |
|
"messages": self.messages, |
|
"offset": self.offset, |
|
} |
|
|
|
|
|
|
|
conv_templates: Dict[str, Conversation] = {} |
|
|
|
|
|
def register_conv_template(template: Conversation, override: bool = False): |
|
"""Register a new conversation template.""" |
|
if not override: |
|
assert ( |
|
template.name not in conv_templates |
|
), f"{template.name} has been registered." |
|
|
|
conv_templates[template.name] = template |
|
|
|
|
|
def get_conv_template(name: str) -> Conversation: |
|
"""Get a conversation template.""" |
|
return conv_templates[name].copy() |
|
|
|
def get_conversation_template(model_path: str) -> Conversation: |
|
"""Get the default conversation template.""" |
|
if "aquila-v1" in model_path: |
|
return get_conv_template("aquila-v1") |
|
elif "aquila-chat" in model_path: |
|
return get_conv_template("aquila-chat") |
|
elif "aquila-legacy" in model_path: |
|
return get_conv_template("aquila-legacy") |
|
else: |
|
return get_conv_template("aquila") |
|
|
|
|
|
|
|
register_conv_template( |
|
Conversation( |
|
name="aquila-chat", |
|
system_message="A chat between a curious human and an artificial intelligence assistant. " |
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.", |
|
roles=("Human", "Assistant", "System"), |
|
messages=(), |
|
offset=0, |
|
sep_style=SeparatorStyle.ADD_COLON_SINGLE, |
|
sep="###", |
|
sep2="", |
|
stop_str=["###", "</s>", "[UNK]"], |
|
) |
|
) |
|
|
|
register_conv_template( |
|
Conversation( |
|
name="aquila-legacy", |
|
system_message="A chat between a curious human and an artificial intelligence assistant. " |
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", |
|
roles=("### Human: ", "### Assistant: ", "System"), |
|
messages=(), |
|
offset=0, |
|
sep_style=SeparatorStyle.NO_COLON_TWO, |
|
sep="\n", |
|
sep2="</s>", |
|
stop_str=["</s>", "[UNK]"], |
|
) |
|
) |
|
|
|
register_conv_template( |
|
Conversation( |
|
name="aquila", |
|
system_message="A chat between a curious human and an artificial intelligence assistant. " |
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.", |
|
roles=("Human", "Assistant", "System"), |
|
messages=(), |
|
offset=0, |
|
sep_style=SeparatorStyle.ADD_COLON_TWO, |
|
sep="###", |
|
sep2="</s>", |
|
stop_str=["</s>", "[UNK]"], |
|
) |
|
) |
|
|
|
register_conv_template( |
|
Conversation( |
|
name="aquila-v1", |
|
roles=("<|startofpiece|>", "<|endofpiece|>", ""), |
|
messages=(), |
|
offset=0, |
|
sep_style=SeparatorStyle.NO_COLON_TWO, |
|
sep="", |
|
sep2="</s>", |
|
stop_str=["</s>", "<|endoftext|>"], |
|
) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("aquila template:") |
|
conv = get_conv_template("aquila") |
|
conv.append_message(conv.roles[0], "Hello!") |
|
conv.append_message(conv.roles[1], "Hi!") |
|
conv.append_message(conv.roles[0], "How are you?") |
|
conv.append_message(conv.roles[1], None) |
|
print(conv.get_prompt()) |
|
|
|
print("\n") |
|
|
|
print("aquila-chat template:") |
|
conv = get_conv_template("aquila-chat") |
|
conv.append_message(conv.roles[0], "Hello!") |
|
conv.append_message(conv.roles[1], "Hi!") |
|
conv.append_message(conv.roles[0], "How are you?") |
|
conv.append_message(conv.roles[1], None) |
|
print(conv.get_prompt()) |
|
|
|
print("\n") |
|
|
|
print("aquila-v1 template:") |
|
conv = get_conv_template("aquila-v1") |
|
conv.append_message(conv.roles[0], "Hello!") |
|
conv.append_message(conv.roles[1], "Hi!") |
|
conv.append_message(conv.roles[0], "How are you?") |
|
conv.append_message(conv.roles[1], None) |
|
print(conv.get_prompt()) |
|
|
|
print("\n") |
|
|
|
print("aquila-legacy template:") |
|
conv = get_conv_template("aquila-legacy") |
|
conv.append_message(conv.roles[0], "Hello!") |
|
conv.append_message(conv.roles[1], "Hi!") |
|
conv.append_message(conv.roles[0], "How are you?") |
|
conv.append_message(conv.roles[1], None) |
|
print(conv.get_prompt()) |
|
|
|
print("\n") |
|
|
|
def set_random_seed(seed): |
|
"""Set random seed for reproducability.""" |
|
if seed is not None and seed > 0: |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token, convo_template="aquila-chat"): |
|
|
|
conv = get_conv_template(convo_template) |
|
|
|
conv.append_message(conv.roles[1], None) |
|
conv.append_message(conv.roles[0], text) |
|
|
|
example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] |
|
|
|
if history is None or not isinstance(history, list): |
|
history = [] |
|
|
|
while(len(history) > 0 and (len(example) < max_token)): |
|
tmp = history.pop() |
|
if tmp[0] == 'ASSISTANT': |
|
conv.append_message(conv.roles[1], tmp[1]) |
|
else: |
|
conv.append_message(conv.roles[0], tmp[1]) |
|
example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] |
|
|
|
if len(example) >= max_token: |
|
conv.messages.pop() |
|
conv.messages = conv.messages[::-1] |
|
print('model in:', conv.get_prompt()) |
|
example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] |
|
|
|
return example |
|
|
|
def predict(model, text, tokenizer=None, |
|
max_gen_len=200, top_p=0.95, |
|
seed=1234, topk=100, |
|
temperature=0.9, |
|
sft=True, convo_template = "", |
|
device = "cuda", |
|
model_name="AquilaChat2-7B", |
|
history=None, |
|
**kwargs): |
|
|
|
vocab = tokenizer.get_vocab() |
|
|
|
id2word = {v:k for k, v in vocab.items()} |
|
|
|
|
|
template_map = {"AquilaChat2-7B": "aquila-v1", |
|
"AquilaChat2-34B": "aquila-legacy", |
|
"AquilaChat2-7B-16K": "aquila", |
|
"AquilaChat2-34B-16K": "aquila"} |
|
if not convo_template: |
|
convo_template=template_map.get(model_name, "aquila-chat") |
|
|
|
set_random_seed(seed) |
|
if temperature == 0: |
|
topk = 1 |
|
temperature = 1.0 |
|
if sft: |
|
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template) |
|
tokens = torch.tensor(tokens)[None,].to(device) |
|
else : |
|
tokens = tokenizer.encode_plus(text)["input_ids"] |
|
print(tokenizer.decode(tokens)) |
|
tokens = torch.tensor(tokens)[None,].to(device) |
|
input_length = len(tokens[0]) |
|
with torch.no_grad(): |
|
|
|
|
|
logits_processor = LogitsProcessorList( |
|
[ |
|
MinLengthLogitsProcessor(1, eos_token_id=100007), |
|
] |
|
) |
|
|
|
logits_warper = LogitsProcessorList( |
|
[ |
|
TopPLogitsWarper(top_p), |
|
TopKLogitsWarper(topk), |
|
TemperatureLogitsWarper(temperature), |
|
|
|
] |
|
) |
|
|
|
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=input_length + max_gen_len)]) |
|
out = model.sample( |
|
tokens, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
stopping_criteria=stopping_criteria, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
) |
|
|
|
|
|
|
|
out_ids = out["sequences"][0][input_length:].cpu().numpy() |
|
|
|
out_scores = out["scores"] |
|
|
|
out_scores = torch.cat(out_scores, dim=0) |
|
out_scores = torch.nn.functional.softmax(out_scores, dim=-1).cpu().numpy() |
|
|
|
probs = [] |
|
for i in range(len(out_ids)): |
|
probs.append(float(out_scores[i][out_ids[i]])) |
|
|
|
|
|
|
|
convert_tokens = [] |
|
for t in out_ids: |
|
if t == 100006: |
|
convert_tokens.append("[CLS]") |
|
else : |
|
convert_tokens.append(id2word.get(t, "[unkonwn_token]")) |
|
|
|
out_text = tokenizer.decode(out_ids.tolist()) |
|
|
|
|
|
out = out_text |
|
|
|
if "[UNK]" in out: |
|
special_index = out.index("[UNK]") |
|
out = out[:special_index] |
|
token_length = len(tokenizer.encode_plus(out)["input_ids"]) |
|
convert_tokens = convert_tokens[:token_length] |
|
probs = probs[:token_length] |
|
|
|
if "</s>" in out: |
|
special_index = out.index("</s>") |
|
out = out[: special_index] |
|
token_length = len(tokenizer.encode_plus(out)["input_ids"]) |
|
convert_tokens = convert_tokens[:token_length] |
|
probs = probs[:token_length] |
|
|
|
if len(out) > 0 and out[0] == " ": |
|
out = out[1:] |
|
|
|
convert_tokens = convert_tokens[1:] |
|
probs = probs[1:] |
|
|
|
if isinstance(history, list): |
|
|
|
history.insert(0, ('ASSISTANT', out)) |
|
history.insert(0, ('USER', text)) |
|
|
|
return out |
|
|