janar commited on
Commit
1bec7d8
1 Parent(s): 9c7a6f3

Complete SentenceTransformer integration

Browse files

Introduced a way to mix and match embedding.

Only marginal improvement with Sentence Embedding.

todo:
clean up the config mess. new var EMBEDDING introduced

Files changed (3) hide show
  1. api/db/vector_store.py +17 -16
  2. api/routes/search.py +11 -4
  3. requirements.txt +2 -0
api/db/vector_store.py CHANGED
@@ -1,10 +1,10 @@
1
  from abc import abstractmethod
2
  import os
3
  from qdrant_client import QdrantClient
4
- from langchain.embeddings import OpenAIEmbeddings, ElasticsearchEmbeddings
5
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
6
  from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
7
  from qdrant_client.models import VectorParams, Distance
 
8
 
9
 
10
  class ToyVectorStore:
@@ -12,13 +12,10 @@ class ToyVectorStore:
12
  @staticmethod
13
  def get_embedding():
14
  embedding = os.getenv("EMBEDDING")
15
- if "SENTENCE" == embedding:
16
- return SentenceTransformerEmbeddings()
17
- elif "ELASTIC" == embedding:
18
- return ElasticsearchEmbeddings()
19
- else:
20
- return OpenAIEmbeddings()
21
-
22
  @staticmethod
23
  def get_instance():
24
  vector_store = os.getenv("STORE")
@@ -31,8 +28,8 @@ class ToyVectorStore:
31
  raise ValueError(f"Invalid vector store {vector_store}")
32
 
33
 
34
- def __init__(self, embeddings):
35
- self.embeddings = embeddings
36
 
37
  @abstractmethod
38
  def get_collection(self, collection: str="test") -> VectorStore:
@@ -41,7 +38,7 @@ class ToyVectorStore:
41
  of collection
42
  """
43
  pass
44
-
45
  @abstractmethod
46
  def create_collection(self, collection: str) -> None:
47
  """
@@ -51,11 +48,13 @@ class ToyVectorStore:
51
  pass
52
 
53
  class ElasticVectorStore(ToyVectorStore):
 
 
54
 
55
  def get_collection(self, collection:str) -> ElasticVectorSearch:
56
  return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
57
- index_name= collection, embedding=self.embeddings)
58
-
59
  def create_collection(self, collection: str) -> None:
60
  store = self.get_collection(collection)
61
  store.create_index(store.client,collection, dict())
@@ -63,15 +62,17 @@ class ElasticVectorStore(ToyVectorStore):
63
 
64
  class QdrantVectorStore(ToyVectorStore):
65
 
66
- def __init__(self):
 
67
  self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
68
  api_key=os.getenv("QDRANT_API_KEY"))
69
 
70
  def get_collection(self, collection: str) -> Qdrant:
71
  return Qdrant(client=self.client,collection_name=collection,
72
- embeddings=self.embeddings)
73
 
74
  def create_collection(self, collection: str) -> None:
75
  self.client.create_collection(collection_name=collection,
76
- vectors_config=VectorParams(size=1536, distance=Distance.COSINE))
 
77
 
 
1
  from abc import abstractmethod
2
  import os
3
  from qdrant_client import QdrantClient
 
4
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
5
  from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
6
  from qdrant_client.models import VectorParams, Distance
7
+ from db.embedding import Embedding, EMBEDDINGS
8
 
9
 
10
  class ToyVectorStore:
 
12
  @staticmethod
13
  def get_embedding():
14
  embedding = os.getenv("EMBEDDING")
15
+ if not embedding:
16
+ return EMBEDDINGS["OPEN_AI"]
17
+ return EMBEDDINGS[embedding]
18
+
 
 
 
19
  @staticmethod
20
  def get_instance():
21
  vector_store = os.getenv("STORE")
 
28
  raise ValueError(f"Invalid vector store {vector_store}")
29
 
30
 
31
+ def __init__(self, embedding: Embedding):
32
+ self.embedding = embedding
33
 
34
  @abstractmethod
35
  def get_collection(self, collection: str="test") -> VectorStore:
 
38
  of collection
39
  """
40
  pass
41
+
42
  @abstractmethod
43
  def create_collection(self, collection: str) -> None:
44
  """
 
48
  pass
49
 
50
  class ElasticVectorStore(ToyVectorStore):
51
+ def __init__(self, embeddings):
52
+ super().__init__(embeddings)
53
 
54
  def get_collection(self, collection:str) -> ElasticVectorSearch:
55
  return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
56
+ index_name= collection, embedding=self.embedding.embedding)
57
+
58
  def create_collection(self, collection: str) -> None:
59
  store = self.get_collection(collection)
60
  store.create_index(store.client,collection, dict())
 
62
 
63
  class QdrantVectorStore(ToyVectorStore):
64
 
65
+ def __init__(self, embeddings):
66
+ super().__init__(embeddings)
67
  self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
68
  api_key=os.getenv("QDRANT_API_KEY"))
69
 
70
  def get_collection(self, collection: str) -> Qdrant:
71
  return Qdrant(client=self.client,collection_name=collection,
72
+ embeddings=self.embedding.embedding)
73
 
74
  def create_collection(self, collection: str) -> None:
75
  self.client.create_collection(collection_name=collection,
76
+ vectors_config=VectorParams(size=self.embedding.dimension,
77
+ distance=Distance.COSINE))
78
 
api/routes/search.py CHANGED
@@ -11,6 +11,7 @@ from langchain.vectorstores import Qdrant
11
  from langchain.schema import Document
12
  from langchain.chains.question_answering import load_qa_chain
13
  from langchain.llms import OpenAI
 
14
  from db.vector_store import ToyVectorStore
15
 
16
  router = APIRouter()
@@ -50,20 +51,26 @@ async def answer(name: str, query: str):
50
 
51
  async def generate_documents(file: UploadFile, file_name: str):
52
  num=0
53
- async for txt in convert_documents(file):
54
  num += 1
55
- document = Document(page_content=txt,metadata={"file": file_name, "page": num})
56
- yield document
 
57
 
58
  async def convert_documents(file: UploadFile):
 
 
59
  #parse pdf document
60
  if file.content_type == 'application/pdf':
61
  content = await file.read()
62
  pdf_reader = PdfReader(io.BytesIO(content))
63
  try:
64
  for page in pdf_reader.pages:
65
- yield page.extract_text()
66
  except Exception as e:
67
  print(f"Exception {e}")
 
 
 
68
  else:
69
  return
 
11
  from langchain.schema import Document
12
  from langchain.chains.question_answering import load_qa_chain
13
  from langchain.llms import OpenAI
14
+ from langchain.text_splitter import SentenceTransformersTokenTextSplitter
15
  from db.vector_store import ToyVectorStore
16
 
17
  router = APIRouter()
 
51
 
52
  async def generate_documents(file: UploadFile, file_name: str):
53
  num=0
54
+ async for txts in convert_documents(file):
55
  num += 1
56
+ for txt in txts:
57
+ document = Document(page_content=txt,metadata={"file": file_name, "page": num})
58
+ yield document
59
 
60
  async def convert_documents(file: UploadFile):
61
+ splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
62
+
63
  #parse pdf document
64
  if file.content_type == 'application/pdf':
65
  content = await file.read()
66
  pdf_reader = PdfReader(io.BytesIO(content))
67
  try:
68
  for page in pdf_reader.pages:
69
+ yield splitter.split_text(page.extract_text())
70
  except Exception as e:
71
  print(f"Exception {e}")
72
+ elif "text" in file.content_type:
73
+ content = await file.read()
74
+ yield splitter.split_text(content.decode("utf-8"))
75
  else:
76
  return
requirements.txt CHANGED
@@ -9,3 +9,5 @@ tiktoken
9
  faiss-cpu
10
  qdrant-client
11
  elasticsearch
 
 
 
9
  faiss-cpu
10
  qdrant-client
11
  elasticsearch
12
+ sentence_transformers
13
+