Complete SentenceTransformer integration
Browse filesIntroduced a way to mix and match embedding.
Only marginal improvement with Sentence Embedding.
todo:
clean up the config mess. new var EMBEDDING introduced
- api/db/vector_store.py +17 -16
- api/routes/search.py +11 -4
- requirements.txt +2 -0
api/db/vector_store.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from abc import abstractmethod
|
2 |
import os
|
3 |
from qdrant_client import QdrantClient
|
4 |
-
from langchain.embeddings import OpenAIEmbeddings, ElasticsearchEmbeddings
|
5 |
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
6 |
from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
|
7 |
from qdrant_client.models import VectorParams, Distance
|
|
|
8 |
|
9 |
|
10 |
class ToyVectorStore:
|
@@ -12,13 +12,10 @@ class ToyVectorStore:
|
|
12 |
@staticmethod
|
13 |
def get_embedding():
|
14 |
embedding = os.getenv("EMBEDDING")
|
15 |
-
if
|
16 |
-
return
|
17 |
-
|
18 |
-
|
19 |
-
else:
|
20 |
-
return OpenAIEmbeddings()
|
21 |
-
|
22 |
@staticmethod
|
23 |
def get_instance():
|
24 |
vector_store = os.getenv("STORE")
|
@@ -31,8 +28,8 @@ class ToyVectorStore:
|
|
31 |
raise ValueError(f"Invalid vector store {vector_store}")
|
32 |
|
33 |
|
34 |
-
def __init__(self,
|
35 |
-
self.
|
36 |
|
37 |
@abstractmethod
|
38 |
def get_collection(self, collection: str="test") -> VectorStore:
|
@@ -41,7 +38,7 @@ class ToyVectorStore:
|
|
41 |
of collection
|
42 |
"""
|
43 |
pass
|
44 |
-
|
45 |
@abstractmethod
|
46 |
def create_collection(self, collection: str) -> None:
|
47 |
"""
|
@@ -51,11 +48,13 @@ class ToyVectorStore:
|
|
51 |
pass
|
52 |
|
53 |
class ElasticVectorStore(ToyVectorStore):
|
|
|
|
|
54 |
|
55 |
def get_collection(self, collection:str) -> ElasticVectorSearch:
|
56 |
return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
|
57 |
-
index_name= collection, embedding=self.
|
58 |
-
|
59 |
def create_collection(self, collection: str) -> None:
|
60 |
store = self.get_collection(collection)
|
61 |
store.create_index(store.client,collection, dict())
|
@@ -63,15 +62,17 @@ class ElasticVectorStore(ToyVectorStore):
|
|
63 |
|
64 |
class QdrantVectorStore(ToyVectorStore):
|
65 |
|
66 |
-
def __init__(self):
|
|
|
67 |
self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
|
68 |
api_key=os.getenv("QDRANT_API_KEY"))
|
69 |
|
70 |
def get_collection(self, collection: str) -> Qdrant:
|
71 |
return Qdrant(client=self.client,collection_name=collection,
|
72 |
-
embeddings=self.
|
73 |
|
74 |
def create_collection(self, collection: str) -> None:
|
75 |
self.client.create_collection(collection_name=collection,
|
76 |
-
vectors_config=VectorParams(size=
|
|
|
77 |
|
|
|
1 |
from abc import abstractmethod
|
2 |
import os
|
3 |
from qdrant_client import QdrantClient
|
|
|
4 |
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
5 |
from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
|
6 |
from qdrant_client.models import VectorParams, Distance
|
7 |
+
from db.embedding import Embedding, EMBEDDINGS
|
8 |
|
9 |
|
10 |
class ToyVectorStore:
|
|
|
12 |
@staticmethod
|
13 |
def get_embedding():
|
14 |
embedding = os.getenv("EMBEDDING")
|
15 |
+
if not embedding:
|
16 |
+
return EMBEDDINGS["OPEN_AI"]
|
17 |
+
return EMBEDDINGS[embedding]
|
18 |
+
|
|
|
|
|
|
|
19 |
@staticmethod
|
20 |
def get_instance():
|
21 |
vector_store = os.getenv("STORE")
|
|
|
28 |
raise ValueError(f"Invalid vector store {vector_store}")
|
29 |
|
30 |
|
31 |
+
def __init__(self, embedding: Embedding):
|
32 |
+
self.embedding = embedding
|
33 |
|
34 |
@abstractmethod
|
35 |
def get_collection(self, collection: str="test") -> VectorStore:
|
|
|
38 |
of collection
|
39 |
"""
|
40 |
pass
|
41 |
+
|
42 |
@abstractmethod
|
43 |
def create_collection(self, collection: str) -> None:
|
44 |
"""
|
|
|
48 |
pass
|
49 |
|
50 |
class ElasticVectorStore(ToyVectorStore):
|
51 |
+
def __init__(self, embeddings):
|
52 |
+
super().__init__(embeddings)
|
53 |
|
54 |
def get_collection(self, collection:str) -> ElasticVectorSearch:
|
55 |
return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
|
56 |
+
index_name= collection, embedding=self.embedding.embedding)
|
57 |
+
|
58 |
def create_collection(self, collection: str) -> None:
|
59 |
store = self.get_collection(collection)
|
60 |
store.create_index(store.client,collection, dict())
|
|
|
62 |
|
63 |
class QdrantVectorStore(ToyVectorStore):
|
64 |
|
65 |
+
def __init__(self, embeddings):
|
66 |
+
super().__init__(embeddings)
|
67 |
self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
|
68 |
api_key=os.getenv("QDRANT_API_KEY"))
|
69 |
|
70 |
def get_collection(self, collection: str) -> Qdrant:
|
71 |
return Qdrant(client=self.client,collection_name=collection,
|
72 |
+
embeddings=self.embedding.embedding)
|
73 |
|
74 |
def create_collection(self, collection: str) -> None:
|
75 |
self.client.create_collection(collection_name=collection,
|
76 |
+
vectors_config=VectorParams(size=self.embedding.dimension,
|
77 |
+
distance=Distance.COSINE))
|
78 |
|
api/routes/search.py
CHANGED
@@ -11,6 +11,7 @@ from langchain.vectorstores import Qdrant
|
|
11 |
from langchain.schema import Document
|
12 |
from langchain.chains.question_answering import load_qa_chain
|
13 |
from langchain.llms import OpenAI
|
|
|
14 |
from db.vector_store import ToyVectorStore
|
15 |
|
16 |
router = APIRouter()
|
@@ -50,20 +51,26 @@ async def answer(name: str, query: str):
|
|
50 |
|
51 |
async def generate_documents(file: UploadFile, file_name: str):
|
52 |
num=0
|
53 |
-
async for
|
54 |
num += 1
|
55 |
-
|
56 |
-
|
|
|
57 |
|
58 |
async def convert_documents(file: UploadFile):
|
|
|
|
|
59 |
#parse pdf document
|
60 |
if file.content_type == 'application/pdf':
|
61 |
content = await file.read()
|
62 |
pdf_reader = PdfReader(io.BytesIO(content))
|
63 |
try:
|
64 |
for page in pdf_reader.pages:
|
65 |
-
yield page.extract_text()
|
66 |
except Exception as e:
|
67 |
print(f"Exception {e}")
|
|
|
|
|
|
|
68 |
else:
|
69 |
return
|
|
|
11 |
from langchain.schema import Document
|
12 |
from langchain.chains.question_answering import load_qa_chain
|
13 |
from langchain.llms import OpenAI
|
14 |
+
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
|
15 |
from db.vector_store import ToyVectorStore
|
16 |
|
17 |
router = APIRouter()
|
|
|
51 |
|
52 |
async def generate_documents(file: UploadFile, file_name: str):
|
53 |
num=0
|
54 |
+
async for txts in convert_documents(file):
|
55 |
num += 1
|
56 |
+
for txt in txts:
|
57 |
+
document = Document(page_content=txt,metadata={"file": file_name, "page": num})
|
58 |
+
yield document
|
59 |
|
60 |
async def convert_documents(file: UploadFile):
|
61 |
+
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
62 |
+
|
63 |
#parse pdf document
|
64 |
if file.content_type == 'application/pdf':
|
65 |
content = await file.read()
|
66 |
pdf_reader = PdfReader(io.BytesIO(content))
|
67 |
try:
|
68 |
for page in pdf_reader.pages:
|
69 |
+
yield splitter.split_text(page.extract_text())
|
70 |
except Exception as e:
|
71 |
print(f"Exception {e}")
|
72 |
+
elif "text" in file.content_type:
|
73 |
+
content = await file.read()
|
74 |
+
yield splitter.split_text(content.decode("utf-8"))
|
75 |
else:
|
76 |
return
|
requirements.txt
CHANGED
@@ -9,3 +9,5 @@ tiktoken
|
|
9 |
faiss-cpu
|
10 |
qdrant-client
|
11 |
elasticsearch
|
|
|
|
|
|
9 |
faiss-cpu
|
10 |
qdrant-client
|
11 |
elasticsearch
|
12 |
+
sentence_transformers
|
13 |
+
|