Spaces:
Runtime error
Runtime error
File size: 6,245 Bytes
8aecfe4 9ac6c7f 8aecfe4 9ac6c7f 8aecfe4 f4e5c25 8aecfe4 f4e5c25 8aecfe4 398bfe6 8aecfe4 42029d6 8aecfe4 42029d6 8aecfe4 ab40822 8aecfe4 4d9c77c 8aecfe4 e832744 1cdc713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import gradio as gr
import clueai
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("huolongguo10/HR_Chat")
model = T5ForConditionalGeneration.from_pretrained("huolongguo10/HR_Chat")
# 使用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
base_info = "用户:你是谁?\n小元:我是huolongguo10助手智能助手。\n"
def preprocess(text):
text = f"{base_info}{text}"
text = text.replace("\n", "\\n").replace("\t", "\\t")
return text
def postprocess(text):
return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')#.replace(" ", " ")
generate_config = {'do_sample': True, 'top_p': 0.9, 'top_k': 50, 'temperature': 0.9,
'num_beams': 1, 'max_length': 1024, 'min_length': 3, 'no_repeat_ngram_size': 5,
'length_penalty': 0.6, 'return_dict_in_generate': True, 'output_scores': True}
def answer(text, sample=True, top_p=0.9, temperature=0.9):
'''sample:是否抽样。生成任务,可以设置为True;
top_p:0-1之间,生成的内容越多样'''
text = preprocess(text)
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
if not sample:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
else:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=12)
#out=model.generate(**encoding, **generate_config)
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
return postprocess(out_text[0])
def clear_session():
return '', None
def chatyuan_bot(input, history):
history = history or []
if len(history) > 5:
history = history[-5:]
context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
#print(context)
input_text = context + "\n用户:" + input + "\n小元:"
input_text = input_text.strip()
output_text = answer(input_text)
print("open_model".center(20, "="))
print(f"{input_text}\n{output_text}")
#print("="*20)
history.append((input, output_text))
#print(history)
return history, history
def chatyuan_bot_regenerate(input, history):
history = history or []
if history:
input=history[-1][0]
history=history[:-1]
if len(history) > 1:
history = history[-1:]
context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
#print(context)
input_text = context + "\n用户:" + input + "\n小元:"
input_text = input_text.strip()
output_text = answer(input_text)
print("open_model".center(20, "="))
print(f"{input_text}\n{output_text}")
history.append((input, output_text))
#print(history)
return history, history
block = gr.Blocks()
with block as demo:
gr.Markdown("""<h1><center>huolongguo10助手</center></h1>
""")
chatbot = gr.Chatbot(label='ChatYuan')
message = gr.Textbox()
state = gr.State()
message.submit(chatyuan_bot, inputs=[message, state], outputs=[chatbot, state])
with gr.Row():
clear_history = gr.Button("👋 清除历史对话 | Clear History")
clear = gr.Button('🧹 清除发送框 | Clear Input')
send = gr.Button("🚀 发送 | Send")
regenerate = gr.Button("🚀 重新生成本次结果 | regenerate")
regenerate.click(chatyuan_bot_regenerate, inputs=[message, state], outputs=[chatbot, state])
send.click(chatyuan_bot, inputs=[message, state], outputs=[chatbot, state])
clear.click(lambda: None, None, message, queue=False)
clear_history.click(fn=clear_session , inputs=[], outputs=[chatbot, state], queue=False)
def ChatYuan(api_key, text_prompt):
cl = clueai.Client(api_key,
check_api_key=True)
# generate a prediction for a prompt
# 需要返回得分的话,指定return_likelihoods="GENERATION"
prediction = cl.generate(model_name='ChatYuan-large', prompt=text_prompt)
# print the predicted text
#print('prediction: {}'.format(prediction.generations[0].text))
response = prediction.generations[0].text
if response == '':
response = "很抱歉,我无法回答这个问题"
return response
def chatyuan_bot_api(api_key, input, history):
history = history or []
if len(history) > 5:
history = history[-5:]
context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
#print(context)
input_text = context + "\n用户:" + input + "\n小元:"
input_text = input_text.strip()
output_text = ChatYuan(api_key, input_text)
print("api".center(20, "="))
print(f"api_key:{api_key}\n{input_text}\n{output_text}")
#print("="*20)
history.append((input, output_text))
#print(history)
return history, history
block = gr.Blocks()
with block as demo_1:
gr.Markdown("""
""")
api_key = gr.inputs.Textbox(label="请输入你的api-key(必填)", default="", type='password')
chatbot = gr.Chatbot(label='ChatYuan')
message = gr.Textbox()
state = gr.State()
message.submit(chatyuan_bot_api, inputs=[api_key,message, state], outputs=[chatbot, state])
with gr.Row():
clear_history = gr.Button("👋 清除历史对话 | Clear Context")
clear = gr.Button('🧹 清除发送框 | Clear Input')
send = gr.Button("🚀 发送 | Send")
send.click(chatyuan_bot_api, inputs=[api_key,message, state], outputs=[chatbot, state],api_name='send')
clear.click(lambda: None, None, message, queue=False)
clear_history.click(fn=clear_session , inputs=[], outputs=[chatbot, state], queue=False)
block = gr.Blocks()
with block as introduction:
gr.Markdown("""啥也没有
""")
gui = gr.TabbedInterface(interface_list=[demo], tab_names=["开源模型"])
gui.launch(quiet=True,show_api=True, share = False) |