File size: 2,386 Bytes
4231e0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoModelForCausalLM, AutoTokenizer
from flask import Flask, request
import argparse
import logging


class LLMInstance:

    def __init__(self, model_path: str, device: str = "cuda"):

        self.model = AutoModelForCausalLM.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model.to(device)
        self.device = device

    def query(self, message):
        try:
            messages = [
                {"role": "user", "content": message},
            ]
            encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
            model_inputs = encodeds.to(self.device)

            generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
            decoded = self.tokenizer.batch_decode(generated_ids)

            # output is the string decoded[0] after "[/INST]". There may exist "</s>", delete it.
            output = decoded[0].split("[/INST]")[1].split("</s>")[0]
            return {
                'code': 0,
                'ret': True,
                'error_msg': None,
                'output': output
            }
        except Exception as e:
            return {
                'code': 1,
                'ret': False,
                'error_msg': str(e),
                'output': None
            }


def create_app(core):
    app = Flask(__name__)

    @app.route('/ask_llm_for_answer', methods=['POST'])
    def ask_llm_for_answer():
        user_text = request.json['user_text']
        return core.query(user_text)

    return app


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model_path', required=True, default='Mistral-7B-Instruct-v0.1', help='the model path of reward model')
    parser.add_argument('--ip', default='0.0.0.0')
    parser.add_argument('-p', '--port', default=8001)
    parser.add_argument('--debug', action='store_true')
    args = parser.parse_args()

    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)
    else:
        logging.getLogger().setLevel(logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.getLogger().handlers[0].setFormatter(logging.Formatter("%(message)s"))

    core = LLMInstance(args.model_path)
    app = create_app(core)
    app.run(host=args.ip, port=args.port)