from fastapi import APIRouter, Depends from langchain_core.prompts import ChatPromptTemplate from datetime import datetime # import uuid from global_state import get from db.tbs_db import TbsDb from auth import get_current_user from db_model.user import UserModel from db_model.chat import ChatModel router = APIRouter() db_module_filename = f"{get('project_root')}/db/cloudflare.py" @router.post("/chat/completions") async def chat_completions(chat_model:ChatModel, current_user: UserModel = Depends(get_current_user)): try: model = chat_model.model except: model = '' if (model=='')or(model is None): model = await get_default_model() api_key_info = await get_api_key(model) api_key = api_key_info.get('api_key', '') group_name = api_key_info.get('group_name', '') base_url = api_key_info.get('base_url', '') if group_name=='gemini': # google api,生成 gemini 的 llm from langchain_google_genai import ChatGoogleGenerativeAI llm = ChatGoogleGenerativeAI( api_key = api_key, model = model, ) else: # 下面就是 chatgpt 兼容 api from langchain_openai import ChatOpenAI # 初始化 ChatOpenAI 模型 llm = ChatOpenAI( model = model, api_key = api_key, base_url = base_url, ) # 生成prompt模板 lc_messages = [(message.role, message.content) for message in chat_model.messages] prompt_template = ChatPromptTemplate.from_messages(lc_messages) chain = prompt_template | llm try: result = chain.invoke({}) # AIMessage 类对象 except Exception as e: return {'error': str(e)} # 转换为OpenAI格式 converted_data = convert_to_openai_format(result) return converted_data # 从数据库获取默认模型 async def get_default_model(): query = f"SELECT * FROM api_names order by default_order limit 1" response = TbsDb(db_module_filename, "Cloudflare").get_item(query) try: result = response['result'][0]['results'][0]['api_name'] except: result = '' return result async def get_api_key(model): query = f""" SELECT an.api_name, ak.api_key, an.base_url, ag.group_name FROM api_keys ak JOIN api_groups ag ON ak.api_group_id = ag.id JOIN api_names an ON an.api_group_id = ag.id WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0 ORDER BY ak.last_call_at limit 1 """ response = TbsDb(db_module_filename, "Cloudflare").get_item(query) try: result = response['result'][0]['results'][0] api_key = result['api_key'] except: api_key = '' query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'" TbsDb(db_module_filename, "Cloudflare").execute_query(query) return result def convert_to_openai_format(original_json): # 创建新的JSON对象 new_json = { "id": "chatcmpl-123", # 这里可以生成一个唯一的ID,或者使用传入的id "object": "chat.completion", "created": int(datetime.now().timestamp()), # 当前时间戳 "choices": [ { "index": 0, "message": { "role": "assistant", "content": original_json.content # 使用原始内容 }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": original_json.usage_metadata.get("input_tokens",0), "completion_tokens": original_json.usage_metadata.get("output_tokens", 0), "total_tokens": original_json.usage_metadata.get("total_tokens", 0) } } return new_json