niuyazhe
feature(nyz): add naive Mistral-7B-Instruct-v0.1 demo
d2f4f1c
raw
history blame
2.39 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
from flask import Flask, request
import argparse
import logging
class LLMInstance:
def __init__(self, model_path: str, device: str = "cuda"):
self.model = AutoModelForCausalLM.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model.to(device)
self.device = device
def query(self, message):
try:
messages = [
{"role": "user", "content": message},
]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(self.device)
generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = self.tokenizer.batch_decode(generated_ids)
# output is the string decoded[0] after "[/INST]". There may exist "</s>", delete it.
output = decoded[0].split("[/INST]")[1].split("</s>")[0]
return {
'code': 0,
'ret': True,
'error_msg': None,
'output': output
}
except Exception as e:
return {
'code': 1,
'ret': False,
'error_msg': str(e),
'output': None
}
def create_app(core):
app = Flask(__name__)
@app.route('/ask_llm_for_answer', methods=['POST'])
def ask_llm_for_answer():
user_text = request.json['user_text']
return core.query(user_text)
return app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', required=True, default='Mistral-7B-Instruct-v0.1', help='the model path of reward model')
parser.add_argument('--ip', default='0.0.0.0')
parser.add_argument('-p', '--port', default=8001)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
else:
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
logging.getLogger().handlers[0].setFormatter(logging.Formatter("%(message)s"))
core = LLMInstance(args.model_path)
app = create_app(core)
app.run(host=args.ip, port=args.port)