zhangyang6
Add application file
9e62907
import logging
import json
import requests
import gradio as gr
import time
import jwt
SYSTEM_PROMPT = """你是一个心理学家,擅长通过跟用户对话的方式,判断用户的意图和情绪。"""
TOPIC_PROMPT_TEMPLATE = """
目前你想讨论的话题有["星座","旅行","音乐","其他"]。例如,
user: 今晚的星空很美
ai: 星座
user: 我曾经买过一把吉他
ai: 音乐
根据下面的聊天记录
{history}
请识别用户想谈论的话题,只返回["星座","旅行","音乐","其他"]中的一个,不要返回多余的内容,使用中文。
"""
EMOTION_PROMPT_TEMPLATE = """
根据下面的聊天记录
{history}
判断用户的尬聊程度,10分是最尴尬,0分是最兴奋,请只给出分数的数字,不要返回多余的内容。
"""
def encode_jwt_token(ak, sk):
headers = {
"alg": "HS256",
"typ": "JWT"
}
payload = {
"iss": ak,
"exp": int(time.time()) + 1800, # 填写您期望的有效时间,此处示例代表当前时间+30分钟
"nbf": int(time.time()) - 5 # 填写您期望的生效时间,此处示例代表当前时间-5秒
}
token = jwt.encode(payload, sk, headers=headers)
return token
def sensenova_classification(ak, sk, c_type, history):
messages = [{
"role": "system",
"content": SYSTEM_PROMPT
}]
if c_type == "话题":
messages.append({
"role": "user",
"content": TOPIC_PROMPT_TEMPLATE.format(history=history)
})
elif c_type == "情绪":
messages.append({
"role": "user",
"content": EMOTION_PROMPT_TEMPLATE.format(history=history)
})
else:
raise ValueError("不支持的识别类型")
data = {
"messages": messages,
"model": "nova-ptc-xl-v2-1-0-8k-internal",
}
logging.info("request data: %s", json.dumps(data, ensure_ascii=False, indent=2))
response = requests.post(url="https://api.sensenova.cn/v1/llm/chat-completions", headers={
"Authorization": "Bear " + encode_jwt_token(ak, sk),
}, json=data, stream=True)
if response.status_code == 200:
return response.json()["data"]["choices"][0]["message"]
return response.content
with gr.Blocks() as demo:
input_ak = gr.Textbox("AK", label="AK")
input_sk = gr.Textbox("SK", label="SK")
chat_history = gr.TextArea(placeholder="user: 今天星空很美\nai: ", label="聊天记录")
classification_type = gr.Dropdown(choices=["话题", "情绪"])
output = gr.Textbox(label="识别结果")
greet_btn = gr.Button("识别")
greet_btn.click(
sensenova_classification, inputs=[input_ak, input_sk, classification_type, chat_history],
outputs=output)
if __name__ == "__main__":
demo.launch()