File size: 4,658 Bytes
3ac04fa
5799733
8344bac
5799733
2fa4e4c
 
241f191
5799733
 
 
 
 
 
 
 
 
 
 
 
 
 
73119ac
 
 
a752f35
2fa4e4c
 
 
 
 
 
73119ac
 
2fa4e4c
5799733
241f191
 
5799733
241f191
 
b57ad3d
5799733
 
64e9b3c
8607d84
 
b57ad3d
8607d84
 
 
5799733
 
241f191
5799733
 
 
 
8607d84
5799733
241f191
 
b57ad3d
5799733
64e9b3c
8607d84
 
b57ad3d
5799733
8607d84
 
 
5799733
 
8344bac
 
 
 
 
 
 
2fa4e4c
 
e4e3ccf
55b26f1
8607d84
55b26f1
8607d84
5799733
d89d143
e4e3ccf
 
 
5799733
 
8344bac
 
 
 
 
 
 
 
8607d84
 
 
 
 
 
 
 
 
8344bac
8607d84
 
 
 
 
 
 
 
e52ef2a
3ac04fa
 
 
 
 
 
 
 
 
b57ad3d
5799733
 
 
 
 
 
 
863617a
5799733
 
726a01e
 
5799733
241f191
f0929ee
 
241f191
 
459fbe3
 
5799733
241f191
f0929ee
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
import gradio as gr
from typing import List, Dict
from utils.logging_util import logger
from models.cpp_qwen2 import Qwen2Simulator as Bot
# from models.hf_qwen2 import Qwen2Simulator as Bot
# from models.mock import MockSimulator as Bot

#
# def postprocess(self, y):
#     if y is None:
#         return []
#     for i, (message, response) in enumerate(y):
#         y[i] = (
#             None if message is None else mdtex2html.convert((message)),
#             None if response is None else mdtex2html.convert(response),
#         )
#     return y
#
# gr.Chatbot.postprocess = postprocess

system_list = [
    "You are a helpful assistant.",
    "你是一个导游。",
    "你是一名投资经理。",
    "你是一名医生。",
    "你是一个英语老师。",
    "你是一个程序员。",
    "你是一个心理咨询师。",
    "你是一名AI写作助手。"
    "你是一名作家,擅长写小说。"
]

bot = Bot(system_list)


def generate_user_message(chatbot, history, show_warning=True):
    if history and history[-1]["role"] == "user":
        if show_warning:
            gr.Warning('You should generate assistant-response.')
        yield chatbot, history
    else:
        chatbot.append(None)
        streamer = bot.generate(history, stream=True)
        for user_content, user_tokens in streamer:
            chatbot[-1] = (user_content, None)
            yield chatbot, history
        user_tokens = bot.strip_stoptokens(user_tokens)
        history.append({"role": "user", "content": user_content, "tokens": user_tokens})
        yield chatbot, history


def generate_assistant_message(chatbot, history, show_warning=True):
    """
    auto-mode:query is None
    manual-mode:query 是用户输入
    """
    user_content = history[-1]["content"]
    if history[-1]["role"] != "user":
        if show_warning:
            gr.Warning('You should generate or type user-input first.')
        yield chatbot, history
    else:
        streamer = bot.generate(history, stream=True)
        for assistant_content, assistant_tokens in streamer:
            chatbot[-1] = (user_content, assistant_content)
            yield chatbot, history

        assistant_tokens = bot.strip_stoptokens(assistant_tokens)
        history.append({"role": "assistant", "content": assistant_content, "tokens": assistant_tokens})
        yield chatbot, history


def chat(chatbot: List[str], history: List[Dict]):
    """
    self chat
    :param chatbot:
    :param history:
    :return:
    """
    request_param = json.dumps({'chatbot': chatbot, 'history': history}, ensure_ascii=False)
    logger.info(f"request_param: {request_param}")
    streamer = None
    if history[-1]["role"] in ["assistant", "system"]:
        streamer = generate_user_message(chatbot, history)
    elif history[-1]["role"] == "user":
        streamer = generate_assistant_message(chatbot, history)
    else:
        gr.Warning("bug")

    for out in streamer:
        yield out


def append_user_to_history(input_content, chatbot, history):
    """

    :param input_content:
    :param chatbot:
    :param history:
    :return:
    """
    if history[-1]["role"] == "user":
        gr.Warning('You should generate assistant-response.')
        return chatbot, history

    chatbot.append((input_content, None))
    history.append({"role": "user", "content": input_content})
    return chatbot, history


def append_assistant_to_history(input_content, chatbot, history):
    if history[-1]["role"] != "user":
        gr.Warning('You should generate or type user-input first.')
        return chatbot, history
    chatbot[-1] = (chatbot[-1][0], input_content)
    history.append({"role": "assistant", "content": input_content})
    return chatbot, history


def undo_generate(chatbot, history):
    if history[-1]["role"] == "user":
        history = history[:-1]
        chatbot = chatbot[:-1]
    elif history[-1]["role"] == "assistant":
        history = history[:-1]
        chatbot[-1] = (chatbot[-1][0], None)
    else:
        pass
    logger.info(f"after undo, {json.dumps(chatbot, ensure_ascii=False)}, {json.dumps(history, ensure_ascii=False)}")
    return chatbot, history


def reset_user_input():
    return gr.update(value='')


def reset_state(system):
    return [], [{"role": "system", "content": system}]


def set_max_new_tokens(max_new_tokens):
    bot.generation_kwargs["max_tokens"] = max_new_tokens


def set_temperature(temperature):
    bot.generation_kwargs["temperature"] = temperature


def set_top_p(top_p):
    bot.generation_kwargs["top_p"] = top_p


def set_top_k(top_k):
    bot.generation_kwargs["top_k"] = top_k