Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, Query | |
from pydantic import BaseModel | |
from loguru import logger | |
import gc, torch | |
import sys | |
sys.path.append('./') | |
app = FastAPI() | |
# 全局变量用于存储当前加载的LLM模型 | |
from LLM import LLM | |
llm_class = LLM(mode='offline') | |
# 默认不使用LLM模型,直接回复问题,同时减少显存占用! | |
llm = llm_class.init_model('直接回复 Direct Reply') | |
# 默认系统提示语 | |
PREFIX_PROMPT = '请用少于25个字回答以下问题\n\n' | |
PREFIX_PROMPT = '' | |
DEFAULT_SYSTEM = '你是一个很有帮助的助手' | |
class LLMRequest(BaseModel): | |
question: str = '请问什么是FastAPI?' | |
model_name: str = 'Linly' | |
gemini_apikey: str = '' # Gemini模型的API密钥 | |
openai_apikey: str = '' # OpenAI的API密钥 | |
proxy_url: str = None # 代理URL | |
async def change_model( | |
model_name: str = Query(..., description="要加载的LLM模型名称"), | |
gemini_apikey: str = Query('', description="Gemini API 密钥"), | |
openai_apikey: str = Query('', description="OpenAI API 密钥"), | |
proxy_url: str = Query(None, description="代理 URL") | |
): | |
"""更换LLM模型并加载相应资源。""" | |
global llm | |
# 清理显存(具体实现依赖于模型库) | |
await clear_memory() | |
try: | |
if model_name == 'Linly': | |
llm = llm_class.init_model('Linly', 'Linly-AI/Chinese-LLaMA-2-7B-hf', prefix_prompt=PREFIX_PROMPT) | |
logger.info("Linly模型导入成功") | |
elif model_name in ['Qwen', 'Qwen2']: | |
model_path = 'Qwen/Qwen-1_8B-Chat' if model_name == 'Qwen' else 'Qwen/Qwen1.5-0.5B-Chat' | |
llm = llm_class.init_model(model_name, model_path, prefix_prompt=PREFIX_PROMPT) | |
logger.info(f"{model_name} 模型导入成功") | |
elif model_name == 'Gemini': | |
if gemini_apikey: | |
llm = llm_class.init_model('Gemini', 'gemini-pro', gemini_apikey, proxy_url) | |
logger.info("Gemini模型导入成功") | |
else: | |
raise HTTPException(status_code=400, detail="请填写Gemini的API密钥") | |
elif model_name == 'ChatGLM': | |
llm = llm_class.init_model('ChatGLM', 'THUDM/chatglm3-6b', prefix_prompt=PREFIX_PROMPT) | |
logger.info("ChatGLM模型导入成功") | |
elif model_name == 'ChatGPT': | |
if openai_apikey: | |
llm = llm_class.init_model('ChatGPT', api_key=openai_apikey, proxy_url=proxy_url, prefix_prompt=PREFIX_PROMPT) | |
logger.info("ChatGPT模型导入成功") | |
else: | |
raise HTTPException(status_code=400, detail="请填写OpenAI的API密钥") | |
elif model_name == 'GPT4Free': | |
llm = llm_class.init_model('GPT4Free', prefix_prompt=PREFIX_PROMPT) | |
logger.info("GPT4Free模型导入成功,请注意该模型可能不稳定") | |
elif model_name == '直接回复 Direct Reply': | |
llm = llm_class.init_model(model_name) | |
logger.info("直接回复模式激活,不使用LLM模型") | |
else: | |
raise HTTPException(status_code=400, detail=f"未知的LLM模型: {model_name}") | |
except Exception as e: | |
logger.error(f"{model_name}模型加载失败: {e}") | |
raise HTTPException(status_code=500, detail=f"{model_name}模型加载失败: {e}") | |
return {"message": f"{model_name} 模型加载成功"} | |
async def llm_response(request: LLMRequest): | |
"""处理LLM模型的问答请求。""" | |
global llm | |
if not request.question: | |
raise HTTPException(status_code=400, detail="问题内容不能为空") | |
if llm is None: | |
raise HTTPException(status_code=400, detail="LLM模型未加载,请先加载模型") | |
try: | |
answer = llm.generate(request.question, DEFAULT_SYSTEM) | |
logger.info(f"LLM 回复:{answer}") | |
return {"answer": answer} | |
except Exception as e: | |
logger.error(f"处理LLM请求失败: {e}") | |
raise HTTPException(status_code=500, detail=f"处理LLM请求失败: {e}") | |
async def clear_memory(): | |
"""清理显存的异步函数""" | |
logger.info("清理显存资源") | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
logger.info(f"显存使用情况: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8002) |