Linly-Talker / api /llm_api.py
weismart1807's picture
Upload folder using huggingface_hub
e90b704 verified
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from loguru import logger
import gc, torch
import sys
sys.path.append('./')
app = FastAPI()
# 全局变量用于存储当前加载的LLM模型
from LLM import LLM
llm_class = LLM(mode='offline')
# 默认不使用LLM模型,直接回复问题,同时减少显存占用!
llm = llm_class.init_model('直接回复 Direct Reply')
# 默认系统提示语
PREFIX_PROMPT = '请用少于25个字回答以下问题\n\n'
PREFIX_PROMPT = ''
DEFAULT_SYSTEM = '你是一个很有帮助的助手'
class LLMRequest(BaseModel):
question: str = '请问什么是FastAPI?'
model_name: str = 'Linly'
gemini_apikey: str = '' # Gemini模型的API密钥
openai_apikey: str = '' # OpenAI的API密钥
proxy_url: str = None # 代理URL
@app.post("/llm_change_model/")
async def change_model(
model_name: str = Query(..., description="要加载的LLM模型名称"),
gemini_apikey: str = Query('', description="Gemini API 密钥"),
openai_apikey: str = Query('', description="OpenAI API 密钥"),
proxy_url: str = Query(None, description="代理 URL")
):
"""更换LLM模型并加载相应资源。"""
global llm
# 清理显存(具体实现依赖于模型库)
await clear_memory()
try:
if model_name == 'Linly':
llm = llm_class.init_model('Linly', 'Linly-AI/Chinese-LLaMA-2-7B-hf', prefix_prompt=PREFIX_PROMPT)
logger.info("Linly模型导入成功")
elif model_name in ['Qwen', 'Qwen2']:
model_path = 'Qwen/Qwen-1_8B-Chat' if model_name == 'Qwen' else 'Qwen/Qwen1.5-0.5B-Chat'
llm = llm_class.init_model(model_name, model_path, prefix_prompt=PREFIX_PROMPT)
logger.info(f"{model_name} 模型导入成功")
elif model_name == 'Gemini':
if gemini_apikey:
llm = llm_class.init_model('Gemini', 'gemini-pro', gemini_apikey, proxy_url)
logger.info("Gemini模型导入成功")
else:
raise HTTPException(status_code=400, detail="请填写Gemini的API密钥")
elif model_name == 'ChatGLM':
llm = llm_class.init_model('ChatGLM', 'THUDM/chatglm3-6b', prefix_prompt=PREFIX_PROMPT)
logger.info("ChatGLM模型导入成功")
elif model_name == 'ChatGPT':
if openai_apikey:
llm = llm_class.init_model('ChatGPT', api_key=openai_apikey, proxy_url=proxy_url, prefix_prompt=PREFIX_PROMPT)
logger.info("ChatGPT模型导入成功")
else:
raise HTTPException(status_code=400, detail="请填写OpenAI的API密钥")
elif model_name == 'GPT4Free':
llm = llm_class.init_model('GPT4Free', prefix_prompt=PREFIX_PROMPT)
logger.info("GPT4Free模型导入成功,请注意该模型可能不稳定")
elif model_name == '直接回复 Direct Reply':
llm = llm_class.init_model(model_name)
logger.info("直接回复模式激活,不使用LLM模型")
else:
raise HTTPException(status_code=400, detail=f"未知的LLM模型: {model_name}")
except Exception as e:
logger.error(f"{model_name}模型加载失败: {e}")
raise HTTPException(status_code=500, detail=f"{model_name}模型加载失败: {e}")
return {"message": f"{model_name} 模型加载成功"}
@app.post("/llm_response/")
async def llm_response(request: LLMRequest):
"""处理LLM模型的问答请求。"""
global llm
if not request.question:
raise HTTPException(status_code=400, detail="问题内容不能为空")
if llm is None:
raise HTTPException(status_code=400, detail="LLM模型未加载,请先加载模型")
try:
answer = llm.generate(request.question, DEFAULT_SYSTEM)
logger.info(f"LLM 回复:{answer}")
return {"answer": answer}
except Exception as e:
logger.error(f"处理LLM请求失败: {e}")
raise HTTPException(status_code=500, detail=f"处理LLM请求失败: {e}")
async def clear_memory():
"""清理显存的异步函数"""
logger.info("清理显存资源")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
logger.info(f"显存使用情况: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)