Michael Feil commited on
Commit
b43a465
·
1 Parent(s): 6840107

infinity: Update embedding_model.py (#1109)

Browse files

### What problem does this PR solve?

I implemented infinity, a fast vector embeddings engine.

### Type of change


- [x] Performance Improvement
- [X] Other (please describe):

Files changed (1) hide show
  1. rag/llm/embedding_model.py +42 -1
rag/llm/embedding_model.py CHANGED
@@ -26,6 +26,7 @@ import dashscope
26
  from openai import OpenAI
27
  from FlagEmbedding import FlagModel
28
  import torch
 
29
  import numpy as np
30
 
31
  from api.utils.file_utils import get_home_cache_dir
@@ -304,4 +305,44 @@ class JinaEmbed(Base):
304
 
305
  def encode_queries(self, text):
306
  embds, cnt = self.encode([text])
307
- return np.array(embds[0]), cnt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  from openai import OpenAI
27
  from FlagEmbedding import FlagModel
28
  import torch
29
+ import asyncio
30
  import numpy as np
31
 
32
  from api.utils.file_utils import get_home_cache_dir
 
305
 
306
  def encode_queries(self, text):
307
  embds, cnt = self.encode([text])
308
+ return np.array(embds[0]), cnt
309
+
310
+
311
+ class InfinityEmbed(Base):
312
+ _model = None
313
+
314
+ def __init__(
315
+ self,
316
+ model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
317
+ engine_kwargs: dict = {},
318
+ key = None,
319
+ ):
320
+
321
+ from infinity_emb import EngineArgs
322
+ from infinity_emb.engine import AsyncEngineArray
323
+
324
+ self._default_model = model_names[0]
325
+ self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
326
+
327
+ async def _embed(self, sentences: list[str], model_name: str = ""):
328
+ if not model_name:
329
+ model_name = self._default_model
330
+ engine = self.engine_array[model_name]
331
+ was_already_running = engine.is_running
332
+ if not was_already_running:
333
+ await engine.astart()
334
+ embeddings, usage = await engine.embed(sentences=sentences)
335
+ if not was_already_running:
336
+ await engine.astop()
337
+ return embeddings, usage
338
+
339
+ def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
340
+ # Using the internal tokenizer to encode the texts and get the total
341
+ # number of tokens
342
+ embeddings, usage = asyncio.run(self._embed(texts, model_name))
343
+ return np.array(embeddings), usage
344
+
345
+ def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
346
+ # Using the internal tokenizer to encode the texts and get the total
347
+ # number of tokens
348
+ return self.encode([text])