kaixuan42 commited on
Commit
7a10dd5
1 Parent(s): ef12332

Create app_http.py

Browse files
Files changed (1) hide show
  1. app_http.py +65 -0
app_http.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 服务器端代码
2
+
3
+ # 导入Flask库和其他必要的库
4
+ from flask import Flask, request, jsonify
5
+ import threading
6
+ from flask_cors import CORS
7
+ import os
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+ # 设置可见的GPU设备
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1'
12
+
13
+ # 创建Flask应用对象
14
+ app = Flask(__name__)
15
+ # 允许跨域请求
16
+ CORS(app)
17
+
18
+ # 加载百川大模型的分词器和模型
19
+ tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/baichuan-7B", trust_remote_code=True)
20
+ model = AutoModelForCausalLM.from_pretrained("baichuan-inc/baichuan-7B", device_map="auto", trust_remote_code=True)
21
+
22
+ # 创建线程锁和计数器
23
+ lock = threading.Lock()
24
+ counter = 0
25
+ MAX_CONCURRENT_REQUESTS = 50 # 最大并发请求数
26
+
27
+ # 定义服务接口的路由和方法
28
+ @app.route('/baichuan/conversation', methods=['POST'])
29
+ def conversation():
30
+ global counter # 使用全局变量
31
+
32
+ # 请求过载,返回提示信息
33
+ if counter >= MAX_CONCURRENT_REQUESTS:
34
+ return jsonify({'message': '请稍等再试'})
35
+
36
+ # 获取线程锁
37
+ with lock:
38
+ counter += 1 # 增加计数器
39
+
40
+ try:
41
+ # 接收POST请求的数据
42
+ question = request.json['question']
43
+ question += '->' # 添加分隔符
44
+
45
+ # 对输入进行分词和编码
46
+ inputs = tokenizer(question, return_tensors='pt')
47
+ inputs = inputs.to('cuda:0') # 移动到GPU上
48
+
49
+ # 调用模型进行生成
50
+ pred = model.generate(**inputs, max_new_tokens=1024, repetition_penalty=1.1)
51
+ text = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) # 对输出进行解码
52
+
53
+ # 返回结果
54
+ response = {'result': text[len(question):]} # 去掉输入部分
55
+ return jsonify(response)
56
+
57
+ finally:
58
+ # 释放线程锁并减少计数器
59
+ with lock:
60
+ counter -= 1
61
+
62
+ # 主函数
63
+ if __name__ == '__main__':
64
+ print("Flask服务器已启动")
65
+ app.run(host='0.0.0.0', port=30908) # 设置主机地址和端口号