Spaces:
Sleeping
Sleeping
File size: 3,743 Bytes
10bcd3f daa9d8a bad204d daa9d8a 10bcd3f daa9d8a 10bcd3f 2c7fa70 10bcd3f daa9d8a 10bcd3f daa9d8a 10bcd3f daa9d8a 2c7fa70 10bcd3f daa9d8a 10bcd3f daa9d8a 10bcd3f daa9d8a 10bcd3f bad204d 2c7fa70 bad204d 2c7fa70 bad204d daa9d8a 10bcd3f daa9d8a 10bcd3f daa9d8a 10bcd3f daa9d8a 10bcd3f 2c7fa70 daa9d8a 10bcd3f daa9d8a 10bcd3f daa9d8a 10bcd3f daa9d8a 2c7fa70 bad204d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from fastapi import APIRouter, Depends
from langchain_core.prompts import ChatPromptTemplate
from datetime import datetime
# import uuid
from global_state import get
from db.tbs_db import TbsDb
from auth import get_current_user
from db_model.user import UserModel
from db_model.chat import ChatModel
router = APIRouter()
db_module_filename = f"{get('project_root')}/db/cloudflare.py"
@router.post("/chat/completions")
async def chat_completions(chat_model:ChatModel, current_user: UserModel = Depends(get_current_user)):
try:
model = chat_model.model
except:
model = ''
if (model=='')or(model is None):
model = await get_default_model()
api_key_info = await get_api_key(model)
api_key = api_key_info.get('api_key', '')
group_name = api_key_info.get('group_name', '')
base_url = api_key_info.get('base_url', '')
if group_name=='gemini': # google api,生成 gemini 的 llm
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(
api_key = api_key,
model = model,
)
else: # 下面就是 chatgpt 兼容 api
from langchain_openai import ChatOpenAI
# 初始化 ChatOpenAI 模型
llm = ChatOpenAI(
model = model,
api_key = api_key,
base_url = base_url,
)
# 生成prompt模板
lc_messages = [(message.role, message.content) for message in chat_model.messages]
prompt_template = ChatPromptTemplate.from_messages(lc_messages)
chain = prompt_template | llm
try:
result = chain.invoke({}) # AIMessage 类对象
except Exception as e:
return {'error': str(e)}
# 转换为OpenAI格式
converted_data = convert_to_openai_format(result)
return converted_data
# 从数据库获取默认模型
async def get_default_model():
query = f"SELECT * FROM api_names order by default_order limit 1"
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
try:
result = response['result'][0]['results'][0]['api_name']
except:
result = ''
return result
async def get_api_key(model):
query = f"""
SELECT an.api_name, ak.api_key, an.base_url, ag.group_name
FROM api_keys ak
JOIN api_groups ag ON ak.api_group_id = ag.id
JOIN api_names an ON an.api_group_id = ag.id
WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0
ORDER BY ak.last_call_at
limit 1
"""
response = TbsDb(db_module_filename, "Cloudflare").get_item(query)
try:
result = response['result'][0]['results'][0]
api_key = result['api_key']
except:
api_key = ''
query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'"
TbsDb(db_module_filename, "Cloudflare").execute_query(query)
return result
def convert_to_openai_format(original_json):
# 创建新的JSON对象
new_json = {
"id": "chatcmpl-123", # 这里可以生成一个唯一的ID,或者使用传入的id
"object": "chat.completion",
"created": int(datetime.now().timestamp()), # 当前时间戳
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": original_json.content # 使用原始内容
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": original_json.usage_metadata.get("input_tokens",0),
"completion_tokens": original_json.usage_metadata.get("output_tokens", 0),
"total_tokens": original_json.usage_metadata.get("total_tokens", 0)
}
}
return new_json
|