File size: 4,576 Bytes
7c071a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
import time
import gradio as gr
from pipeline import Llama3
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', type=str, required=True, help='path to the bmodel file')
parser.add_argument('-t', '--tokenizer_path', type=str, default="../support/token_config", help='path to the tokenizer file')
parser.add_argument('-d', '--devid', type=str, default='0', help='device ID to use')
parser.add_argument('--enable_history', type=bool, default=True, help="if set, enables storing of history memory.")
parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution')
parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates')
parser.add_argument('--repeat_penalty', type=float, default=1.0, help='penalty for repeated tokens')
parser.add_argument('--repeat_last_n', type=int, default=32, help='repeat penalty for recent n tokens')
parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate')
parser.add_argument('--generation_mode', type=str, choices=["greedy", "penalty_sample"], default="greedy", help='mode for generating next token')
parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input')
parser.add_argument('--decode_mode', type=str, default="basic", choices=["basic", "jacobi"], help='mode for decoding')
args = parser.parse_args()

model = Llama3(args)


def gr_update_history():
    if model.model.token_length >= model.SEQLEN:
        # print("... (reach the maximal length)", flush=True, end='')
        gr.Warning("reach the maximal length, Llama3 would clear all history record")
        model.history = [model.system]
    else:
        model.history.append({"role": "assistant", "content": model.answer_cur})


def gr_user(user_input, history):
    model.input_str = user_input
    return "", history + [[user_input, None]]


def gr_chat(history):
    """
    Stream the prediction for the given query.
    """
    tokens = model.encode_tokens()

    # check tokens
    if not tokens:
        gr.Warning("Sorry: your question is empty!!")
        return
    if len(tokens) > model.SEQLEN:
        gr.Warning(
            "The maximum question length should be shorter than {} but we get {} instead.".format(
                model.SEQLEN, len(tokens)
            )
        )
        gr_update_history()

    model.answer_cur = ""
    model.answer_token = []
    token_num = 0

    first_start = time.time()
    token = model.model.forward_first(tokens)
    first_end = time.time()

    history[-1][1] = ""
    full_word_tokens = []

    while token not in model.EOS and model.model.token_length < model.SEQLEN:
        full_word_tokens.append(token)
        t_word = model.tokenizer.decode(full_word_tokens, skip_special_tokens=True)

        if "�" in t_word:
            token = model.model.forward_next()
            token_num += 1
            continue

        model.answer_token += full_word_tokens

        history[-1][1] += t_word
        full_word_tokens = []
        yield history
        token = model.model.forward_next()
        token_num += 1

    next_end = time.time()
    first_duration = first_end - first_start
    next_duration = next_end - first_end
    tps = token_num / next_duration

    print()
    print(f"FTL: {first_duration:.3f} s")
    print(f"TPS: {tps:.3f} token/s")

    model.answer_cur = model.tokenizer.decode(model.answer_token)
    gr_update_history()


def reset():
    model.clear()
    return [[None, None]]



description = """
# Llama3 TPU 🏁 
"""
with gr.Blocks() as demo:
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            chatbot = gr.Chatbot(label="Llama3", height=1050)

            with gr.Row():
                user_input = gr.Textbox(show_label=False, placeholder="Ask Llama3", lines=1, min_width=300, scale=6)
                submitBtn = gr.Button("Submit", variant="primary", scale=1)
                emptyBtn = gr.Button(value="Clear", scale=1)

    user_input.submit(gr_user, [user_input, chatbot], [user_input, chatbot]).then(gr_chat, chatbot, chatbot)
    # clear_history.
    submitBtn.click(gr_user, [user_input, chatbot], [user_input, chatbot]).then(gr_chat, chatbot, chatbot)

    emptyBtn.click(reset, outputs=[chatbot])

demo.queue(max_size=20).launch(share=False, server_name="0.0.0.0", inbrowser=True, server_port=8003)