0000sir 0000sir Kevin Hu commited on
Commit
13b2570
·
1 Parent(s): 660ec98

Fix keys of Xinference deployed models, especially has the same model name with public hosted models. (#2832)

Browse files

### What problem does this PR solve?

Fix keys of Xinference deployed models, especially has the same model
name with public hosted models.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: 0000sir <0000sir@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>

api/apps/llm_app.py CHANGED
@@ -343,10 +343,10 @@ def list_app():
343
  for m in llms:
344
  m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied
345
 
346
- llm_set = set([m["llm_name"] for m in llms])
347
  for o in objs:
348
  if not o.api_key:continue
349
- if o.llm_name in llm_set:continue
350
  llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
351
 
352
  res = {}
 
343
  for m in llms:
344
  m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied
345
 
346
+ llm_set = set([m["llm_name"]+"@"+m["fid"] for m in llms])
347
  for o in objs:
348
  if not o.api_key:continue
349
+ if o.llm_name+"@"+o.llm_factory in llm_set:continue
350
  llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
351
 
352
  res = {}
api/apps/sdk/doc.py CHANGED
@@ -494,25 +494,24 @@ def set(tenant_id,dataset_id,document_id,chunk_id):
494
 
495
 
496
 
497
- @manager.route('/retrieval', methods=['GET'])
498
  @token_required
499
  def retrieval_test(tenant_id):
500
- req = request.args
501
- req_json = request.json
502
- if not req_json.get("datasets"):
503
  return get_error_data_result("`datasets` is required.")
504
- for id in req_json.get("datasets"):
 
 
505
  if not KnowledgebaseService.query(id=id,tenant_id=tenant_id):
506
  return get_error_data_result(f"You don't own the dataset {id}.")
507
  if "question" not in req_json:
508
  return get_error_data_result("`question` is required.")
509
  page = int(req.get("offset", 1))
510
  size = int(req.get("limit", 30))
511
- question = req_json["question"]
512
- kb_id = req_json["datasets"]
513
- if isinstance(kb_id, str): kb_id = [kb_id]
514
- doc_ids = req_json.get("documents", [])
515
- similarity_threshold = float(req.get("similarity_threshold", 0.0))
516
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
517
  top = int(req.get("top_k", 1024))
518
  if req.get("highlight")=="False" or req.get("highlight")=="false":
 
494
 
495
 
496
 
497
+ @manager.route('/retrieval', methods=['POST'])
498
  @token_required
499
  def retrieval_test(tenant_id):
500
+ req = request.json
501
+ if not req.get("datasets"):
 
502
  return get_error_data_result("`datasets` is required.")
503
+ kb_id = req["datasets"]
504
+ if isinstance(kb_id, str): kb_id = [kb_id]
505
+ for id in kb_id:
506
  if not KnowledgebaseService.query(id=id,tenant_id=tenant_id):
507
  return get_error_data_result(f"You don't own the dataset {id}.")
508
  if "question" not in req_json:
509
  return get_error_data_result("`question` is required.")
510
  page = int(req.get("offset", 1))
511
  size = int(req.get("limit", 30))
512
+ question = req["question"]
513
+ doc_ids = req.get("documents", [])
514
+ similarity_threshold = float(req.get("similarity_threshold", 0.2))
 
 
515
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
516
  top = int(req.get("top_k", 1024))
517
  if req.get("highlight")=="False" or req.get("highlight")=="false":
rag/llm/cv_model.py CHANGED
@@ -453,7 +453,7 @@ class XinferenceCV(Base):
453
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
454
  if base_url.split("/")[-1] != "v1":
455
  base_url = os.path.join(base_url, "v1")
456
- self.client = OpenAI(api_key="xxx", base_url=base_url)
457
  self.model_name = model_name
458
  self.lang = lang
459
 
 
453
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
454
  if base_url.split("/")[-1] != "v1":
455
  base_url = os.path.join(base_url, "v1")
456
+ self.client = OpenAI(api_key=key, base_url=base_url)
457
  self.model_name = model_name
458
  self.lang = lang
459
 
rag/llm/embedding_model.py CHANGED
@@ -274,7 +274,7 @@ class XinferenceEmbed(Base):
274
  def __init__(self, key, model_name="", base_url=""):
275
  if base_url.split("/")[-1] != "v1":
276
  base_url = os.path.join(base_url, "v1")
277
- self.client = OpenAI(api_key="xxx", base_url=base_url)
278
  self.model_name = model_name
279
 
280
  def encode(self, texts: list, batch_size=32):
 
274
  def __init__(self, key, model_name="", base_url=""):
275
  if base_url.split("/")[-1] != "v1":
276
  base_url = os.path.join(base_url, "v1")
277
+ self.client = OpenAI(api_key=key, base_url=base_url)
278
  self.model_name = model_name
279
 
280
  def encode(self, texts: list, batch_size=32):
rag/llm/rerank_model.py CHANGED
@@ -162,7 +162,8 @@ class XInferenceRerank(Base):
162
  self.base_url = base_url
163
  self.headers = {
164
  "Content-Type": "application/json",
165
- "accept": "application/json"
 
166
  }
167
 
168
  def similarity(self, query: str, texts: list):
 
162
  self.base_url = base_url
163
  self.headers = {
164
  "Content-Type": "application/json",
165
+ "accept": "application/json",
166
+ "Authorization": f"Bearer {key}"
167
  }
168
 
169
  def similarity(self, query: str, texts: list):
rag/llm/sequence2txt_model.py CHANGED
@@ -90,6 +90,7 @@ class XinferenceSeq2txt(Base):
90
  def __init__(self,key,model_name="whisper-small",**kwargs):
91
  self.base_url = kwargs.get('base_url', None)
92
  self.model_name = model_name
 
93
 
94
  def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
95
  if isinstance(audio, str):
 
90
  def __init__(self,key,model_name="whisper-small",**kwargs):
91
  self.base_url = kwargs.get('base_url', None)
92
  self.model_name = model_name
93
+ self.key = key
94
 
95
  def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
96
  if isinstance(audio, str):
sdk/python/ragflow/ragflow.py CHANGED
@@ -74,6 +74,12 @@ class RAGFlow:
74
  if res.get("code") != 0:
75
  raise Exception(res["message"])
76
 
 
 
 
 
 
 
77
  def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True,
78
  id: str = None, name: str = None) -> \
79
  List[DataSet]:
 
74
  if res.get("code") != 0:
75
  raise Exception(res["message"])
76
 
77
+ def get_dataset(self,name: str):
78
+ _list = self.list_datasets(name=name)
79
+ if len(_list) > 0:
80
+ return _list[0]
81
+ raise Exception("Dataset %s not found" % name)
82
+
83
  def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True,
84
  id: str = None, name: str = None) -> \
85
  List[DataSet]: