|
|
|
|
|
|
|
from flask import Flask, request, jsonify |
|
import threading |
|
from flask_cors import CORS |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1' |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
CORS(app) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/baichuan-7B", trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/baichuan-7B", device_map="auto", trust_remote_code=True) |
|
|
|
|
|
lock = threading.Lock() |
|
counter = 0 |
|
MAX_CONCURRENT_REQUESTS = 50 |
|
|
|
|
|
@app.route('/baichuan/conversation', methods=['POST']) |
|
def conversation(): |
|
global counter |
|
|
|
|
|
if counter >= MAX_CONCURRENT_REQUESTS: |
|
return jsonify({'message': '请稍等再试'}) |
|
|
|
|
|
with lock: |
|
counter += 1 |
|
|
|
try: |
|
|
|
question = request.json['question'] |
|
question += '->' |
|
|
|
|
|
inputs = tokenizer(question, return_tensors='pt') |
|
inputs = inputs.to('cuda:0') |
|
|
|
|
|
pred = model.generate(**inputs, max_new_tokens=1024, repetition_penalty=1.1) |
|
text = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) |
|
|
|
|
|
response = {'result': text[len(question):]} |
|
return jsonify(response) |
|
|
|
finally: |
|
|
|
with lock: |
|
counter -= 1 |
|
|
|
|
|
if __name__ == '__main__': |
|
print("Flask服务器已启动") |
|
app.run(host='0.0.0.0', port=30908) |
|
|