jinhai-2012 commited on
Commit
0d51511
·
1 Parent(s): c34c86c

Fix error of changing embedding model (#4184)

Browse files

### What problem does this PR solve?

1. Change embedding model of knowledge base won't change the default
embedding model.
2. Retrieval test bug

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>

Files changed (1) hide show
  1. rag/llm/embedding_model.py +3 -3
rag/llm/embedding_model.py CHANGED
@@ -61,11 +61,11 @@ class DefaultEmbedding(Base):
61
  ^_-
62
 
63
  """
64
- if not settings.LIGHTEN and not DefaultEmbedding._model:
65
  with DefaultEmbedding._model_lock:
66
  from FlagEmbedding import FlagModel
67
  import torch
68
- if not DefaultEmbedding._model:
69
  try:
70
  DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
71
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
@@ -261,7 +261,7 @@ class FastEmbed(DefaultEmbedding):
261
  threads: int | None = None,
262
  **kwargs,
263
  ):
264
- if not settings.LIGHTEN and not FastEmbed._model:
265
  with FastEmbed._model_lock:
266
  from fastembed import TextEmbedding
267
  if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
 
61
  ^_-
62
 
63
  """
64
+ if not settings.LIGHTEN:
65
  with DefaultEmbedding._model_lock:
66
  from FlagEmbedding import FlagModel
67
  import torch
68
+ if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
69
  try:
70
  DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
71
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
 
261
  threads: int | None = None,
262
  **kwargs,
263
  ):
264
+ if not settings.LIGHTEN:
265
  with FastEmbed._model_lock:
266
  from fastembed import TextEmbedding
267
  if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: