|
import argparse |
|
|
|
import time |
|
from transformers import AutoTokenizer |
|
|
|
class Llama3(): |
|
def __init__(self, args): |
|
|
|
self.devices = [int(d) for d in args.devid.split(",")] |
|
|
|
|
|
print("Load " + args.tokenizer_path + " ...") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
args.tokenizer_path, trust_remote_code=True |
|
) |
|
|
|
|
|
self.tokenizer.decode([0]) |
|
|
|
|
|
self.system_prompt = 'You are Llama3, a helpful AI assistant.' |
|
self.EOS = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")] |
|
self.system = {"role":"system","content":self.system_prompt} |
|
self.history = [self.system] |
|
self.enable_history = args.enable_history |
|
|
|
|
|
self.load_model(args) |
|
|
|
|
|
def load_model(self, args): |
|
if args.decode_mode == "basic": |
|
import chat |
|
self.model = chat.Llama3() |
|
self.model.init(self.devices, args.model_path) |
|
self.model.temperature = args.temperature |
|
self.model.top_p = args.top_p |
|
self.model.repeat_penalty = args.repeat_penalty |
|
self.model.repeat_last_n = args.repeat_last_n |
|
self.model.max_new_tokens = args.max_new_tokens |
|
self.model.generation_mode = args.generation_mode |
|
self.model.prompt_mode = args.prompt_mode |
|
else: |
|
raise ValueError("decode mode: {} is illegal!".format(args.decode_mode)) |
|
|
|
self.SEQLEN = self.model.SEQLEN |
|
|
|
|
|
def clear(self): |
|
self.history = [self.system] |
|
|
|
|
|
def update_history(self): |
|
if self.model.token_length >= self.SEQLEN: |
|
print("... (reach the maximal length)", flush=True, end='') |
|
self.history = [self.system] |
|
else: |
|
self.history.append({"role":"assistant","content":self.answer_cur}) |
|
|
|
|
|
def encode_tokens(self): |
|
self.history.append({"role":"user","content":self.input_str}) |
|
return self.tokenizer.apply_chat_template(self.history, tokenize=True, add_generation_prompt=True) |
|
|
|
|
|
def chat(self): |
|
""" |
|
Start a chat session. |
|
""" |
|
|
|
print( |
|
"""\n================================================================= |
|
1. If you want to quit, please enter one of [q, quit, exit] |
|
2. To create a new chat session, please enter one of [clear, new] |
|
=================================================================""" |
|
) |
|
|
|
while True: |
|
self.input_str = input("\nQuestion: ") |
|
|
|
if self.input_str in ["exit", "q", "quit"]: |
|
break |
|
|
|
elif self.input_str in ["clear", "new"]: |
|
self.clear() |
|
|
|
else: |
|
tokens = self.encode_tokens() |
|
|
|
|
|
if not tokens: |
|
print("Sorry: your question is empty!!") |
|
return |
|
if len(tokens) > self.SEQLEN: |
|
print( |
|
"The maximum question length should be shorter than {} but we get {} instead.".format( |
|
self.SEQLEN, len(tokens) |
|
) |
|
) |
|
return |
|
|
|
print("\nAnswer: ", end="") |
|
self.stream_answer(tokens) |
|
|
|
|
|
def stream_answer(self, tokens): |
|
""" |
|
Stream the answer for the given tokens. |
|
""" |
|
tok_num = 0 |
|
self.answer_cur = "" |
|
self.answer_token = [] |
|
|
|
|
|
first_start = time.time() |
|
token = self.model.forward_first(tokens) |
|
first_end = time.time() |
|
|
|
full_word_tokens = [] |
|
|
|
while token not in self.EOS and self.model.token_length < self.SEQLEN: |
|
full_word_tokens.append(token) |
|
word = self.tokenizer.decode(full_word_tokens, skip_special_tokens=True) |
|
if "�" in word: |
|
token = self.model.forward_next() |
|
tok_num += 1 |
|
continue |
|
|
|
self.answer_token += [token] |
|
print(word, flush=True, end="") |
|
token = self.model.forward_next() |
|
tok_num += 1 |
|
full_word_tokens = [] |
|
|
|
|
|
next_end = time.time() |
|
first_duration = first_end - first_start |
|
next_duration = next_end - first_end |
|
tps = tok_num / next_duration |
|
|
|
print() |
|
print(f"FTL: {first_duration:.3f} s") |
|
print(f"TPS: {tps:.3f} token/s") |
|
|
|
self.answer_cur = self.tokenizer.decode(self.answer_token) |
|
|
|
if self.enable_history: |
|
self.update_history() |
|
else: |
|
self.clear() |
|
|
|
|
|
|
|
def stream_predict(self, query): |
|
""" |
|
Stream the prediction for the given query. |
|
""" |
|
self.answer_cur = "" |
|
self.input_str = query |
|
tokens = self.encode_tokens() |
|
|
|
for answer_cur, history in self._generate_predictions(tokens): |
|
yield answer_cur, history |
|
|
|
def _generate_predictions(self, tokens): |
|
""" |
|
Generate predictions for the given tokens. |
|
""" |
|
|
|
next_token = self.model.forward_first(tokens) |
|
output_tokens = [next_token] |
|
|
|
|
|
while True: |
|
next_token = self.model.forward_next() |
|
if next_token == self.EOS: |
|
break |
|
output_tokens += [next_token] |
|
self.answer_cur = self.tokenizer.decode(output_tokens) |
|
if self.model.token_length >= self.SEQLEN: |
|
self.update_history() |
|
yield self.answer_cur + "\n\n\nReached the maximum length; The history context has been cleared.", self.history |
|
break |
|
else: |
|
yield self.answer_cur, self.history |
|
|
|
self.update_history() |
|
|
|
def main(args): |
|
model = Llama3(args) |
|
model.chat() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-m', '--model_path', type=str, required=True, help='path to the bmodel file') |
|
parser.add_argument('-t', '--tokenizer_path', type=str, default="../support/token_config", help='path to the tokenizer file') |
|
parser.add_argument('-d', '--devid', type=str, default='0', help='device ID to use') |
|
parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution') |
|
parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates') |
|
parser.add_argument('--repeat_penalty', type=float, default=1.0, help='penalty for repeated tokens') |
|
parser.add_argument('--repeat_last_n', type=int, default=32, help='repeat penalty for recent n tokens') |
|
parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate') |
|
parser.add_argument('--generation_mode', type=str, choices=["greedy", "penalty_sample"], default="greedy", help='mode for generating next token') |
|
parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input') |
|
parser.add_argument('--decode_mode', type=str, default="basic", choices=["basic", "jacobi"], help='mode for decoding') |
|
parser.add_argument('--enable_history', action='store_true', default=True, help="if set, enables storing of history memory.") |
|
args = parser.parse_args() |
|
main(args) |
|
|