# 该文件封装了对api.py的请求,可以被不同的webui使用 # 通过ApiRequest和AsyncApiRequest支持同步/异步调用 from typing import * from pathlib import Path # 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 from configs import ( EMBEDDING_MODEL, DEFAULT_VS_TYPE, LLM_MODELS, TEMPERATURE, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, HTTPX_DEFAULT_TIMEOUT, logger, log_verbose, ) import httpx import contextlib import json import os from io import BytesIO from server.utils import set_httpx_config, api_address, get_httpx_client from pprint import pprint from langchain_core._api import deprecated set_httpx_config() class ApiRequest: ''' api.py调用的封装(同步模式),简化api调用方式 ''' def __init__( self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT, ): self.base_url = base_url self.timeout = timeout self._use_async = False self._client = None @property def client(self): if self._client is None or self._client.is_closed: self._client = get_httpx_client(base_url=self.base_url, use_async=self._use_async, timeout=self.timeout) return self._client def get( self, url: str, params: Union[Dict, List[Tuple], bytes] = None, retry: int = 3, stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: if stream: return self.client.stream("GET", url, params=params, **kwargs) else: return self.client.get(url, params=params, **kwargs) except Exception as e: msg = f"error when get {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) retry -= 1 def post( self, url: str, data: Dict = None, json: Dict = None, retry: int = 3, stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: # print(kwargs) if stream: return self.client.stream("POST", url, data=data, json=json, **kwargs) else: return self.client.post(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when post {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) retry -= 1 def delete( self, url: str, data: Dict = None, json: Dict = None, retry: int = 3, stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: if stream: return self.client.stream("DELETE", url, data=data, json=json, **kwargs) else: return self.client.delete(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when delete {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) retry -= 1 def _httpx_stream2generator( self, response: contextlib._GeneratorContextManager, as_json: bool = False, ): ''' 将httpx.stream返回的GeneratorContextManager转化为普通生成器 ''' async def ret_async(response, as_json): try: async with response as r: async for chunk in r.aiter_text(None): if not chunk: # fastchat api yield empty bytes on start and end continue if as_json: try: if chunk.startswith("data: "): data = json.loads(chunk[6:-2]) elif chunk.startswith(":"): # skip sse comment line continue else: data = json.loads(chunk) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) else: # print(chunk, end="", flush=True) yield chunk except httpx.ConnectError as e: msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" logger.error(msg) yield {"code": 500, "msg": msg} except httpx.ReadTimeout as e: msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" logger.error(msg) yield {"code": 500, "msg": msg} except Exception as e: msg = f"API通信遇到错误:{e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) yield {"code": 500, "msg": msg} def ret_sync(response, as_json): try: with response as r: for chunk in r.iter_text(None): if not chunk: # fastchat api yield empty bytes on start and end continue if as_json: try: if chunk.startswith("data: "): data = json.loads(chunk[6:-2]) elif chunk.startswith(":"): # skip sse comment line continue else: data = json.loads(chunk) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) else: # print(chunk, end="", flush=True) yield chunk except httpx.ConnectError as e: msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" logger.error(msg) yield {"code": 500, "msg": msg} except httpx.ReadTimeout as e: msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" logger.error(msg) yield {"code": 500, "msg": msg} except Exception as e: msg = f"API通信遇到错误:{e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) yield {"code": 500, "msg": msg} if self._use_async: return ret_async(response, as_json) else: return ret_sync(response, as_json) def _get_response_value( self, response: httpx.Response, as_json: bool = False, value_func: Callable = None, ): ''' 转换同步或异步请求返回的响应 `as_json`: 返回json `value_func`: 用户可以自定义返回值,该函数接受response或json ''' def to_json(r): try: return r.json() except Exception as e: msg = "API未能返回正确的JSON。" + str(e) if log_verbose: logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) return {"code": 500, "msg": msg, "data": None} if value_func is None: value_func = (lambda r: r) async def ret_async(response): if as_json: return value_func(to_json(await response)) else: return value_func(await response) if self._use_async: return ret_async(response) else: if as_json: return value_func(to_json(response)) else: return value_func(response) # 服务器信息 def get_server_configs(self, **kwargs) -> Dict: response = self.post("/server/configs", **kwargs) return self._get_response_value(response, as_json=True) def list_search_engines(self, **kwargs) -> List: response = self.post("/server/list_search_engines", **kwargs) return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) def get_prompt_template( self, type: str = "llm_chat", name: str = "default", **kwargs, ) -> str: data = { "type": type, "name": name, } response = self.post("/server/get_prompt_template", json=data, **kwargs) return self._get_response_value(response, value_func=lambda r: r.text) # 对话相关操作 def chat_chat( self, query: str, conversation_id: str = None, history_len: int = -1, history: List[Dict] = [], stream: bool = True, model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", **kwargs, ): ''' 对应api.py/chat/chat接口 ''' data = { "query": query, "conversation_id": conversation_id, "history_len": history_len, "history": history, "stream": stream, "model_name": model, "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, } # print(f"received input message:") # pprint(data) response = self.post("/chat/chat", json=data, stream=True, **kwargs) return self._httpx_stream2generator(response, as_json=True) @deprecated( since="0.3.0", message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃", removal="0.3.0") def agent_chat( self, query: str, history: List[Dict] = [], stream: bool = True, model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", ): ''' 对应api.py/chat/agent_chat 接口 ''' data = { "query": query, "history": history, "stream": stream, "model_name": model, "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, } # print(f"received input message:") # pprint(data) response = self.post("/chat/agent_chat", json=data, stream=True) return self._httpx_stream2generator(response, as_json=True) def knowledge_base_chat( self, query: str, knowledge_base_name: str, top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", ): ''' 对应api.py/chat/knowledge_base_chat接口 ''' data = { "query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "score_threshold": score_threshold, "history": history, "stream": stream, "model_name": model, "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, } # print(f"received input message:") # pprint(data) response = self.post( "/chat/knowledge_base_chat", json=data, stream=True, ) return self._httpx_stream2generator(response, as_json=True) def upload_temp_docs( self, files: List[Union[str, Path, bytes]], knowledge_id: str = None, chunk_size=CHUNK_SIZE, chunk_overlap=OVERLAP_SIZE, zh_title_enhance=ZH_TITLE_ENHANCE, ): ''' 对应api.py/knowledge_base/upload_tmep_docs接口 ''' def convert_file(file, filename=None): if isinstance(file, bytes): # raw bytes file = BytesIO(file) elif hasattr(file, "read"): # a file io like object filename = filename or file.name else: # a local path file = Path(file).absolute().open("rb") filename = filename or os.path.split(file.name)[-1] return filename, file files = [convert_file(file) for file in files] data = { "knowledge_id": knowledge_id, "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "zh_title_enhance": zh_title_enhance, } response = self.post( "/knowledge_base/upload_temp_docs", data=data, files=[("files", (filename, file)) for filename, file in files], ) return self._get_response_value(response, as_json=True) def file_chat( self, query: str, knowledge_id: str, top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", ): ''' 对应api.py/chat/file_chat接口 ''' data = { "query": query, "knowledge_id": knowledge_id, "top_k": top_k, "score_threshold": score_threshold, "history": history, "stream": stream, "model_name": model, "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, } response = self.post( "/chat/file_chat", json=data, stream=True, ) return self._httpx_stream2generator(response, as_json=True) @deprecated( since="0.3.0", message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃", removal="0.3.0" ) def search_engine_chat( self, query: str, search_engine_name: str, top_k: int = SEARCH_ENGINE_TOP_K, history: List[Dict] = [], stream: bool = True, model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", split_result: bool = False, ): ''' 对应api.py/chat/search_engine_chat接口 ''' data = { "query": query, "search_engine_name": search_engine_name, "top_k": top_k, "history": history, "stream": stream, "model_name": model, "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, "split_result": split_result, } # print(f"received input message:") # pprint(data) response = self.post( "/chat/search_engine_chat", json=data, stream=True, ) return self._httpx_stream2generator(response, as_json=True) # 知识库相关操作 def list_knowledge_bases( self, ): ''' 对应api.py/knowledge_base/list_knowledge_bases接口 ''' response = self.get("/knowledge_base/list_knowledge_bases") return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", [])) def create_knowledge_base( self, knowledge_base_name: str, vector_store_type: str = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, ): ''' 对应api.py/knowledge_base/create_knowledge_base接口 ''' data = { "knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model, } response = self.post( "/knowledge_base/create_knowledge_base", json=data, ) return self._get_response_value(response, as_json=True) def delete_knowledge_base( self, knowledge_base_name: str, ): ''' 对应api.py/knowledge_base/delete_knowledge_base接口 ''' response = self.post( "/knowledge_base/delete_knowledge_base", json=f"{knowledge_base_name}", ) return self._get_response_value(response, as_json=True) def list_kb_docs( self, knowledge_base_name: str, ): ''' 对应api.py/knowledge_base/list_files接口 ''' response = self.get( "/knowledge_base/list_files", params={"knowledge_base_name": knowledge_base_name} ) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", [])) def search_kb_docs( self, knowledge_base_name: str, query: str = "", top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: int = SCORE_THRESHOLD, file_name: str = "", metadata: dict = {}, ) -> List: ''' 对应api.py/knowledge_base/search_docs接口 ''' data = { "query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "score_threshold": score_threshold, "file_name": file_name, "metadata": metadata, } response = self.post( "/knowledge_base/search_docs", json=data, ) return self._get_response_value(response, as_json=True) def update_docs_by_id( self, knowledge_base_name: str, docs: Dict[str, Dict], ) -> bool: ''' 对应api.py/knowledge_base/update_docs_by_id接口 ''' data = { "knowledge_base_name": knowledge_base_name, "docs": docs, } response = self.post( "/knowledge_base/update_docs_by_id", json=data ) return self._get_response_value(response) def upload_kb_docs( self, files: List[Union[str, Path, bytes]], knowledge_base_name: str, override: bool = False, to_vector_store: bool = True, chunk_size=CHUNK_SIZE, chunk_overlap=OVERLAP_SIZE, zh_title_enhance=ZH_TITLE_ENHANCE, docs: Dict = {}, not_refresh_vs_cache: bool = False, ): ''' 对应api.py/knowledge_base/upload_docs接口 ''' def convert_file(file, filename=None): if isinstance(file, bytes): # raw bytes file = BytesIO(file) elif hasattr(file, "read"): # a file io like object filename = filename or file.name else: # a local path file = Path(file).absolute().open("rb") filename = filename or os.path.split(file.name)[-1] return filename, file files = [convert_file(file) for file in files] data = { "knowledge_base_name": knowledge_base_name, "override": override, "to_vector_store": to_vector_store, "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "zh_title_enhance": zh_title_enhance, "docs": docs, "not_refresh_vs_cache": not_refresh_vs_cache, } if isinstance(data["docs"], dict): data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( "/knowledge_base/upload_docs", data=data, files=[("files", (filename, file)) for filename, file in files], ) return self._get_response_value(response, as_json=True) def delete_kb_docs( self, knowledge_base_name: str, file_names: List[str], delete_content: bool = False, not_refresh_vs_cache: bool = False, ): ''' 对应api.py/knowledge_base/delete_docs接口 ''' data = { "knowledge_base_name": knowledge_base_name, "file_names": file_names, "delete_content": delete_content, "not_refresh_vs_cache": not_refresh_vs_cache, } response = self.post( "/knowledge_base/delete_docs", json=data, ) return self._get_response_value(response, as_json=True) def update_kb_info(self, knowledge_base_name, kb_info): ''' 对应api.py/knowledge_base/update_info接口 ''' data = { "knowledge_base_name": knowledge_base_name, "kb_info": kb_info, } response = self.post( "/knowledge_base/update_info", json=data, ) return self._get_response_value(response, as_json=True) def update_kb_docs( self, knowledge_base_name: str, file_names: List[str], override_custom_docs: bool = False, chunk_size=CHUNK_SIZE, chunk_overlap=OVERLAP_SIZE, zh_title_enhance=ZH_TITLE_ENHANCE, docs: Dict = {}, not_refresh_vs_cache: bool = False, ): ''' 对应api.py/knowledge_base/update_docs接口 ''' data = { "knowledge_base_name": knowledge_base_name, "file_names": file_names, "override_custom_docs": override_custom_docs, "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "zh_title_enhance": zh_title_enhance, "docs": docs, "not_refresh_vs_cache": not_refresh_vs_cache, } if isinstance(data["docs"], dict): data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( "/knowledge_base/update_docs", json=data, ) return self._get_response_value(response, as_json=True) def recreate_vector_store( self, knowledge_base_name: str, allow_empty_kb: bool = True, vs_type: str = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, chunk_size=CHUNK_SIZE, chunk_overlap=OVERLAP_SIZE, zh_title_enhance=ZH_TITLE_ENHANCE, ): ''' 对应api.py/knowledge_base/recreate_vector_store接口 ''' data = { "knowledge_base_name": knowledge_base_name, "allow_empty_kb": allow_empty_kb, "vs_type": vs_type, "embed_model": embed_model, "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "zh_title_enhance": zh_title_enhance, } response = self.post( "/knowledge_base/recreate_vector_store", json=data, stream=True, timeout=None, ) return self._httpx_stream2generator(response, as_json=True) # LLM模型相关操作 def list_running_models( self, controller_address: str = None, ): ''' 获取Fastchat中正运行的模型列表 ''' data = { "controller_address": controller_address, } if log_verbose: logger.info(f'{self.__class__.__name__}:data: {data}') response = self.post( "/llm_model/list_running_models", json=data, ) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", [])) def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: ''' 从服务器上获取当前运行的LLM模型。 当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。 返回类型为(model_name, is_local_model) ''' def ret_sync(): running_models = self.list_running_models() if not running_models: return "", False model = "" for m in LLM_MODELS: if m not in running_models: continue is_local = not running_models[m].get("online_api") if local_first and not is_local: continue else: model = m break if not model: # LLM_MODELS中配置的模型都不在running_models里 model = list(running_models)[0] is_local = not running_models[model].get("online_api") return model, is_local async def ret_async(): running_models = await self.list_running_models() if not running_models: return "", False model = "" for m in LLM_MODELS: if m not in running_models: continue is_local = not running_models[m].get("online_api") if local_first and not is_local: continue else: model = m break if not model: # LLM_MODELS中配置的模型都不在running_models里 model = list(running_models)[0] is_local = not running_models[model].get("online_api") return model, is_local if self._use_async: return ret_async() else: return ret_sync() def list_config_models( self, types: List[str] = ["local", "online"], ) -> Dict[str, Dict]: ''' 获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。 ''' data = { "types": types, } response = self.post( "/llm_model/list_config_models", json=data, ) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) def get_model_config( self, model_name: str = None, ) -> Dict: ''' 获取服务器上模型配置 ''' data = { "model_name": model_name, } response = self.post( "/llm_model/get_model_config", json=data, ) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) def list_search_engines(self) -> List[str]: ''' 获取服务器支持的搜索引擎 ''' response = self.post( "/server/list_search_engines", ) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) def stop_llm_model( self, model_name: str, controller_address: str = None, ): ''' 停止某个LLM模型。 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 ''' data = { "model_name": model_name, "controller_address": controller_address, } response = self.post( "/llm_model/stop", json=data, ) return self._get_response_value(response, as_json=True) def change_llm_model( self, model_name: str, new_model_name: str, controller_address: str = None, ): ''' 向fastchat controller请求切换LLM模型。 ''' if not model_name or not new_model_name: return { "code": 500, "msg": f"未指定模型名称" } def ret_sync(): running_models = self.list_running_models() if new_model_name == model_name or new_model_name in running_models: return { "code": 200, "msg": "无需切换" } if model_name not in running_models: return { "code": 500, "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" } config_models = self.list_config_models() if new_model_name not in config_models.get("local", {}): return { "code": 500, "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" } data = { "model_name": model_name, "new_model_name": new_model_name, "controller_address": controller_address, } response = self.post( "/llm_model/change", json=data, ) return self._get_response_value(response, as_json=True) async def ret_async(): running_models = await self.list_running_models() if new_model_name == model_name or new_model_name in running_models: return { "code": 200, "msg": "无需切换" } if model_name not in running_models: return { "code": 500, "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" } config_models = await self.list_config_models() if new_model_name not in config_models.get("local", {}): return { "code": 500, "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" } data = { "model_name": model_name, "new_model_name": new_model_name, "controller_address": controller_address, } response = self.post( "/llm_model/change", json=data, ) return self._get_response_value(response, as_json=True) if self._use_async: return ret_async() else: return ret_sync() def embed_texts( self, texts: List[str], embed_model: str = EMBEDDING_MODEL, to_query: bool = False, ) -> List[List[float]]: ''' 对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型 ''' data = { "texts": texts, "embed_model": embed_model, "to_query": to_query, } resp = self.post( "/other/embed_texts", json=data, ) return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) def chat_feedback( self, message_id: str, score: int, reason: str = "", ) -> int: ''' 反馈对话评价 ''' data = { "message_id": message_id, "score": score, "reason": reason, } resp = self.post("/chat/feedback", json=data) return self._get_response_value(resp) class AsyncApiRequest(ApiRequest): def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT): super().__init__(base_url, timeout) self._use_async = True def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: ''' return error message if error occured when requests API ''' if isinstance(data, dict): if key in data: return data[key] if "code" in data and data["code"] != 200: return data["msg"] return "" def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: ''' return error message if error occured when requests API ''' if (isinstance(data, dict) and key in data and "code" in data and data["code"] == 200): return data[key] return "" if __name__ == "__main__": api = ApiRequest() aapi = AsyncApiRequest() # with api.chat_chat("你好") as r: # for t in r.iter_text(None): # print(t) # r = api.chat_chat("你好", no_remote_api=True) # for t in r: # print(t) # r = api.duckduckgo_search_chat("室温超导最新研究进展", no_remote_api=True) # for t in r: # print(t) # print(api.list_knowledge_bases())