KevinHuSh commited on
Commit
63df91a
·
1 Parent(s): e693841

Support Xinference (#320)

Browse files

### What problem does this PR solve?

Issue link:#299

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

README.md CHANGED
@@ -172,6 +172,7 @@ $ docker compose up -d
172
 
173
  ## 🆕 Latest Features
174
 
 
175
  - 2024-04-10 Add a new layout recognize model for method 'Laws'.
176
  - 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.
177
  - 2024-04-07 Support Chinese UI.
 
172
 
173
  ## 🆕 Latest Features
174
 
175
+ - 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
176
  - 2024-04-10 Add a new layout recognize model for method 'Laws'.
177
  - 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.
178
  - 2024-04-07 Support Chinese UI.
README_ja.md CHANGED
@@ -171,6 +171,8 @@ $ docker compose up -d
171
  ```
172
 
173
  ## 🆕 最新の新機能
 
 
174
  - 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
175
  - 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
176
  - 2024-04-07 中国語インターフェースをサポートします。
 
171
  ```
172
 
173
  ## 🆕 最新の新機能
174
+
175
+ - 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
176
  - 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
177
  - 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
178
  - 2024-04-07 中国語インターフェースをサポートします。
README_zh.md CHANGED
@@ -172,6 +172,7 @@ $ docker compose up -d
172
 
173
  ## 🆕 最近新特性
174
 
 
175
  - 2024-04-10 为‘Laws’版面分析增加了模型。
176
  - 2024-04-08 支持用 [Ollama](./docs/ollama.md) 对大模型进行本地化部署。
177
  - 2024-04-07 支持中文界面。
 
172
 
173
  ## 🆕 最近新特性
174
 
175
+ - 2024-04-11 支持用 [Xinference](./docs/xinference.md) for local LLM deployment.
176
  - 2024-04-10 为‘Laws’版面分析增加了模型。
177
  - 2024-04-08 支持用 [Ollama](./docs/ollama.md) 对大模型进行本地化部署。
178
  - 2024-04-07 支持中文界面。
api/apps/__init__.py CHANGED
@@ -22,6 +22,7 @@ from werkzeug.wrappers.request import Request
22
  from flask_cors import CORS
23
 
24
  from api.db import StatusEnum
 
25
  from api.db.services import UserService
26
  from api.utils import CustomJSONEncoder
27
 
@@ -42,7 +43,7 @@ for h in access_logger.handlers:
42
  Request.json = property(lambda self: self.get_json(force=True, silent=True))
43
 
44
  app = Flask(__name__)
45
- CORS(app, supports_credentials=True,max_age = 2592000)
46
  app.url_map.strict_slashes = False
47
  app.json_encoder = CustomJSONEncoder
48
  app.errorhandler(Exception)(server_error_response)
@@ -94,8 +95,6 @@ client_urls_prefix = [
94
  ]
95
 
96
 
97
-
98
-
99
  @login_manager.request_loader
100
  def load_user(web_request):
101
  jwt = Serializer(secret_key=SECRET_KEY)
@@ -112,4 +111,9 @@ def load_user(web_request):
112
  stat_logger.exception(e)
113
  return None
114
  else:
115
- return None
 
 
 
 
 
 
22
  from flask_cors import CORS
23
 
24
  from api.db import StatusEnum
25
+ from api.db.db_models import close_connection
26
  from api.db.services import UserService
27
  from api.utils import CustomJSONEncoder
28
 
 
43
  Request.json = property(lambda self: self.get_json(force=True, silent=True))
44
 
45
  app = Flask(__name__)
46
+ CORS(app, supports_credentials=True,max_age=2592000)
47
  app.url_map.strict_slashes = False
48
  app.json_encoder = CustomJSONEncoder
49
  app.errorhandler(Exception)(server_error_response)
 
95
  ]
96
 
97
 
 
 
98
  @login_manager.request_loader
99
  def load_user(web_request):
100
  jwt = Serializer(secret_key=SECRET_KEY)
 
111
  stat_logger.exception(e)
112
  return None
113
  else:
114
+ return None
115
+
116
+
117
+ @app.teardown_request
118
+ def _db_close(exc):
119
+ close_connection()
api/apps/conversation_app.py CHANGED
@@ -360,6 +360,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
360
  "|" for r in tbl["rows"]]
361
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
362
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
 
363
  if not docid_idx or not docnm_idx:
364
  chat_logger.warning("SQL missing field: " + sql)
365
  return {
 
360
  "|" for r in tbl["rows"]]
361
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
362
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
363
+
364
  if not docid_idx or not docnm_idx:
365
  chat_logger.warning("SQL missing field: " + sql)
366
  return {
api/db/init_data.py CHANGED
@@ -109,6 +109,12 @@ factory_infos = [{
109
  "logo": "",
110
  "tags": "LLM,TEXT EMBEDDING",
111
  "status": "1",
 
 
 
 
 
 
112
  },
113
  # {
114
  # "name": "文心一言",
 
109
  "logo": "",
110
  "tags": "LLM,TEXT EMBEDDING",
111
  "status": "1",
112
+ },
113
+ {
114
+ "name": "Xinference",
115
+ "logo": "",
116
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
117
+ "status": "1",
118
  },
119
  # {
120
  # "name": "文心一言",
docker/docker-compose-CN.yml CHANGED
@@ -20,7 +20,6 @@ services:
20
  - 443:443
21
  volumes:
22
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
23
- - ./entrypoint.sh:/ragflow/entrypoint.sh
24
  - ./ragflow-logs:/ragflow/logs
25
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
26
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
 
20
  - 443:443
21
  volumes:
22
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
 
23
  - ./ragflow-logs:/ragflow/logs
24
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
25
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
docker/docker-compose.yml CHANGED
@@ -19,7 +19,6 @@ services:
19
  - 443:443
20
  volumes:
21
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
22
- - ./entrypoint.sh:/ragflow/entrypoint.sh
23
  - ./ragflow-logs:/ragflow/logs
24
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
25
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
 
19
  - 443:443
20
  volumes:
21
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
 
22
  - ./ragflow-logs:/ragflow/logs
23
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
24
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
rag/llm/__init__.py CHANGED
@@ -21,6 +21,7 @@ from .cv_model import *
21
  EmbeddingModel = {
22
  "Ollama": OllamaEmbed,
23
  "OpenAI": OpenAIEmbed,
 
24
  "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
25
  "ZHIPU-AI": ZhipuEmbed,
26
  "Moonshot": HuEmbedding
@@ -30,6 +31,7 @@ EmbeddingModel = {
30
  CvModel = {
31
  "OpenAI": GptV4,
32
  "Ollama": OllamaCV,
 
33
  "Tongyi-Qianwen": QWenCV,
34
  "ZHIPU-AI": Zhipu4V,
35
  "Moonshot": LocalCV
@@ -41,6 +43,7 @@ ChatModel = {
41
  "ZHIPU-AI": ZhipuChat,
42
  "Tongyi-Qianwen": QWenChat,
43
  "Ollama": OllamaChat,
 
44
  "Moonshot": MoonshotChat
45
  }
46
 
 
21
  EmbeddingModel = {
22
  "Ollama": OllamaEmbed,
23
  "OpenAI": OpenAIEmbed,
24
+ "Xinference": XinferenceEmbed,
25
  "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
26
  "ZHIPU-AI": ZhipuEmbed,
27
  "Moonshot": HuEmbedding
 
31
  CvModel = {
32
  "OpenAI": GptV4,
33
  "Ollama": OllamaCV,
34
+ "Xinference": XinferenceCV,
35
  "Tongyi-Qianwen": QWenCV,
36
  "ZHIPU-AI": Zhipu4V,
37
  "Moonshot": LocalCV
 
43
  "ZHIPU-AI": ZhipuChat,
44
  "Tongyi-Qianwen": QWenChat,
45
  "Ollama": OllamaChat,
46
+ "Xinference": XinferenceChat,
47
  "Moonshot": MoonshotChat
48
  }
49
 
rag/llm/chat_model.py CHANGED
@@ -158,6 +158,28 @@ class OllamaChat(Base):
158
  return "**ERROR**: " + str(e), 0
159
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  class LocalLLM(Base):
162
  class RPCProxy:
163
  def __init__(self, host, port):
 
158
  return "**ERROR**: " + str(e), 0
159
 
160
 
161
+ class XinferenceChat(Base):
162
+ def __init__(self, key=None, model_name="", base_url=""):
163
+ self.client = OpenAI(api_key="xxx", base_url=base_url)
164
+ self.model_name = model_name
165
+
166
+ def chat(self, system, history, gen_conf):
167
+ if system:
168
+ history.insert(0, {"role": "system", "content": system})
169
+ try:
170
+ response = self.client.chat.completions.create(
171
+ model=self.model_name,
172
+ messages=history,
173
+ **gen_conf)
174
+ ans = response.choices[0].message.content.strip()
175
+ if response.choices[0].finish_reason == "length":
176
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
177
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
178
+ return ans, response.usage.completion_tokens
179
+ except openai.APIError as e:
180
+ return "**ERROR**: " + str(e), 0
181
+
182
+
183
  class LocalLLM(Base):
184
  class RPCProxy:
185
  def __init__(self, host, port):
rag/llm/cv_model.py CHANGED
@@ -161,6 +161,22 @@ class OllamaCV(Base):
161
  except Exception as e:
162
  return "**ERROR**: " + str(e), 0
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  class LocalCV(Base):
166
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
 
161
  except Exception as e:
162
  return "**ERROR**: " + str(e), 0
163
 
164
+ class XinferenceCV(Base):
165
+ def __init__(self, key, model_name="", lang="Chinese", base_url=""):
166
+ self.client = OpenAI(api_key=key, base_url=base_url)
167
+ self.model_name = model_name
168
+ self.lang = lang
169
+
170
+ def describe(self, image, max_tokens=300):
171
+ b64 = self.image2base64(image)
172
+
173
+ res = self.client.chat.completions.create(
174
+ model=self.model_name,
175
+ messages=self.prompt(b64),
176
+ max_tokens=max_tokens,
177
+ )
178
+ return res.choices[0].message.content.strip(), res.usage.total_tokens
179
+
180
 
181
  class LocalCV(Base):
182
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
rag/llm/embedding_model.py CHANGED
@@ -170,3 +170,20 @@ class OllamaEmbed(Base):
170
  res = self.client.embeddings(prompt=text,
171
  model=self.model_name)
172
  return np.array(res["embedding"]), 128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  res = self.client.embeddings(prompt=text,
171
  model=self.model_name)
172
  return np.array(res["embedding"]), 128
173
+
174
+
175
+ class XinferenceEmbed(Base):
176
+ def __init__(self, key, model_name="", base_url=""):
177
+ self.client = OpenAI(api_key="xxx", base_url=base_url)
178
+ self.model_name = model_name
179
+
180
+ def encode(self, texts: list, batch_size=32):
181
+ res = self.client.embeddings.create(input=texts,
182
+ model=self.model_name)
183
+ return np.array([d.embedding for d in res.data]
184
+ ), res.usage.total_tokens
185
+
186
+ def encode_queries(self, text):
187
+ res = self.client.embeddings.create(input=[text],
188
+ model=self.model_name)
189
+ return np.array(res.data[0].embedding), res.usage.total_tokens
rag/settings.py CHANGED
@@ -34,7 +34,7 @@ LoggerFactory.set_directory(
34
  "logs",
35
  "rag"))
36
  # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
37
- LoggerFactory.LEVEL = 10
38
 
39
  es_logger = getLogger("es")
40
  minio_logger = getLogger("minio")
 
34
  "logs",
35
  "rag"))
36
  # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
37
+ LoggerFactory.LEVEL = 30
38
 
39
  es_logger = getLogger("es")
40
  minio_logger = getLogger("minio")
rag/svr/task_executor.py CHANGED
@@ -24,6 +24,8 @@ import sys
24
  import time
25
  import traceback
26
  from functools import partial
 
 
27
  from rag.settings import database_logger
28
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
29
  from multiprocessing import Pool
@@ -302,3 +304,4 @@ if __name__ == "__main__":
302
  comm = MPI.COMM_WORLD
303
  while True:
304
  main(int(sys.argv[2]), int(sys.argv[1]))
 
 
24
  import time
25
  import traceback
26
  from functools import partial
27
+
28
+ from api.db.db_models import close_connection
29
  from rag.settings import database_logger
30
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
31
  from multiprocessing import Pool
 
304
  comm = MPI.COMM_WORLD
305
  while True:
306
  main(int(sys.argv[2]), int(sys.argv[1]))
307
+ close_connection()