Spaces:
Running
Running
# 该文件封装了对api.py的请求,可以被不同的webui使用 | |
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用 | |
from typing import * | |
from pathlib import Path | |
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 | |
from configs import ( | |
EMBEDDING_MODEL, | |
DEFAULT_VS_TYPE, | |
LLM_MODELS, | |
TEMPERATURE, | |
SCORE_THRESHOLD, | |
CHUNK_SIZE, | |
OVERLAP_SIZE, | |
ZH_TITLE_ENHANCE, | |
VECTOR_SEARCH_TOP_K, | |
SEARCH_ENGINE_TOP_K, | |
HTTPX_DEFAULT_TIMEOUT, | |
logger, log_verbose, | |
) | |
import httpx | |
import contextlib | |
import json | |
import os | |
from io import BytesIO | |
from server.utils import set_httpx_config, api_address, get_httpx_client | |
from pprint import pprint | |
from langchain_core._api import deprecated | |
set_httpx_config() | |
class ApiRequest: | |
''' | |
api.py调用的封装(同步模式),简化api调用方式 | |
''' | |
def __init__( | |
self, | |
base_url: str = api_address(), | |
timeout: float = HTTPX_DEFAULT_TIMEOUT, | |
): | |
self.base_url = base_url | |
self.timeout = timeout | |
self._use_async = False | |
self._client = None | |
def client(self): | |
if self._client is None or self._client.is_closed: | |
self._client = get_httpx_client(base_url=self.base_url, | |
use_async=self._use_async, | |
timeout=self.timeout) | |
return self._client | |
def get( | |
self, | |
url: str, | |
params: Union[Dict, List[Tuple], bytes] = None, | |
retry: int = 3, | |
stream: bool = False, | |
**kwargs: Any, | |
) -> Union[httpx.Response, Iterator[httpx.Response], None]: | |
while retry > 0: | |
try: | |
if stream: | |
return self.client.stream("GET", url, params=params, **kwargs) | |
else: | |
return self.client.get(url, params=params, **kwargs) | |
except Exception as e: | |
msg = f"error when get {url}: {e}" | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
retry -= 1 | |
def post( | |
self, | |
url: str, | |
data: Dict = None, | |
json: Dict = None, | |
retry: int = 3, | |
stream: bool = False, | |
**kwargs: Any | |
) -> Union[httpx.Response, Iterator[httpx.Response], None]: | |
while retry > 0: | |
try: | |
# print(kwargs) | |
if stream: | |
return self.client.stream("POST", url, data=data, json=json, **kwargs) | |
else: | |
return self.client.post(url, data=data, json=json, **kwargs) | |
except Exception as e: | |
msg = f"error when post {url}: {e}" | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
retry -= 1 | |
def delete( | |
self, | |
url: str, | |
data: Dict = None, | |
json: Dict = None, | |
retry: int = 3, | |
stream: bool = False, | |
**kwargs: Any | |
) -> Union[httpx.Response, Iterator[httpx.Response], None]: | |
while retry > 0: | |
try: | |
if stream: | |
return self.client.stream("DELETE", url, data=data, json=json, **kwargs) | |
else: | |
return self.client.delete(url, data=data, json=json, **kwargs) | |
except Exception as e: | |
msg = f"error when delete {url}: {e}" | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
retry -= 1 | |
def _httpx_stream2generator( | |
self, | |
response: contextlib._GeneratorContextManager, | |
as_json: bool = False, | |
): | |
''' | |
将httpx.stream返回的GeneratorContextManager转化为普通生成器 | |
''' | |
async def ret_async(response, as_json): | |
try: | |
async with response as r: | |
async for chunk in r.aiter_text(None): | |
if not chunk: # fastchat api yield empty bytes on start and end | |
continue | |
if as_json: | |
try: | |
if chunk.startswith("data: "): | |
data = json.loads(chunk[6:-2]) | |
elif chunk.startswith(":"): # skip sse comment line | |
continue | |
else: | |
data = json.loads(chunk) | |
yield data | |
except Exception as e: | |
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
else: | |
# print(chunk, end="", flush=True) | |
yield chunk | |
except httpx.ConnectError as e: | |
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" | |
logger.error(msg) | |
yield {"code": 500, "msg": msg} | |
except httpx.ReadTimeout as e: | |
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" | |
logger.error(msg) | |
yield {"code": 500, "msg": msg} | |
except Exception as e: | |
msg = f"API通信遇到错误:{e}" | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
yield {"code": 500, "msg": msg} | |
def ret_sync(response, as_json): | |
try: | |
with response as r: | |
for chunk in r.iter_text(None): | |
if not chunk: # fastchat api yield empty bytes on start and end | |
continue | |
if as_json: | |
try: | |
if chunk.startswith("data: "): | |
data = json.loads(chunk[6:-2]) | |
elif chunk.startswith(":"): # skip sse comment line | |
continue | |
else: | |
data = json.loads(chunk) | |
yield data | |
except Exception as e: | |
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
else: | |
# print(chunk, end="", flush=True) | |
yield chunk | |
except httpx.ConnectError as e: | |
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" | |
logger.error(msg) | |
yield {"code": 500, "msg": msg} | |
except httpx.ReadTimeout as e: | |
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" | |
logger.error(msg) | |
yield {"code": 500, "msg": msg} | |
except Exception as e: | |
msg = f"API通信遇到错误:{e}" | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
yield {"code": 500, "msg": msg} | |
if self._use_async: | |
return ret_async(response, as_json) | |
else: | |
return ret_sync(response, as_json) | |
def _get_response_value( | |
self, | |
response: httpx.Response, | |
as_json: bool = False, | |
value_func: Callable = None, | |
): | |
''' | |
转换同步或异步请求返回的响应 | |
`as_json`: 返回json | |
`value_func`: 用户可以自定义返回值,该函数接受response或json | |
''' | |
def to_json(r): | |
try: | |
return r.json() | |
except Exception as e: | |
msg = "API未能返回正确的JSON。" + str(e) | |
if log_verbose: | |
logger.error(f'{e.__class__.__name__}: {msg}', | |
exc_info=e if log_verbose else None) | |
return {"code": 500, "msg": msg, "data": None} | |
if value_func is None: | |
value_func = (lambda r: r) | |
async def ret_async(response): | |
if as_json: | |
return value_func(to_json(await response)) | |
else: | |
return value_func(await response) | |
if self._use_async: | |
return ret_async(response) | |
else: | |
if as_json: | |
return value_func(to_json(response)) | |
else: | |
return value_func(response) | |
# 服务器信息 | |
def get_server_configs(self, **kwargs) -> Dict: | |
response = self.post("/server/configs", **kwargs) | |
return self._get_response_value(response, as_json=True) | |
def list_search_engines(self, **kwargs) -> List: | |
response = self.post("/server/list_search_engines", **kwargs) | |
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) | |
def get_prompt_template( | |
self, | |
type: str = "llm_chat", | |
name: str = "default", | |
**kwargs, | |
) -> str: | |
data = { | |
"type": type, | |
"name": name, | |
} | |
response = self.post("/server/get_prompt_template", json=data, **kwargs) | |
return self._get_response_value(response, value_func=lambda r: r.text) | |
# 对话相关操作 | |
def chat_chat( | |
self, | |
query: str, | |
conversation_id: str = None, | |
history_len: int = -1, | |
history: List[Dict] = [], | |
stream: bool = True, | |
model: str = LLM_MODELS[0], | |
temperature: float = TEMPERATURE, | |
max_tokens: int = None, | |
prompt_name: str = "default", | |
**kwargs, | |
): | |
''' | |
对应api.py/chat/chat接口 | |
''' | |
data = { | |
"query": query, | |
"conversation_id": conversation_id, | |
"history_len": history_len, | |
"history": history, | |
"stream": stream, | |
"model_name": model, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"prompt_name": prompt_name, | |
} | |
# print(f"received input message:") | |
# pprint(data) | |
response = self.post("/chat/chat", json=data, stream=True, **kwargs) | |
return self._httpx_stream2generator(response, as_json=True) | |
def agent_chat( | |
self, | |
query: str, | |
history: List[Dict] = [], | |
stream: bool = True, | |
model: str = LLM_MODELS[0], | |
temperature: float = TEMPERATURE, | |
max_tokens: int = None, | |
prompt_name: str = "default", | |
): | |
''' | |
对应api.py/chat/agent_chat 接口 | |
''' | |
data = { | |
"query": query, | |
"history": history, | |
"stream": stream, | |
"model_name": model, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"prompt_name": prompt_name, | |
} | |
# print(f"received input message:") | |
# pprint(data) | |
response = self.post("/chat/agent_chat", json=data, stream=True) | |
return self._httpx_stream2generator(response, as_json=True) | |
def knowledge_base_chat( | |
self, | |
query: str, | |
knowledge_base_name: str, | |
top_k: int = VECTOR_SEARCH_TOP_K, | |
score_threshold: float = SCORE_THRESHOLD, | |
history: List[Dict] = [], | |
stream: bool = True, | |
model: str = LLM_MODELS[0], | |
temperature: float = TEMPERATURE, | |
max_tokens: int = None, | |
prompt_name: str = "default", | |
): | |
''' | |
对应api.py/chat/knowledge_base_chat接口 | |
''' | |
data = { | |
"query": query, | |
"knowledge_base_name": knowledge_base_name, | |
"top_k": top_k, | |
"score_threshold": score_threshold, | |
"history": history, | |
"stream": stream, | |
"model_name": model, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"prompt_name": prompt_name, | |
} | |
# print(f"received input message:") | |
# pprint(data) | |
response = self.post( | |
"/chat/knowledge_base_chat", | |
json=data, | |
stream=True, | |
) | |
return self._httpx_stream2generator(response, as_json=True) | |
def upload_temp_docs( | |
self, | |
files: List[Union[str, Path, bytes]], | |
knowledge_id: str = None, | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=OVERLAP_SIZE, | |
zh_title_enhance=ZH_TITLE_ENHANCE, | |
): | |
''' | |
对应api.py/knowledge_base/upload_tmep_docs接口 | |
''' | |
def convert_file(file, filename=None): | |
if isinstance(file, bytes): # raw bytes | |
file = BytesIO(file) | |
elif hasattr(file, "read"): # a file io like object | |
filename = filename or file.name | |
else: # a local path | |
file = Path(file).absolute().open("rb") | |
filename = filename or os.path.split(file.name)[-1] | |
return filename, file | |
files = [convert_file(file) for file in files] | |
data = { | |
"knowledge_id": knowledge_id, | |
"chunk_size": chunk_size, | |
"chunk_overlap": chunk_overlap, | |
"zh_title_enhance": zh_title_enhance, | |
} | |
response = self.post( | |
"/knowledge_base/upload_temp_docs", | |
data=data, | |
files=[("files", (filename, file)) for filename, file in files], | |
) | |
return self._get_response_value(response, as_json=True) | |
def file_chat( | |
self, | |
query: str, | |
knowledge_id: str, | |
top_k: int = VECTOR_SEARCH_TOP_K, | |
score_threshold: float = SCORE_THRESHOLD, | |
history: List[Dict] = [], | |
stream: bool = True, | |
model: str = LLM_MODELS[0], | |
temperature: float = TEMPERATURE, | |
max_tokens: int = None, | |
prompt_name: str = "default", | |
): | |
''' | |
对应api.py/chat/file_chat接口 | |
''' | |
data = { | |
"query": query, | |
"knowledge_id": knowledge_id, | |
"top_k": top_k, | |
"score_threshold": score_threshold, | |
"history": history, | |
"stream": stream, | |
"model_name": model, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"prompt_name": prompt_name, | |
} | |
response = self.post( | |
"/chat/file_chat", | |
json=data, | |
stream=True, | |
) | |
return self._httpx_stream2generator(response, as_json=True) | |
def search_engine_chat( | |
self, | |
query: str, | |
search_engine_name: str, | |
top_k: int = SEARCH_ENGINE_TOP_K, | |
history: List[Dict] = [], | |
stream: bool = True, | |
model: str = LLM_MODELS[0], | |
temperature: float = TEMPERATURE, | |
max_tokens: int = None, | |
prompt_name: str = "default", | |
split_result: bool = False, | |
): | |
''' | |
对应api.py/chat/search_engine_chat接口 | |
''' | |
data = { | |
"query": query, | |
"search_engine_name": search_engine_name, | |
"top_k": top_k, | |
"history": history, | |
"stream": stream, | |
"model_name": model, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"prompt_name": prompt_name, | |
"split_result": split_result, | |
} | |
# print(f"received input message:") | |
# pprint(data) | |
response = self.post( | |
"/chat/search_engine_chat", | |
json=data, | |
stream=True, | |
) | |
return self._httpx_stream2generator(response, as_json=True) | |
# 知识库相关操作 | |
def list_knowledge_bases( | |
self, | |
): | |
''' | |
对应api.py/knowledge_base/list_knowledge_bases接口 | |
''' | |
response = self.get("/knowledge_base/list_knowledge_bases") | |
return self._get_response_value(response, | |
as_json=True, | |
value_func=lambda r: r.get("data", [])) | |
def create_knowledge_base( | |
self, | |
knowledge_base_name: str, | |
vector_store_type: str = DEFAULT_VS_TYPE, | |
embed_model: str = EMBEDDING_MODEL, | |
): | |
''' | |
对应api.py/knowledge_base/create_knowledge_base接口 | |
''' | |
data = { | |
"knowledge_base_name": knowledge_base_name, | |
"vector_store_type": vector_store_type, | |
"embed_model": embed_model, | |
} | |
response = self.post( | |
"/knowledge_base/create_knowledge_base", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
def delete_knowledge_base( | |
self, | |
knowledge_base_name: str, | |
): | |
''' | |
对应api.py/knowledge_base/delete_knowledge_base接口 | |
''' | |
response = self.post( | |
"/knowledge_base/delete_knowledge_base", | |
json=f"{knowledge_base_name}", | |
) | |
return self._get_response_value(response, as_json=True) | |
def list_kb_docs( | |
self, | |
knowledge_base_name: str, | |
): | |
''' | |
对应api.py/knowledge_base/list_files接口 | |
''' | |
response = self.get( | |
"/knowledge_base/list_files", | |
params={"knowledge_base_name": knowledge_base_name} | |
) | |
return self._get_response_value(response, | |
as_json=True, | |
value_func=lambda r: r.get("data", [])) | |
def search_kb_docs( | |
self, | |
knowledge_base_name: str, | |
query: str = "", | |
top_k: int = VECTOR_SEARCH_TOP_K, | |
score_threshold: int = SCORE_THRESHOLD, | |
file_name: str = "", | |
metadata: dict = {}, | |
) -> List: | |
''' | |
对应api.py/knowledge_base/search_docs接口 | |
''' | |
data = { | |
"query": query, | |
"knowledge_base_name": knowledge_base_name, | |
"top_k": top_k, | |
"score_threshold": score_threshold, | |
"file_name": file_name, | |
"metadata": metadata, | |
} | |
response = self.post( | |
"/knowledge_base/search_docs", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
def update_docs_by_id( | |
self, | |
knowledge_base_name: str, | |
docs: Dict[str, Dict], | |
) -> bool: | |
''' | |
对应api.py/knowledge_base/update_docs_by_id接口 | |
''' | |
data = { | |
"knowledge_base_name": knowledge_base_name, | |
"docs": docs, | |
} | |
response = self.post( | |
"/knowledge_base/update_docs_by_id", | |
json=data | |
) | |
return self._get_response_value(response) | |
def upload_kb_docs( | |
self, | |
files: List[Union[str, Path, bytes]], | |
knowledge_base_name: str, | |
override: bool = False, | |
to_vector_store: bool = True, | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=OVERLAP_SIZE, | |
zh_title_enhance=ZH_TITLE_ENHANCE, | |
docs: Dict = {}, | |
not_refresh_vs_cache: bool = False, | |
): | |
''' | |
对应api.py/knowledge_base/upload_docs接口 | |
''' | |
def convert_file(file, filename=None): | |
if isinstance(file, bytes): # raw bytes | |
file = BytesIO(file) | |
elif hasattr(file, "read"): # a file io like object | |
filename = filename or file.name | |
else: # a local path | |
file = Path(file).absolute().open("rb") | |
filename = filename or os.path.split(file.name)[-1] | |
return filename, file | |
files = [convert_file(file) for file in files] | |
data = { | |
"knowledge_base_name": knowledge_base_name, | |
"override": override, | |
"to_vector_store": to_vector_store, | |
"chunk_size": chunk_size, | |
"chunk_overlap": chunk_overlap, | |
"zh_title_enhance": zh_title_enhance, | |
"docs": docs, | |
"not_refresh_vs_cache": not_refresh_vs_cache, | |
} | |
if isinstance(data["docs"], dict): | |
data["docs"] = json.dumps(data["docs"], ensure_ascii=False) | |
response = self.post( | |
"/knowledge_base/upload_docs", | |
data=data, | |
files=[("files", (filename, file)) for filename, file in files], | |
) | |
return self._get_response_value(response, as_json=True) | |
def delete_kb_docs( | |
self, | |
knowledge_base_name: str, | |
file_names: List[str], | |
delete_content: bool = False, | |
not_refresh_vs_cache: bool = False, | |
): | |
''' | |
对应api.py/knowledge_base/delete_docs接口 | |
''' | |
data = { | |
"knowledge_base_name": knowledge_base_name, | |
"file_names": file_names, | |
"delete_content": delete_content, | |
"not_refresh_vs_cache": not_refresh_vs_cache, | |
} | |
response = self.post( | |
"/knowledge_base/delete_docs", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
def update_kb_info(self, knowledge_base_name, kb_info): | |
''' | |
对应api.py/knowledge_base/update_info接口 | |
''' | |
data = { | |
"knowledge_base_name": knowledge_base_name, | |
"kb_info": kb_info, | |
} | |
response = self.post( | |
"/knowledge_base/update_info", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
def update_kb_docs( | |
self, | |
knowledge_base_name: str, | |
file_names: List[str], | |
override_custom_docs: bool = False, | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=OVERLAP_SIZE, | |
zh_title_enhance=ZH_TITLE_ENHANCE, | |
docs: Dict = {}, | |
not_refresh_vs_cache: bool = False, | |
): | |
''' | |
对应api.py/knowledge_base/update_docs接口 | |
''' | |
data = { | |
"knowledge_base_name": knowledge_base_name, | |
"file_names": file_names, | |
"override_custom_docs": override_custom_docs, | |
"chunk_size": chunk_size, | |
"chunk_overlap": chunk_overlap, | |
"zh_title_enhance": zh_title_enhance, | |
"docs": docs, | |
"not_refresh_vs_cache": not_refresh_vs_cache, | |
} | |
if isinstance(data["docs"], dict): | |
data["docs"] = json.dumps(data["docs"], ensure_ascii=False) | |
response = self.post( | |
"/knowledge_base/update_docs", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
def recreate_vector_store( | |
self, | |
knowledge_base_name: str, | |
allow_empty_kb: bool = True, | |
vs_type: str = DEFAULT_VS_TYPE, | |
embed_model: str = EMBEDDING_MODEL, | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=OVERLAP_SIZE, | |
zh_title_enhance=ZH_TITLE_ENHANCE, | |
): | |
''' | |
对应api.py/knowledge_base/recreate_vector_store接口 | |
''' | |
data = { | |
"knowledge_base_name": knowledge_base_name, | |
"allow_empty_kb": allow_empty_kb, | |
"vs_type": vs_type, | |
"embed_model": embed_model, | |
"chunk_size": chunk_size, | |
"chunk_overlap": chunk_overlap, | |
"zh_title_enhance": zh_title_enhance, | |
} | |
response = self.post( | |
"/knowledge_base/recreate_vector_store", | |
json=data, | |
stream=True, | |
timeout=None, | |
) | |
return self._httpx_stream2generator(response, as_json=True) | |
# LLM模型相关操作 | |
def list_running_models( | |
self, | |
controller_address: str = None, | |
): | |
''' | |
获取Fastchat中正运行的模型列表 | |
''' | |
data = { | |
"controller_address": controller_address, | |
} | |
if log_verbose: | |
logger.info(f'{self.__class__.__name__}:data: {data}') | |
response = self.post( | |
"/llm_model/list_running_models", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", [])) | |
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: | |
''' | |
从服务器上获取当前运行的LLM模型。 | |
当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。 | |
返回类型为(model_name, is_local_model) | |
''' | |
def ret_sync(): | |
running_models = self.list_running_models() | |
if not running_models: | |
return "", False | |
model = "" | |
for m in LLM_MODELS: | |
if m not in running_models: | |
continue | |
is_local = not running_models[m].get("online_api") | |
if local_first and not is_local: | |
continue | |
else: | |
model = m | |
break | |
if not model: # LLM_MODELS中配置的模型都不在running_models里 | |
model = list(running_models)[0] | |
is_local = not running_models[model].get("online_api") | |
return model, is_local | |
async def ret_async(): | |
running_models = await self.list_running_models() | |
if not running_models: | |
return "", False | |
model = "" | |
for m in LLM_MODELS: | |
if m not in running_models: | |
continue | |
is_local = not running_models[m].get("online_api") | |
if local_first and not is_local: | |
continue | |
else: | |
model = m | |
break | |
if not model: # LLM_MODELS中配置的模型都不在running_models里 | |
model = list(running_models)[0] | |
is_local = not running_models[model].get("online_api") | |
return model, is_local | |
if self._use_async: | |
return ret_async() | |
else: | |
return ret_sync() | |
def list_config_models( | |
self, | |
types: List[str] = ["local", "online"], | |
) -> Dict[str, Dict]: | |
''' | |
获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。 | |
''' | |
data = { | |
"types": types, | |
} | |
response = self.post( | |
"/llm_model/list_config_models", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) | |
def get_model_config( | |
self, | |
model_name: str = None, | |
) -> Dict: | |
''' | |
获取服务器上模型配置 | |
''' | |
data = { | |
"model_name": model_name, | |
} | |
response = self.post( | |
"/llm_model/get_model_config", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) | |
def list_search_engines(self) -> List[str]: | |
''' | |
获取服务器支持的搜索引擎 | |
''' | |
response = self.post( | |
"/server/list_search_engines", | |
) | |
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) | |
def stop_llm_model( | |
self, | |
model_name: str, | |
controller_address: str = None, | |
): | |
''' | |
停止某个LLM模型。 | |
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 | |
''' | |
data = { | |
"model_name": model_name, | |
"controller_address": controller_address, | |
} | |
response = self.post( | |
"/llm_model/stop", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
def change_llm_model( | |
self, | |
model_name: str, | |
new_model_name: str, | |
controller_address: str = None, | |
): | |
''' | |
向fastchat controller请求切换LLM模型。 | |
''' | |
if not model_name or not new_model_name: | |
return { | |
"code": 500, | |
"msg": f"未指定模型名称" | |
} | |
def ret_sync(): | |
running_models = self.list_running_models() | |
if new_model_name == model_name or new_model_name in running_models: | |
return { | |
"code": 200, | |
"msg": "无需切换" | |
} | |
if model_name not in running_models: | |
return { | |
"code": 500, | |
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" | |
} | |
config_models = self.list_config_models() | |
if new_model_name not in config_models.get("local", {}): | |
return { | |
"code": 500, | |
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" | |
} | |
data = { | |
"model_name": model_name, | |
"new_model_name": new_model_name, | |
"controller_address": controller_address, | |
} | |
response = self.post( | |
"/llm_model/change", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
async def ret_async(): | |
running_models = await self.list_running_models() | |
if new_model_name == model_name or new_model_name in running_models: | |
return { | |
"code": 200, | |
"msg": "无需切换" | |
} | |
if model_name not in running_models: | |
return { | |
"code": 500, | |
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" | |
} | |
config_models = await self.list_config_models() | |
if new_model_name not in config_models.get("local", {}): | |
return { | |
"code": 500, | |
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" | |
} | |
data = { | |
"model_name": model_name, | |
"new_model_name": new_model_name, | |
"controller_address": controller_address, | |
} | |
response = self.post( | |
"/llm_model/change", | |
json=data, | |
) | |
return self._get_response_value(response, as_json=True) | |
if self._use_async: | |
return ret_async() | |
else: | |
return ret_sync() | |
def embed_texts( | |
self, | |
texts: List[str], | |
embed_model: str = EMBEDDING_MODEL, | |
to_query: bool = False, | |
) -> List[List[float]]: | |
''' | |
对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型 | |
''' | |
data = { | |
"texts": texts, | |
"embed_model": embed_model, | |
"to_query": to_query, | |
} | |
resp = self.post( | |
"/other/embed_texts", | |
json=data, | |
) | |
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) | |
def chat_feedback( | |
self, | |
message_id: str, | |
score: int, | |
reason: str = "", | |
) -> int: | |
''' | |
反馈对话评价 | |
''' | |
data = { | |
"message_id": message_id, | |
"score": score, | |
"reason": reason, | |
} | |
resp = self.post("/chat/feedback", json=data) | |
return self._get_response_value(resp) | |
class AsyncApiRequest(ApiRequest): | |
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT): | |
super().__init__(base_url, timeout) | |
self._use_async = True | |
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: | |
''' | |
return error message if error occured when requests API | |
''' | |
if isinstance(data, dict): | |
if key in data: | |
return data[key] | |
if "code" in data and data["code"] != 200: | |
return data["msg"] | |
return "" | |
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: | |
''' | |
return error message if error occured when requests API | |
''' | |
if (isinstance(data, dict) | |
and key in data | |
and "code" in data | |
and data["code"] == 200): | |
return data[key] | |
return "" | |
if __name__ == "__main__": | |
api = ApiRequest() | |
aapi = AsyncApiRequest() | |
# with api.chat_chat("你好") as r: | |
# for t in r.iter_text(None): | |
# print(t) | |
# r = api.chat_chat("你好", no_remote_api=True) | |
# for t in r: | |
# print(t) | |
# r = api.duckduckgo_search_chat("室温超导最新研究进展", no_remote_api=True) | |
# for t in r: | |
# print(t) | |
# print(api.list_knowledge_bases()) | |