File size: 5,140 Bytes
dd62f9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import time
import os
import gradio as gr
from text_generation import Client
from conversation import get_default_conv_template
from transformers import AutoTokenizer


endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080")
client = Client(endpoint_url, timeout=120)
eos_token = "</s>"
max_new_tokens = 512
max_prompt_length = 4096 - max_new_tokens - 10

tokenizer = AutoTokenizer.from_pretrained("yentinglin/Taiwan-LLaMa-v1.0")

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        conv = get_default_conv_template("vicuna").copy()
        roles = {"human": conv.roles[0], "gpt": conv.roles[1]}  # map human to USER and gpt to ASSISTANT
        for user, bot in history:
            conv.append_message(roles['human'], user)
            conv.append_message(roles["gpt"], bot)
        msg = conv.get_prompt()
        prompt_tokens = tokenizer.encode(msg)
        length_of_prompt = len(prompt_tokens)
        if length_of_prompt > max_prompt_length:
            msg = tokenizer.decode(prompt_tokens[-max_prompt_length+1:])

        history[-1][1] = ""
        for response in client.generate_stream(
                msg,
                max_new_tokens=max_new_tokens,
        ):
            if not response.token.special:
                character = response.token.text
                history[-1][1] += character
                yield history


    def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True):
        conv = get_default_conv_template("vicuna").copy()
        roles = {"human": conv.roles[0], "gpt": conv.roles[1]}  # map human to USER and gpt to ASSISTANT
        for user, bot in history:
            conv.append_message(roles['human'], user)
            conv.append_message(roles["gpt"], bot)
        msg = conv.get_prompt()

        for response in client.generate_stream(
                msg,
                max_new_tokens=max_new_token,
                top_p=top_p,
                temperature=temperature,
                do_sample=do_sample,
        ):
            history[-1][1] = ""
            # if not response.token.special:
            character = response.token.text
            history[-1][1] += character
            print(history[-1][1])
            time.sleep(0.05)
            yield history


    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    
demo.queue()
demo.launch()

#
# with gr.Blocks() as demo:
#     chatbot = gr.Chatbot()
#     with gr.Row():
#         with gr.Column(scale=4):
#             with gr.Column(scale=12):
#                 user_input = gr.Textbox(
#                     show_label=False,
#                     placeholder="Shift + Enter傳送...",
#                     lines=10).style(
#                     container=False)
#             with gr.Column(min_width=32, scale=1):
#                 submitBtn = gr.Button("Submit", variant="primary")
#         with gr.Column(scale=1):
#             emptyBtn = gr.Button("Clear History")
#             max_new_token = gr.Slider(
#                 1,
#                 1024,
#                 value=128,
#                 step=1.0,
#                 label="Maximum New Token Length",
#                 interactive=True)
#             top_p = gr.Slider(0, 1, value=0.9, step=0.01,
#                               label="Top P", interactive=True)
#             temperature = gr.Slider(
#                 0,
#                 1,
#                 value=0.5,
#                 step=0.01,
#                 label="Temperature",
#                 interactive=True)
#             top_k = gr.Slider(1, 40, value=40, step=1,
#                               label="Top K", interactive=True)
#             do_sample = gr.Checkbox(
#                 value=True,
#                 label="Do Sample",
#                 info="use random sample strategy",
#                 interactive=True)
#             repetition_penalty = gr.Slider(
#                 1.0,
#                 3.0,
#                 value=1.1,
#                 step=0.1,
#                 label="Repetition Penalty",
#                 interactive=True)
#
#     params = [user_input, chatbot]
#     predict_params = [
#         chatbot,
#         max_new_token,
#         top_p,
#         temperature,
#         top_k,
#         do_sample,
#         repetition_penalty]
#
#     submitBtn.click(
#         generate_response,
#         [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
#         [chatbot],
#         queue=False
#     )
#
#     user_input.submit(
#         generate_response,
#         [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
#         [chatbot],
#         queue=False
#     )
#
#     submitBtn.click(lambda: None, [], [user_input])
#
#     emptyBtn.click(lambda: chatbot.reset(), outputs=[chatbot], show_progress=True)
#
# demo.launch()