EasyPrompt / app.py
Trace2333's picture
modified files to test remote POST
115460a
import os
import spacy
from accelerate import PartialState
from accelerate.utils import set_seed
from flask import Flask, request, jsonify
from gpt2_generation import Translator
from gpt2_generation import generate_prompt, MODEL_CLASSES
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"
app = Flask(__name__)
path_for_model = "./output/gpt2_openprompt/checkpoint-4500"
args = {
"model_type": "gpt2",
"model_name_or_path": path_for_model,
"length": 80,
"stop_token": None,
"temperature": 1.0,
"length_penalty": 1.2,
"repetition_penalty": 1.2,
"k": 3,
"p": 0.9,
"prefix": "",
"padding_text": "",
"xlm_language": "",
"seed": 42,
"use_cpu": False,
"num_return_sequences": 1,
"fp16": False,
"jit": False,
}
distributed_state = PartialState(cpu=args["use_cpu"])
if args["seed"] is not None:
set_seed(args["seed"])
tokenizer = None
model = None
zh_en_translator = None
nlp = None
def load_model_and_components():
global tokenizer, model, zh_en_translator, nlp
# Initialize the model and tokenizer
try:
args["model_type"] = args["model_type"].lower()
model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
except KeyError:
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.mask_token = tokenizer.eos_token
model = model_class.from_pretrained(args["model_name_or_path"])
print("Model loaded!")
# translator
zh_en_translator = Translator("Helsinki-NLP/opus-mt-zh-en")
print("Translator loaded!")
# filter
nlp = spacy.load('en_core_web_sm')
print("Filter loaded!")
# Set the model to the right device
model.to(distributed_state.device)
if args["fp16"]:
model.half()
@app.route('/chat', methods=['POST'])
def chat():
phrase = request.json.get('phrase')
if tokenizer is None or model is None or zh_en_translator is None or nlp is None:
load_model_and_components()
messages = generate_prompt(
prompt_text=phrase,
args=args,
zh_en_translator=zh_en_translator,
nlp=nlp,
model=model,
tokenizer=tokenizer,
distributed_state=distributed_state,
)
return jsonify(messages)
if __name__ == '__main__':
load_model_and_components()
app.run(host='0.0.0.0', port=10008, debug=False)