File size: 8,368 Bytes
1155f19 d2742cf e65d733 2b92edc 6cf18af e65d733 0143ad2 c7b6b2d 0143ad2 c7b6b2d 0143ad2 e65d733 c7b6b2d e65d733 c7b6b2d 0143ad2 6cf18af 0143ad2 c7b6b2d e65d733 c7b6b2d 0143ad2 e65d733 c7b6b2d 6cf18af e65d733 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc 6cf18af 2b92edc e65d733 c7b6b2d e65d733 c7b6b2d e65d733 c7b6b2d e65d733 c7b6b2d e65d733 0143ad2 e65d733 6cf18af 0143ad2 e65d733 6cf18af e65d733 c7b6b2d e65d733 c7b6b2d e65d733 c7b6b2d e65d733 c7b6b2d e65d733 2b92edc 6cf18af 2b92edc e65d733 c7b6b2d e65d733 6cf18af c7b6b2d e65d733 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
from fix_int8 import fix_pytorch_int8
fix_pytorch_int8()
# Credit:
# https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py
import torch
import gradio as gr
from threading import Thread
from model import model, tokenizer
from session import db, logger, log_sys_info
from transformers import AutoTokenizer, GenerationConfig, AutoModel
max_length = 224
default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"]
gr_title = """<h1 align="center">KumaGLM</h1>
<h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
<p align='center'>采样范围 2020/06/13 - 2023/04/15</p>
<p align='center'>GitHub Repo: <a class="github-button" href="https://github.com/KumaTea/ChatGLM" aria-label="Star KumaTea/ChatGLM on GitHub">KumaTea/ChatGLM</a></p>
<script async defer src="https://buttons.github.io/buttons.js"></script>
"""
gr_footer = """<p align='center'>
本项目基于
<a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
,模型采用
<a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>
。
</p>
<p align='center'>
<em>每天起床第一句!</em>
</p>"""
def evaluate(context, temperature, top_p):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
# top_k=top_k,
#repetition_penalty=1.1,
num_beams=1,
do_sample=True,
)
with torch.no_grad():
# input_text = f"Context: {context}Answer: "
# input_text = '||'.join(default_start) + '||'
# No need for starting prompt in API
if not context.endswith('||'):
context += '||'
# logger.info('[API] Request: ' + context)
ids = tokenizer([context], return_tensors="pt")
inputs = ids.to("cpu")
out = model.generate(
**inputs,
max_length=max_length,
generation_config=generation_config
)
out = out.tolist()[0]
decoder_output = tokenizer.decode(out)
# out_text = decoder_output.split("Answer: ")[1]
out_text = decoder_output
logger.info('[API] Results: ' + out_text.replace('\n', '<br>'))
return out_text
def evaluate_wrapper(context, temperature, top_p):
db.lock()
index = db.index
db.set(index, prompt=context)
result = evaluate(context, temperature, top_p)
db.set(index, result=result)
db.unlock()
return result
def api_wrapper(context='', temperature=0.5, top_p=0.8, query=0):
query = int(query)
assert context or query
return_json = {
'status': '',
'code': 0,
'message': '',
'index': 0,
'result': ''
}
if context:
if db.islocked():
logger.info(f'[API] Request: {context}, Status: busy')
return_json['status'] = 'busy'
return_json['code'] = 503
return_json['message'] = '[context] Server is busy, please try again later.'
return return_json
else:
for index in db.prompts:
if db.prompts[index] == context:
return_json['status'] = 'done'
return_json['code'] = 200
return_json['message'] = '[context] Request cached.'
return_json['index'] = index
return_json['result'] = db.results[index]
return return_json
# new
index = db.index
t = Thread(target=evaluate_wrapper, args=(context, temperature, top_p))
t.start()
logger.info(f'[API] Request: {context}, Status: processing, Index: {index}')
return_json['status'] = 'processing'
return_json['code'] = 202
return_json['message'] = '[context] Request accepted, please check back later.'
return_json['index'] = index
return return_json
else: # query
if query in db.prompts and query in db.results:
logger.info(f'[API] Query: {query}, Status: hit')
return_json['status'] = 'done'
return_json['code'] = 200
return_json['message'] = '[query] Request processed.'
return_json['index'] = query
return_json['result'] = db.results[query]
return return_json
else:
if db.islocked():
logger.info(f'[API] Query: {query}, Status: processing')
return_json['status'] = 'processing'
return_json['code'] = 202
return_json['message'] = '[query] Request in processing, please check back later.'
return_json['index'] = query
return return_json
else:
logger.info(f'[API] Query: {query}, Status: error')
return_json['status'] = 'error'
return_json['code'] = 404
return_json['message'] = '[query] Index not found.'
return_json['index'] = query
return return_json
def evaluate_stream(msg, history, temperature, top_p):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
#repetition_penalty=1.1,
num_beams=1,
do_sample=True,
)
if not msg:
msg = '……'
history.append([msg, ""])
context = '||'.join(default_start) + '||'
if len(history) > 4:
history.pop(0)
for j in range(len(history)):
history[j][0] = history[j][0].replace("<br>", "")
# concatenate context
for h in history[:-1]:
context += h[0] + "||" + h[1] + "||"
context += history[-1][0] + "||"
context = context.replace(r'<br>', '')
# TODO: Avoid the tokens are too long.
# CUTOFF = 224
while len(tokenizer.encode(context)) > max_length:
# save 15 token size for the answer
context = context[15:]
h = []
logger.info('[UI] Request: ' + context)
for response, h in model.stream_chat(tokenizer, context, h, max_length=max_length, top_p=top_p, temperature=temperature):
history[-1][1] = response
yield history, ""
logger.info('[UI] Results: ' + response.replace('\n', '<br>'))
with gr.Blocks() as demo:
gr.HTML(gr_title)
# state = gr.State()
with gr.Row():
with gr.Column(scale=2):
temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.5, label="Temperature",
info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.8, label="Top-p",
info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
#code = gr.Textbox(label="temp_output", info="解码器输出")
#top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
# info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="聊天框", info="")
msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
info="输入你的内容,按 [Enter] 发送。什么都不填经常会出错。")
clear = gr.Button("清除聊天")
api_handler = gr.Button("API", visible=False)
api_index = gr.Number(visible=False)
api_result = gr.JSON(visible=False)
info_handler = gr.Button("Info", visible=False)
info_text = gr.Textbox('System info logged. Check it in the log viewer.', visible=False)
msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
clear.click(lambda: None, None, chatbot, queue=False)
api_handler.click(api_wrapper, [msg, temp, top_p, api_index], api_result, api_name='chat')
info_handler.click(log_sys_info, None, info_text, api_name='info')
gr.HTML(gr_footer)
demo.queue()
demo.launch(debug=False)
|