Spaces:
Runtime error
Runtime error
from flask import Flask, request, jsonify, current_app | |
from flask_cors import CORS, cross_origin | |
from typing import List | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM | |
import os | |
import threading | |
import time | |
import torch | |
import yaml | |
tokenizer = None | |
model = None | |
personas = [] | |
def threaded_load_model(): | |
global tokenizer, model, personas | |
try: | |
cache_dir = os.environ['TRANSFORMERS_CACHE'] | |
print("TRANSFORMERS_CACHE", cache_dir) | |
print("Loading model...") | |
start_time = time.time() | |
tokenizer = AutoTokenizer.from_pretrained( | |
"microsoft/DialoGPT-medium", cache_dir=cache_dir, local_files_only=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
"af1tang/personaGPT", cache_dir=cache_dir, local_files_only=True) | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
elapsed = round(time.time() - start_time, 1) | |
print(f"Model loaded after {elapsed} seconds") | |
# personalities | |
with open('persona.yaml', 'r') as f: | |
personal_dict = yaml.load(f) | |
personas_text = personal_dict['persona'] | |
for i in range(len(personas_text)): | |
response = personas_text[i] + tokenizer.eos_token | |
personas.append(response) | |
personas = tokenizer.encode( | |
''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>'])) | |
except Exception as e: | |
print(e) | |
def flatten(l): return [item for sublist in l for item in sublist] | |
def to_data(x): | |
if torch.cuda.is_available(): | |
x = x.cpu() | |
return x.data.numpy() | |
def to_var(x): | |
if not torch.is_tensor(x): | |
x = torch.Tensor(x) | |
if torch.cuda.is_available(): | |
x = x.cuda() | |
return x | |
def generate_next(bot_input_ids, top_k=10, top_p=.92, max_length=1000): | |
full_msg = model.generate(bot_input_ids, do_sample=True, | |
top_k=top_k, top_p=top_p, | |
max_length=max_length, pad_token_id=tokenizer.eos_token_id) | |
msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:] | |
return msg | |
def get_bot_response(text_input: str, dialog_history: List[str]): | |
user_inp = tokenizer.encode(text_input + tokenizer.eos_token) | |
# append to the chat history | |
dialog_history.append(user_inp) | |
# generated a response while limiting the total chat history to 1000 tokens, | |
bot_input_ids = to_var([personas + flatten(dialog_history)]).long() | |
msg = generate_next(bot_input_ids) | |
bot_response_text = tokenizer.decode(msg, skip_special_tokens=True) | |
print(f"User: {text_input}\nBot: {bot_response_text}\n") | |
return bot_response_text, dialog_history | |
app = Flask(__name__, | |
static_url_path='', | |
static_folder='./frontend/out') | |
cors = CORS(app) | |
app.config['CORS_HEADERS'] = 'Content-Type' | |
def index(): | |
return current_app.send_static_file('index.html') | |
def ask(): | |
request_json = request.get_json() | |
if not request_json: | |
return jsonify(error="Missing params: { text_input: str, dialog_history: List[str] }") | |
text_input = request_json.get("text_input", "hi") | |
dialog_history = request_json.get("dialog_history", []) | |
bot_response_text, new_dialog_history = get_bot_response( | |
text_input, dialog_history) | |
return jsonify(bot_response_text=bot_response_text, dialog_history=new_dialog_history) | |
def threaded_web_server(app): | |
app.run(host='0.0.0.0', port=7860, threaded=True) | |
web_server_thread = threading.Thread(target=threaded_web_server, args=(app,)) | |
web_server_thread.start() | |
threaded_load_model() | |
# web_server_thread.join() | |