File size: 3,743 Bytes
10bcd3f
daa9d8a
bad204d
 
daa9d8a
10bcd3f
 
 
 
 
daa9d8a
 
 
10bcd3f
2c7fa70
10bcd3f
 
daa9d8a
10bcd3f
daa9d8a
 
10bcd3f
 
daa9d8a
2c7fa70
10bcd3f
 
 
 
daa9d8a
10bcd3f
 
daa9d8a
 
 
 
10bcd3f
 
 
 
 
 
 
 
daa9d8a
10bcd3f
 
 
bad204d
2c7fa70
bad204d
2c7fa70
 
bad204d
 
 
 
daa9d8a
10bcd3f
daa9d8a
10bcd3f
 
daa9d8a
10bcd3f
daa9d8a
 
 
 
 
10bcd3f
2c7fa70
daa9d8a
 
 
 
 
 
 
10bcd3f
daa9d8a
10bcd3f
daa9d8a
 
 
 
10bcd3f
 
daa9d8a
2c7fa70
bad204d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi import APIRouter, Depends
from langchain_core.prompts import ChatPromptTemplate
from datetime import datetime
# import uuid

from global_state import get
from db.tbs_db import TbsDb
from auth import get_current_user
from db_model.user import UserModel
from db_model.chat import ChatModel

router = APIRouter()

db_module_filename = f"{get('project_root')}/db/cloudflare.py"

@router.post("/chat/completions")
async def chat_completions(chat_model:ChatModel, current_user: UserModel = Depends(get_current_user)):
    try:
        model = chat_model.model
    except:
        model = ''

    if (model=='')or(model is None):
        model = await get_default_model()

    api_key_info = await get_api_key(model)
    api_key = api_key_info.get('api_key', '')
    group_name = api_key_info.get('group_name', '')
    base_url = api_key_info.get('base_url', '')

    if group_name=='gemini': # google api,生成 gemini 的 llm
        from langchain_google_genai import ChatGoogleGenerativeAI    
        llm = ChatGoogleGenerativeAI(
            api_key = api_key,
            model = model,
        )
    else: # 下面就是 chatgpt 兼容 api
        from langchain_openai import ChatOpenAI
        # 初始化 ChatOpenAI 模型
        llm = ChatOpenAI(
            model = model,
            api_key = api_key,
            base_url = base_url,
        )
    
    # 生成prompt模板
    lc_messages = [(message.role, message.content) for message in chat_model.messages]
    prompt_template = ChatPromptTemplate.from_messages(lc_messages)
    chain = prompt_template | llm
    try:
        result = chain.invoke({}) # AIMessage 类对象
    except Exception as e:
        return {'error': str(e)}
    
    # 转换为OpenAI格式
    converted_data = convert_to_openai_format(result)
    return converted_data

# 从数据库获取默认模型
async def get_default_model():
    query = f"SELECT * FROM api_names order by default_order limit 1"
    response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
    try:
        result = response['result'][0]['results'][0]['api_name']
    except:
        result = ''
    return result

async def get_api_key(model):
    query = f"""
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
FROM api_keys ak
JOIN api_groups ag ON ak.api_group_id = ag.id
JOIN api_names an ON an.api_group_id = ag.id
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
ORDER BY ak.last_call_at
limit 1
"""
    response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
    try:
        result = response['result'][0]['results'][0]
        api_key = result['api_key']
    except:
        api_key = ''
    
    query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
    TbsDb(db_module_filename, "Cloudflare").execute_query(query)
    return result

def convert_to_openai_format(original_json):
    # 创建新的JSON对象
    new_json = {
        "id": "chatcmpl-123",  # 这里可以生成一个唯一的ID,或者使用传入的id
        "object": "chat.completion",
        "created": int(datetime.now().timestamp()),  # 当前时间戳
        "choices": [
            {
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": original_json.content  # 使用原始内容
                },
                "finish_reason": "stop"
            }
        ],
        "usage": {
            "prompt_tokens": original_json.usage_metadata.get("input_tokens",0),
            "completion_tokens": original_json.usage_metadata.get("output_tokens", 0),
            "total_tokens": original_json.usage_metadata.get("total_tokens", 0)
        }
    }
    return new_json