janar commited on
Commit
9c7a6f3
1 Parent(s): 0760a61

add support for different embeddings

Browse files
Files changed (1) hide show
  1. api/db/vector_store.py +17 -4
api/db/vector_store.py CHANGED
@@ -2,24 +2,37 @@ from abc import abstractmethod
2
  import os
3
  from qdrant_client import QdrantClient
4
  from langchain.embeddings import OpenAIEmbeddings, ElasticsearchEmbeddings
 
5
  from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
6
  from qdrant_client.models import VectorParams, Distance
7
 
8
 
9
  class ToyVectorStore:
10
 
 
 
 
 
 
 
 
 
 
 
11
  @staticmethod
12
  def get_instance():
13
  vector_store = os.getenv("STORE")
 
14
  if vector_store == "ELASTIC":
15
- return ElasticVectorStore()
16
  elif vector_store == "QDRANT":
17
- return QdrantVectorStore()
18
  else:
19
  raise ValueError(f"Invalid vector store {vector_store}")
20
 
21
- def __init__(self):
22
- self.embeddings = OpenAIEmbeddings()
 
23
 
24
  @abstractmethod
25
  def get_collection(self, collection: str="test") -> VectorStore:
 
2
  import os
3
  from qdrant_client import QdrantClient
4
  from langchain.embeddings import OpenAIEmbeddings, ElasticsearchEmbeddings
5
+ from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
6
  from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
7
  from qdrant_client.models import VectorParams, Distance
8
 
9
 
10
  class ToyVectorStore:
11
 
12
+ @staticmethod
13
+ def get_embedding():
14
+ embedding = os.getenv("EMBEDDING")
15
+ if "SENTENCE" == embedding:
16
+ return SentenceTransformerEmbeddings()
17
+ elif "ELASTIC" == embedding:
18
+ return ElasticsearchEmbeddings()
19
+ else:
20
+ return OpenAIEmbeddings()
21
+
22
  @staticmethod
23
  def get_instance():
24
  vector_store = os.getenv("STORE")
25
+
26
  if vector_store == "ELASTIC":
27
+ return ElasticVectorStore(ToyVectorStore.get_embedding())
28
  elif vector_store == "QDRANT":
29
+ return QdrantVectorStore(ToyVectorStore.get_embedding())
30
  else:
31
  raise ValueError(f"Invalid vector store {vector_store}")
32
 
33
+
34
+ def __init__(self, embeddings):
35
+ self.embeddings = embeddings
36
 
37
  @abstractmethod
38
  def get_collection(self, collection: str="test") -> VectorStore: