File size: 7,849 Bytes
0b9ff8b
 
 
ea2c97f
 
0b9ff8b
 
ea2c97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fea061
28afeea
3fea061
 
 
 
 
0b9ff8b
 
 
ea2c97f
0b9ff8b
 
3fea061
 
 
 
 
0b9ff8b
 
 
ea2c97f
0b9ff8b
 
 
3fea061
0b9ff8b
 
 
ea2c97f
 
 
 
 
 
 
0b9ff8b
 
 
ea2c97f
0b9ff8b
 
 
 
 
 
ea2c97f
0b9ff8b
 
 
 
 
 
 
 
ea2c97f
 
0b9ff8b
ea2c97f
 
 
 
 
 
 
 
3fea061
 
 
 
 
 
 
 
 
 
0b9ff8b
 
 
ea2c97f
 
 
 
 
 
3fea061
 
ea2c97f
 
 
 
 
 
 
0b9ff8b
 
 
 
 
 
 
 
ea2c97f
0b9ff8b
 
ea2c97f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fea061
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from flask import Flask, request, Response, stream_with_context, jsonify
from openai import OpenAI
import json
import tiktoken
#import httpx

app = Flask(__name__)

# 在请求头中指定你的API密钥名称
#MY_API_KEY = "sk-gyxzhao"

# 模型的最大上下文长度
MODEL_MAX_CONTEXT_LENGTH = {
    "gpt-4": 8192,
    "gpt-4-0613": 8192,
    "gpt-4o": 4096,
    "gpt-4-turbo": 4096,
    "claude-3-opus-20240229": 4096
}

def calculate_max_tokens(model_name, messages, requested_max_tokens):
    if model_name in ["gpt-4", "gpt-4-0613"]:
        try:
            encoding = tiktoken.encoding_for_model(model_name)
        except Exception as e:
            print(f"Error getting encoding for model {model_name}: {e}")
            encoding = tiktoken.get_encoding("cl100k_base")  # 使用通用编码作为后备
        
        max_context_length = MODEL_MAX_CONTEXT_LENGTH[model_name]
        
        tokens_per_message = 3  # 每个消息的固定令牌数 (role + content + message boundary tokens)
        tokens_per_name = 1  # 如果消息中包含'name'字段,增加的令牌数
        messages_length = 3  # 一开始的消息长度
        
        for message in messages:
            messages_length += tokens_per_message
            for key, value in message.items():
                messages_length += len(encoding.encode(value))
                if key == 'name':
                    messages_length += tokens_per_name
        
        #print(f"Message length in tokens: {messages_length}")  # 打印消息长度以进行调试
        
        max_tokens = max_context_length - messages_length
        if requested_max_tokens:
            max_tokens = min(max_tokens, requested_max_tokens)
        
        return max(100, max_tokens)  # 确保max_tokens至少为1
    
    else:
        return MODEL_MAX_CONTEXT_LENGTH.get(model_name, 4096)  # 其他模型直接返回对应的最大token数

@app.route('/hf/v1/chat/completions', methods=['POST'])
def chat():
    try:
        # 验证请求头中的API密钥
        auth_header = request.headers.get('Authorization')
        if not auth_header or not auth_header.startswith('Bearer '):
            return jsonify({"error": "Unauthorized"}), 401
        
        api_key = auth_header.split(" ")[1]

        data = request.json
        #print("Received data:", data)  # 打印请求体以进行调试
        
        # 验证请求格式
        if not data or 'messages' not in data or 'model' not in data:
            return jsonify({"error": "Missing 'messages' or 'model' in request body"}), 400

        model = data['model']
        messages = data['messages']
        temperature = data.get('temperature', 0.7)  # 默认值0.7
        requested_max_tokens = data.get('max_tokens', MODEL_MAX_CONTEXT_LENGTH.get(model, 4096))
        #max_tokens = calculate_max_tokens(model, messages, requested_max_tokens)
        top_p = data.get('top_p', 1.0)              # 默认值1.0
        n = data.get('n', 1)                        # 默认值1
        stream = data.get('stream', False)          # 默认值False
        functions = data.get('functions', None)     # Functions for function calling
        function_call = data.get('function_call', None)  # Specific function call request

        # 检查 Claude 模型,调整消息格式
        system_message = None
        if model.startswith("claude"):
            messages = [msg for msg in messages if msg['role'] != 'system']
            if 'system' in data:
                system_message = data['system']

        # 创建每个请求的 OpenAI 客户端实例
        client = OpenAI(
            api_key=api_key,
            base_url="https://api.aimlapi.com",
        )

        # 处理模型响应
        if stream:
            # 处理流式响应
            def generate():
                if model.startswith("claude"):
                    response = client.chat.completions.create(
                        model=model,
                        messages=messages,
                        temperature=temperature,
                        #max_tokens=max_tokens,
                        top_p=top_p,
                        n=n,
                        functions=functions,
                        function_call=function_call,
                        #system=system_message  # 传递 system_message 作为顶级参数
                    )
                    content = response.choices[0].message.content
                    for i in range(0, len(content), 20):  # 每20个字符分成一块
                        chunk = content[i:i+20]
                        yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
                else:
                    response = client.chat.completions.create(
                        model=model,
                        messages=messages,
                        temperature=temperature,
                        #max_tokens=max_tokens,
                        top_p=top_p,
                        n=n,
                        stream=True,
                        functions=functions,
                        function_call=function_call
                    )
                    for chunk in response:
                        yield f"data: {json.dumps(chunk.to_dict())}\n\n"

            return Response(stream_with_context(generate()), content_type='text/event-stream')
        else:
            # 非流式响应
            if model.startswith("claude"):
                response = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                #max_tokens=max_tokens,
                top_p=top_p,
                n=n,
                functions=functions,
                function_call=function_call,
                #system=system_message  # 传递 system_message 作为顶级参数
                )
            else:
                response = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                #max_tokens=max_tokens,
                top_p=top_p,
                n=n,
                functions=functions,
                function_call=function_call,
                )


            # 打印响应
            #print("API response:", response)
            
            # 将响应转换为字典
            response_dict = {
                "id": response.id,
                "object": response.object,
                "created": response.created,
                "model": response.model,
                "choices": [
                    {
                        "message": {
                            "role": choice.message.role,
                            "content": choice.message.content
                        },
                        "index": choice.index,
                        "finish_reason": choice.finish_reason,
                        "logprobs": choice.logprobs.__dict__ if choice.logprobs else None  # 转换ChoiceLogprobs为字典
                    }
                    for choice in response.choices
                ],
                "usage": {
                    "prompt_tokens": response.usage.prompt_tokens,
                    "completion_tokens": response.usage.completion_tokens,
                    "total_tokens": response.usage.total_tokens
                }
            }

            # 打印JSON格式的响应字典
            #print("Response dict:", json.dumps(response_dict, ensure_ascii=False, indent=2))

            # 确保返回的JSON格式正确
            return jsonify(response_dict), 200

    except Exception as e:
        print("Exception:", e)
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860, threaded=True)