Winnie / web.py
lewiswu1209's picture
add qa skills
c1e6869
import os
import random
import re
import requests
import argparse
import string
from datetime import timedelta
from flask import Flask, session, request, jsonify, render_template
from transformers.models.bert.tokenization_bert import BertTokenizer
from bot.chatbot import ChatBot
from bot.config import special_token_list
app = Flask(__name__)
app.config["SECRET_KEY"] = os.urandom(74)
app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7)
tokenizer:BertTokenizer = None
history_matrix:dict = {}
def move_history_from_session_to_global_memory() -> None:
global history_matrix
if session.get( "session_hash") and session["history"]:
history_matrix[session["session_hash"]] = session["history"]
def move_history_from_global_memory_to_session() -> None:
global history_matrix
if session.get( "session_hash"):
session["history"] = history_matrix.get( session.get( "session_hash") )
def set_args() -> argparse.Namespace:
parser:argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库")
parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径")
return parser.parse_args()
@app.route("/chitchat/history", methods = ["GET"])
def get_history_list() -> str:
global tokenizer
move_history_from_global_memory_to_session()
history_list:list = session.get("history")
if history_list is None:
history_list = []
history:list = []
for history_ids in history_list:
tokens = tokenizer.convert_ids_to_tokens(history_ids)
fixed_tokens = []
for token in tokens:
if token.startswith("##"):
token = token[2:]
fixed_tokens.append(token)
history.append( "".join( fixed_tokens ) )
return jsonify(history)
@app.route("/chitchat/chat", methods = ["GET"])
def talk() -> str:
global tokenizer
global history_matrix
if request.args.get("hash"):
session["session_hash"] = request.args.get("hash")
move_history_from_global_memory_to_session()
if session.get("session_hash") is None:
session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) )
if request.args.get("text"):
input_text = request.args.get("text")
history_list = session.get("history")
if input_text.upper()=="HELP":
help_info_list = ["输入任意文字,Winnie会和你对话",
"输入ERASE MEMORY,Winnie会清空记忆",
"输入\"<TAG>=<VALUE>\",Winnie会记录你的角色信息",
"例如:<NAME>=Vicky,Winnie会修改自己的名字",
"可以修改的角色信息有:",
"<NAME>, <GENDER>, <YEAROFBIRTH>, <MONTHOFBIRTH>, <DAYOFBIRTH>, <ZODIAC>, <AGE>",
"输入“上联:XXXXXXX”,Winnie会和你对对联",
"输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗",
"以\"请问\"开头并以问号结尾,Winnie会回答该问题"
]
return jsonify(help_info_list)
if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY":
history_list = []
output_text = requests.post(
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]}
).json()["data"][0]
if input_text != "ERASE MEMORY":
if not re.match( r"^<.+>=.+$", input_text ):
history_list.append( tokenizer.encode(input_text, add_special_tokens=False) )
output_text = requests.post(
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
json={"data": [input_text], "session_hash": session["session_hash"]}
).json()["data"][0]
if not re.match( r"^<.+>=.+$", input_text ):
history_list.append( tokenizer.encode(output_text, add_special_tokens=False) )
session["history"] = history_list
history_matrix[session["session_hash"]] = history_list
return jsonify([output_text])
else:
return jsonify([""])
@app.route("/")
def index() -> str:
return "Hello world!"
@app.route("/chitchat/hash", methods = ["GET"])
def get_hash() -> str:
global history_matrix
if request.args.get("hash"):
session["session_hash"] = request.args.get("hash")
move_history_from_global_memory_to_session()
hash = session.get("session_hash")
if hash:
return session.get("session_hash")
else:
return " "
@app.route( "/chitchat", methods = ["GET"] )
def chitchat() -> str:
return render_template( "chat_template.html" )
def main() -> None:
global tokenizer
args = set_args()
tokenizer = ChatBot.get_tokenizer(
args.model_path,
vocab_path=args.vocab_path,
special_token_list = special_token_list
)
app.run( host = "127.0.0.1", port = 8080 )
if __name__ == "__main__":
main()