import os import spacy from accelerate import PartialState from accelerate.utils import set_seed from flask import Flask, request, jsonify from gpt2_generation import Translator from gpt2_generation import generate_prompt, MODEL_CLASSES os.environ["http_proxy"] = "http://127.0.0.1:7890" os.environ["https_proxy"] = "http://127.0.0.1:7890" app = Flask(__name__) path_for_model = "./output/gpt2_openprompt/checkpoint-4500" args = { "model_type": "gpt2", "model_name_or_path": path_for_model, "length": 80, "stop_token": None, "temperature": 1.0, "length_penalty": 1.2, "repetition_penalty": 1.2, "k": 3, "p": 0.9, "prefix": "", "padding_text": "", "xlm_language": "", "seed": 42, "use_cpu": False, "num_return_sequences": 1, "fp16": False, "jit": False, } distributed_state = PartialState(cpu=args["use_cpu"]) if args["seed"] is not None: set_seed(args["seed"]) tokenizer = None model = None zh_en_translator = None nlp = None def load_model_and_components(): global tokenizer, model, zh_en_translator, nlp # Initialize the model and tokenizer try: args["model_type"] = args["model_type"].lower() model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]] except KeyError: raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], padding_side='left') tokenizer.pad_token = tokenizer.eos_token tokenizer.mask_token = tokenizer.eos_token model = model_class.from_pretrained(args["model_name_or_path"]) print("Model loaded!") # translator zh_en_translator = Translator("Helsinki-NLP/opus-mt-zh-en") print("Translator loaded!") # filter nlp = spacy.load('en_core_web_sm') print("Filter loaded!") # Set the model to the right device model.to(distributed_state.device) if args["fp16"]: model.half() @app.route('/chat', methods=['POST']) def chat(): phrase = request.json.get('phrase') if tokenizer is None or model is None or zh_en_translator is None or nlp is None: load_model_and_components() messages = generate_prompt( prompt_text=phrase, args=args, zh_en_translator=zh_en_translator, nlp=nlp, model=model, tokenizer=tokenizer, distributed_state=distributed_state, ) return jsonify(messages) if __name__ == '__main__': load_model_and_components() app.run(host='0.0.0.0', port=10008, debug=False)