tianlong12 commited on
Commit
0b9ff8b
1 Parent(s): 25a7ff3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, Response, stream_with_context, jsonify
2
+ from openai import OpenAI
3
+ import json
4
+ import tiktoken
5
+ #import httpx
6
+
7
+ app = Flask(__name__)
8
+
9
+ # 在请求头中指定你的API密钥名称
10
+ #MY_API_KEY = "sk-gyxzhao"
11
+
12
+ # 模型的最大上下文长度
13
+ MODEL_MAX_CONTEXT_LENGTH = {
14
+ "gpt-4": 8192,
15
+ "gpt-4-0613": 8192,
16
+ "gpt-4o": 4096,
17
+ "gpt-4-turbo": 4096,
18
+ "claude-3-opus-20240229": 4096
19
+ }
20
+
21
+ def calculate_max_tokens(model_name, messages, requested_max_tokens):
22
+ if model_name in ["gpt-4", "gpt-4-0613"]:
23
+ try:
24
+ encoding = tiktoken.encoding_for_model(model_name)
25
+ except Exception as e:
26
+ print(f"Error getting encoding for model {model_name}: {e}")
27
+ encoding = tiktoken.get_encoding("cl100k_base") # 使用通用编码作为后备
28
+
29
+ max_context_length = MODEL_MAX_CONTEXT_LENGTH[model_name]
30
+
31
+ tokens_per_message = 3 # 每个消息的固定令牌数 (role + content + message boundary tokens)
32
+ tokens_per_name = 1 # 如果消息中包含'name'字段,增加的令牌数
33
+ messages_length = 3 # 一开始的消息长度
34
+
35
+ for message in messages:
36
+ messages_length += tokens_per_message
37
+ for key, value in message.items():
38
+ messages_length += len(encoding.encode(value))
39
+ if key == 'name':
40
+ messages_length += tokens_per_name
41
+
42
+ #print(f"Message length in tokens: {messages_length}") # 打印消息长度以进行调试
43
+
44
+ max_tokens = max_context_length - messages_length
45
+ if requested_max_tokens:
46
+ max_tokens = min(max_tokens, requested_max_tokens)
47
+
48
+ return max(100, max_tokens) # 确保max_tokens至少为1
49
+
50
+ else:
51
+ return MODEL_MAX_CONTEXT_LENGTH.get(model_name, 4096) # 其他模型直接返回对应的最大token数
52
+
53
+ @app.route('/v1/chat/completions', methods=['POST'])
54
+ def chat():
55
+ try:
56
+ # 验证请求头中的API密钥
57
+ auth_header = request.headers.get('Authorization')
58
+ if not auth_header or not auth_header.startswith('Bearer '):
59
+ return jsonify({"error": "Unauthorized"}), 401
60
+
61
+ api_key = auth_header.split(" ")[1]
62
+
63
+ data = request.json
64
+ #print("Received data:", data) # 打印请求体以进行调试
65
+
66
+ # 验证请求格式
67
+ if not data or 'messages' not in data or 'model' not in data:
68
+ return jsonify({"error": "Missing 'messages' or 'model' in request body"}), 400
69
+
70
+ model = data['model']
71
+ messages = data['messages']
72
+ temperature = data.get('temperature', 0.7) # 默认值0.7
73
+ requested_max_tokens = data.get('max_tokens', MODEL_MAX_CONTEXT_LENGTH.get(model, 4096))
74
+ #max_tokens = calculate_max_tokens(model, messages, requested_max_tokens)
75
+ top_p = data.get('top_p', 1.0) # 默认值1.0
76
+ n = data.get('n', 1) # 默认值1
77
+ stream = data.get('stream', False) # 默认值False
78
+ functions = data.get('functions', None) # Functions for function calling
79
+ function_call = data.get('function_call', None) # Specific function call request
80
+
81
+ # 检查 Claude 模型,调整消息格式
82
+ system_message = None
83
+ if model.startswith("claude"):
84
+ messages = [msg for msg in messages if msg['role'] != 'system']
85
+ if 'system' in data:
86
+ system_message = data['system']
87
+
88
+ # 创建每个请求的 OpenAI 客户端实例
89
+ client = OpenAI(
90
+ api_key=api_key,
91
+ base_url="https://api.aimlapi.com",
92
+ )
93
+
94
+ # 处理模型响应
95
+ if stream:
96
+ # 处理流式响应
97
+ def generate():
98
+ if model.startswith("claude"):
99
+ response = client.chat.completions.create(
100
+ model=model,
101
+ messages=messages,
102
+ temperature=temperature,
103
+ #max_tokens=max_tokens,
104
+ top_p=top_p,
105
+ n=n,
106
+ functions=functions,
107
+ function_call=function_call,
108
+ #system=system_message # 传递 system_message 作为顶级参数
109
+ )
110
+ content = response.choices[0].message.content
111
+ for i in range(0, len(content), 20): # 每20个字符分成一块
112
+ chunk = content[i:i+20]
113
+ yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
114
+ else:
115
+ response = client.chat.completions.create(
116
+ model=model,
117
+ messages=messages,
118
+ temperature=temperature,
119
+ #max_tokens=max_tokens,
120
+ top_p=top_p,
121
+ n=n,
122
+ stream=True,
123
+ functions=functions,
124
+ function_call=function_call
125
+ )
126
+ for chunk in response:
127
+ yield f"data: {json.dumps(chunk.to_dict())}\n\n"
128
+
129
+ return Response(stream_with_context(generate()), content_type='text/event-stream')
130
+ else:
131
+ # 非流式响应
132
+ if model.startswith("claude"):
133
+ response = client.chat.completions.create(
134
+ model=model,
135
+ messages=messages,
136
+ temperature=temperature,
137
+ #max_tokens=max_tokens,
138
+ top_p=top_p,
139
+ n=n,
140
+ functions=functions,
141
+ function_call=function_call,
142
+ #system=system_message # 传递 system_message 作为顶级参数
143
+ )
144
+ else:
145
+ response = client.chat.completions.create(
146
+ model=model,
147
+ messages=messages,
148
+ temperature=temperature,
149
+ #max_tokens=max_tokens,
150
+ top_p=top_p,
151
+ n=n,
152
+ functions=functions,
153
+ function_call=function_call,
154
+ )
155
+
156
+
157
+ # 打印响应
158
+ #print("API response:", response)
159
+
160
+ # 将响应转换为字典
161
+ response_dict = {
162
+ "id": response.id,
163
+ "object": response.object,
164
+ "created": response.created,
165
+ "model": response.model,
166
+ "choices": [
167
+ {
168
+ "message": {
169
+ "role": choice.message.role,
170
+ "content": choice.message.content
171
+ },
172
+ "index": choice.index,
173
+ "finish_reason": choice.finish_reason,
174
+ "logprobs": choice.logprobs.__dict__ if choice.logprobs else None # 转换ChoiceLogprobs为字典
175
+ }
176
+ for choice in response.choices
177
+ ],
178
+ "usage": {
179
+ "prompt_tokens": response.usage.prompt_tokens,
180
+ "completion_tokens": response.usage.completion_tokens,
181
+ "total_tokens": response.usage.total_tokens
182
+ }
183
+ }
184
+
185
+ # 打印JSON格式的响应字典
186
+ #print("Response dict:", json.dumps(response_dict, ensure_ascii=False, indent=2))
187
+
188
+ # 确保返回的JSON格式正确
189
+ return jsonify(response_dict), 200
190
+
191
+ except Exception as e:
192
+ print("Exception:", e)
193
+ return jsonify({"error": str(e)}), 500
194
+
195
+ if __name__ == '__main__':
196
+ app.run(host='0.0.0.0', port=4500, debug=False,processes=4)