File size: 2,608 Bytes
0032d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
115460a
0032d0a
 
 
 
 
 
 
115460a
0032d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115460a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import spacy
from accelerate import PartialState
from accelerate.utils import set_seed
from flask import Flask, request, jsonify

from gpt2_generation import Translator
from gpt2_generation import generate_prompt, MODEL_CLASSES

os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

app = Flask(__name__)

path_for_model = "./output/gpt2_openprompt/checkpoint-4500"

args = {
    "model_type": "gpt2",
    "model_name_or_path": path_for_model,
    "length": 80,
    "stop_token": None,
    "temperature": 1.0,
    "length_penalty": 1.2,
    "repetition_penalty": 1.2,
    "k": 3,
    "p": 0.9,
    "prefix": "",
    "padding_text": "",
    "xlm_language": "",
    "seed": 42,
    "use_cpu": False,
    "num_return_sequences": 1,
    "fp16": False,
    "jit": False,
}

distributed_state = PartialState(cpu=args["use_cpu"])

if args["seed"] is not None:
    set_seed(args["seed"])

tokenizer = None
model = None
zh_en_translator = None
nlp = None

def load_model_and_components():
    global tokenizer, model, zh_en_translator, nlp

    # Initialize the model and tokenizer
    try:
        args["model_type"] = args["model_type"].lower()
        model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
    except KeyError:
        raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

    tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], padding_side='left')
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.mask_token = tokenizer.eos_token
    model = model_class.from_pretrained(args["model_name_or_path"])
    print("Model loaded!")

    # translator
    zh_en_translator = Translator("Helsinki-NLP/opus-mt-zh-en")
    print("Translator loaded!")

    # filter
    nlp = spacy.load('en_core_web_sm')
    print("Filter loaded!")

    # Set the model to the right device
    model.to(distributed_state.device)

    if args["fp16"]:
        model.half()

@app.route('/chat', methods=['POST'])
def chat():
    phrase = request.json.get('phrase')

    if tokenizer is None or model is None or zh_en_translator is None or nlp is None:
        load_model_and_components()

    messages = generate_prompt(
        prompt_text=phrase,
        args=args,
        zh_en_translator=zh_en_translator,
        nlp=nlp,
        model=model,
        tokenizer=tokenizer,
        distributed_state=distributed_state,
    )

    return jsonify(messages)

if __name__ == '__main__':
    load_model_and_components()
    app.run(host='0.0.0.0', port=10008, debug=False)