from flask import Flask, request, jsonify, current_app from flask_cors import CORS, cross_origin from typing import List from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM import os import threading import time import torch import yaml tokenizer = None model = None personas = [] def threaded_load_model(): global tokenizer, model, personas try: cache_dir = os.environ['TRANSFORMERS_CACHE'] print("TRANSFORMERS_CACHE", cache_dir) print("Loading model...") start_time = time.time() tokenizer = AutoTokenizer.from_pretrained( "microsoft/DialoGPT-medium", cache_dir=cache_dir, local_files_only=True) model = AutoModelForCausalLM.from_pretrained( "af1tang/personaGPT", cache_dir=cache_dir, local_files_only=True) if torch.cuda.is_available(): model = model.cuda() elapsed = round(time.time() - start_time, 1) print(f"Model loaded after {elapsed} seconds") # personalities with open('persona.yaml', 'r') as f: personal_dict = yaml.load(f) personas_text = personal_dict['persona'] for i in range(len(personas_text)): response = personas_text[i] + tokenizer.eos_token personas.append(response) personas = tokenizer.encode( ''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>'])) except Exception as e: print(e) def flatten(l): return [item for sublist in l for item in sublist] def to_data(x): if torch.cuda.is_available(): x = x.cpu() return x.data.numpy() def to_var(x): if not torch.is_tensor(x): x = torch.Tensor(x) if torch.cuda.is_available(): x = x.cuda() return x def generate_next(bot_input_ids, top_k=10, top_p=.92, max_length=1000): full_msg = model.generate(bot_input_ids, do_sample=True, top_k=top_k, top_p=top_p, max_length=max_length, pad_token_id=tokenizer.eos_token_id) msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:] return msg def get_bot_response(text_input: str, dialog_history: List[str]): user_inp = tokenizer.encode(text_input + tokenizer.eos_token) # append to the chat history dialog_history.append(user_inp) # generated a response while limiting the total chat history to 1000 tokens, bot_input_ids = to_var([personas + flatten(dialog_history)]).long() msg = generate_next(bot_input_ids) bot_response_text = tokenizer.decode(msg, skip_special_tokens=True) print(f"User: {text_input}\nBot: {bot_response_text}\n") return bot_response_text, dialog_history app = Flask(__name__, static_url_path='', static_folder='./frontend/out') cors = CORS(app) app.config['CORS_HEADERS'] = 'Content-Type' @app.route('/') @cross_origin() def index(): return current_app.send_static_file('index.html') @app.route('/ask', methods=['POST']) @cross_origin() def ask(): request_json = request.get_json() if not request_json: return jsonify(error="Missing params: { text_input: str, dialog_history: List[str] }") text_input = request_json.get("text_input", "hi") dialog_history = request_json.get("dialog_history", []) bot_response_text, new_dialog_history = get_bot_response( text_input, dialog_history) return jsonify(bot_response_text=bot_response_text, dialog_history=new_dialog_history) def threaded_web_server(app): app.run(host='0.0.0.0', port=7860, threaded=True) web_server_thread = threading.Thread(target=threaded_web_server, args=(app,)) web_server_thread.start() threaded_load_model() # web_server_thread.join()