Saksorn ruangtanusak
add yaml load persona
aff673e
from flask import Flask, request, jsonify, current_app
from flask_cors import CORS, cross_origin
from typing import List
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM
import os
import threading
import time
import torch
import yaml
tokenizer = None
model = None
personas = []
def threaded_load_model():
global tokenizer, model, personas
try:
cache_dir = os.environ['TRANSFORMERS_CACHE']
print("TRANSFORMERS_CACHE", cache_dir)
print("Loading model...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/DialoGPT-medium", cache_dir=cache_dir, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
"af1tang/personaGPT", cache_dir=cache_dir, local_files_only=True)
if torch.cuda.is_available():
model = model.cuda()
elapsed = round(time.time() - start_time, 1)
print(f"Model loaded after {elapsed} seconds")
# personalities
with open('persona.yaml', 'r') as f:
personal_dict = yaml.load(f)
personas_text = personal_dict['persona']
for i in range(len(personas_text)):
response = personas_text[i] + tokenizer.eos_token
personas.append(response)
personas = tokenizer.encode(
''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))
except Exception as e:
print(e)
def flatten(l): return [item for sublist in l for item in sublist]
def to_data(x):
if torch.cuda.is_available():
x = x.cpu()
return x.data.numpy()
def to_var(x):
if not torch.is_tensor(x):
x = torch.Tensor(x)
if torch.cuda.is_available():
x = x.cuda()
return x
def generate_next(bot_input_ids, top_k=10, top_p=.92, max_length=1000):
full_msg = model.generate(bot_input_ids, do_sample=True,
top_k=top_k, top_p=top_p,
max_length=max_length, pad_token_id=tokenizer.eos_token_id)
msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:]
return msg
def get_bot_response(text_input: str, dialog_history: List[str]):
user_inp = tokenizer.encode(text_input + tokenizer.eos_token)
# append to the chat history
dialog_history.append(user_inp)
# generated a response while limiting the total chat history to 1000 tokens,
bot_input_ids = to_var([personas + flatten(dialog_history)]).long()
msg = generate_next(bot_input_ids)
bot_response_text = tokenizer.decode(msg, skip_special_tokens=True)
print(f"User: {text_input}\nBot: {bot_response_text}\n")
return bot_response_text, dialog_history
app = Flask(__name__,
static_url_path='',
static_folder='./frontend/out')
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'
@app.route('/')
@cross_origin()
def index():
return current_app.send_static_file('index.html')
@app.route('/ask', methods=['POST'])
@cross_origin()
def ask():
request_json = request.get_json()
if not request_json:
return jsonify(error="Missing params: { text_input: str, dialog_history: List[str] }")
text_input = request_json.get("text_input", "hi")
dialog_history = request_json.get("dialog_history", [])
bot_response_text, new_dialog_history = get_bot_response(
text_input, dialog_history)
return jsonify(bot_response_text=bot_response_text, dialog_history=new_dialog_history)
def threaded_web_server(app):
app.run(host='0.0.0.0', port=7860, threaded=True)
web_server_thread = threading.Thread(target=threaded_web_server, args=(app,))
web_server_thread.start()
threaded_load_model()
# web_server_thread.join()