janar commited on
Commit
ca2fff7
1 Parent(s): 3add6e8
api/db/vector_store.py CHANGED
@@ -1,4 +1,5 @@
1
  from abc import abstractmethod
 
2
  import os
3
  from qdrant_client import QdrantClient
4
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
@@ -7,7 +8,7 @@ from qdrant_client.models import VectorParams, Distance
7
  from db.embedding import Embedding, EMBEDDINGS
8
 
9
 
10
- class ToyVectorStore:
11
 
12
  @staticmethod
13
  def get_embedding():
@@ -17,13 +18,14 @@ class ToyVectorStore:
17
  return EMBEDDINGS[embedding]
18
 
19
  @staticmethod
 
20
  def get_instance():
21
  vector_store = os.getenv("STORE")
22
 
23
  if vector_store == "ELASTIC":
24
- return ElasticVectorStore(ToyVectorStore.get_embedding())
25
  elif vector_store == "QDRANT":
26
- return QdrantVectorStore(ToyVectorStore.get_embedding())
27
  else:
28
  raise ValueError(f"Invalid vector store {vector_store}")
29
 
@@ -47,7 +49,14 @@ class ToyVectorStore:
47
  """
48
  pass
49
 
50
- class ElasticVectorStore(ToyVectorStore):
 
 
 
 
 
 
 
51
  def __init__(self, embeddings):
52
  super().__init__(embeddings)
53
 
@@ -59,8 +68,11 @@ class ElasticVectorStore(ToyVectorStore):
59
  store = self.get_collection(collection)
60
  store.create_index(store.client,collection, dict())
61
 
 
 
 
62
 
63
- class QdrantVectorStore(ToyVectorStore):
64
 
65
  def __init__(self, embeddings):
66
  super().__init__(embeddings)
@@ -75,4 +87,9 @@ class QdrantVectorStore(ToyVectorStore):
75
  self.client.create_collection(collection_name=collection,
76
  vectors_config=VectorParams(size=self.embedding.dimension,
77
  distance=Distance.COSINE))
 
 
 
 
 
78
 
 
1
  from abc import abstractmethod
2
+ from functools import cache
3
  import os
4
  from qdrant_client import QdrantClient
5
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
 
8
  from db.embedding import Embedding, EMBEDDINGS
9
 
10
 
11
+ class Store:
12
 
13
  @staticmethod
14
  def get_embedding():
 
18
  return EMBEDDINGS[embedding]
19
 
20
  @staticmethod
21
+ @cache
22
  def get_instance():
23
  vector_store = os.getenv("STORE")
24
 
25
  if vector_store == "ELASTIC":
26
+ return ElasticVectorStore(Store.get_embedding())
27
  elif vector_store == "QDRANT":
28
+ return QdrantVectorStore(Store.get_embedding())
29
  else:
30
  raise ValueError(f"Invalid vector store {vector_store}")
31
 
 
49
  """
50
  pass
51
 
52
+ @abstractmethod
53
+ def list_collections(self) -> list[dict]:
54
+ """
55
+ Return a list of collections in the vecot store.
56
+ """
57
+ pass
58
+
59
+ class ElasticVectorStore(Store):
60
  def __init__(self, embeddings):
61
  super().__init__(embeddings)
62
 
 
68
  store = self.get_collection(collection)
69
  store.create_index(store.client,collection, dict())
70
 
71
+ def list_collections(self) -> list[dict]:
72
+ #TODO: not impelented
73
+ return []
74
 
75
+ class QdrantVectorStore(Store):
76
 
77
  def __init__(self, embeddings):
78
  super().__init__(embeddings)
 
87
  self.client.create_collection(collection_name=collection,
88
  vectors_config=VectorParams(size=self.embedding.dimension,
89
  distance=Distance.COSINE))
90
+
91
+ def list_collections(self) -> list[dict]:
92
+ """ return a list of collections.
93
+ """
94
+ return [ c for i,c in enumerate(self.client.get_collections().collections)]
95
 
api/document_parsing.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+
3
+ from fastapi import APIRouter, UploadFile, File, Body
4
+ from langchain.schema import Document
5
+ import io
6
+ import os
7
+ from pypdf import PdfReader
8
+ from langchain.text_splitter import SentenceTransformersTokenTextSplitter
9
+ from db.vector_store import Store
10
+
11
+ async def generate_documents(file: UploadFile, file_name: str):
12
+ num=0
13
+ async for txts in convert_documents(file):
14
+ num += 1
15
+ for txt in txts:
16
+ document = Document(page_content=txt,metadata={"file": file_name, "page": num})
17
+ yield document
18
+
19
+ async def convert_documents(file: UploadFile):
20
+ splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
21
+
22
+ #parse pdf document
23
+ if file.content_type == 'application/pdf':
24
+ content = await file.read()
25
+ pdf_reader = PdfReader(io.BytesIO(content))
26
+ try:
27
+ for page in pdf_reader.pages:
28
+ yield splitter.split_text(page.extract_text())
29
+ except Exception as e:
30
+ print(f"Exception {e}")
31
+ elif "text" in file.content_type:
32
+ content = await file.read()
33
+ yield splitter.split_text(content.decode("utf-8"))
34
+ else:
35
+ return
api/main.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python3
2
  from fastapi import FastAPI
3
- from routes import embeddings, search, admin
4
  from fastapi.middleware import Middleware
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from datetime import datetime
@@ -16,9 +16,8 @@ logger.addHandler(handler)
16
 
17
  # Create the FastAPI instance
18
  app = FastAPI()
19
- app.include_router(embeddings.router)
20
  app.include_router(search.router)
21
- app.include_router(admin.router)
22
  app.exception_handler(generic_exception_handler)
23
 
24
  app.add_middleware(CORSMiddleware, allow_origins = ["*"],
 
1
  #!/usr/bin/env python3
2
  from fastapi import FastAPI
3
+ from routes import search, upload
4
  from fastapi.middleware import Middleware
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from datetime import datetime
 
16
 
17
  # Create the FastAPI instance
18
  app = FastAPI()
 
19
  app.include_router(search.router)
20
+ app.include_router(upload.router)
21
  app.exception_handler(generic_exception_handler)
22
 
23
  app.add_middleware(CORSMiddleware, allow_origins = ["*"],
api/routes/admin.py DELETED
@@ -1,16 +0,0 @@
1
- #This is to init the vector store
2
-
3
- from typing import Annotated
4
-
5
- from fastapi import APIRouter, Body
6
- from db.vector_store import ToyVectorStore
7
-
8
- router = APIRouter()
9
-
10
- @router.put("/admin/v1/db")
11
- async def recreate_collection(name: Annotated[str, Body(embed=True)]):
12
- """ `name` of the collection to be created.
13
- If one exits, delete and recreate.
14
- """
15
- print(f"creating collection {name} in db")
16
- return ToyVectorStore.get_instance().create_collection(name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/routes/embeddings.py DELETED
@@ -1,15 +0,0 @@
1
- from fastapi import APIRouter, UploadFile, File
2
- import openai
3
- import io
4
- import os
5
- from pypdf import PdfReader
6
-
7
- router = APIRouter()
8
-
9
- openai.api_key = os.getenv("OPENAI_API_KEY")
10
-
11
- @router.post("/v1/embeddings")
12
- async def embed_doc(file: UploadFile = File(...)):
13
- #for now just truncate based on length of words
14
- content = await file.read()
15
- return openai.Embedding.create(input = content.decode("utf-8"), model = "text-embedding-ada-002")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/routes/search.py CHANGED
@@ -10,65 +10,28 @@ from langchain.schema import Document
10
  from langchain.chains.question_answering import load_qa_chain
11
  from langchain.llms import OpenAI
12
  from langchain.text_splitter import SentenceTransformersTokenTextSplitter
13
- from db.vector_store import ToyVectorStore
14
 
15
  router = APIRouter()
16
  _chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
17
 
18
- @router.post("/v1/docs")
19
- async def create_or_update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
20
- """Create or update an existing collection with information from the file
21
- `name` of the collection
22
- `file` to upload.
23
- `fileName` name of the file.
24
- """
25
-
26
- _db = ToyVectorStore.get_instance().get_collection(name)
27
- if not _db:
28
- #todo. fix this to create a collection, may be.
29
- return JSONResponse(status_code=404, content={})
30
-
31
- async for doc in generate_documents(file, file_name):
32
- print(doc)
33
- _db.add_documents([doc])
34
- #todo return something sensible
35
- return JSONResponse(status_code=200, content={"name": name})
36
 
37
- @router.get("/v1/doc/{name}/answer")
38
  async def answer(name: str, query: str):
39
- """ Answer a question from the collection
40
- `name` of the collection.
41
  `query` to be answered.
42
  """
43
- _db = ToyVectorStore.get_instance().get_collection(name)
44
  print(query)
45
  docs = _db.similarity_search_with_score(query=query)
46
  print(docs)
47
  answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
48
  return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
49
 
50
- async def generate_documents(file: UploadFile, file_name: str):
51
- num=0
52
- async for txts in convert_documents(file):
53
- num += 1
54
- for txt in txts:
55
- document = Document(page_content=txt,metadata={"file": file_name, "page": num})
56
- yield document
57
-
58
- async def convert_documents(file: UploadFile):
59
- splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
60
 
61
- #parse pdf document
62
- if file.content_type == 'application/pdf':
63
- content = await file.read()
64
- pdf_reader = PdfReader(io.BytesIO(content))
65
- try:
66
- for page in pdf_reader.pages:
67
- yield splitter.split_text(page.extract_text())
68
- except Exception as e:
69
- print(f"Exception {e}")
70
- elif "text" in file.content_type:
71
- content = await file.read()
72
- yield splitter.split_text(content.decode("utf-8"))
73
- else:
74
- return
 
10
  from langchain.chains.question_answering import load_qa_chain
11
  from langchain.llms import OpenAI
12
  from langchain.text_splitter import SentenceTransformersTokenTextSplitter
13
+ from db.vector_store import Store
14
 
15
  router = APIRouter()
16
  _chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ @router.get("/v1/docs/{name}/answer")
20
  async def answer(name: str, query: str):
21
+ """ Answer a question from the doc
22
+ `name` of the doc.
23
  `query` to be answered.
24
  """
25
+ _db = Store.get_instance().get_collection(name)
26
  print(query)
27
  docs = _db.similarity_search_with_score(query=query)
28
  print(docs)
29
  answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
30
  return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
31
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ @router.get("/v1/docs")
34
+ async def list() -> list[dict]:
35
+ """ List all the docs.
36
+ """
37
+ return Store.get_instance().list_collections()
 
 
 
 
 
 
 
 
 
api/routes/upload.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This is to init the vector store
2
+
3
+ from typing import Annotated
4
+
5
+ from db.vector_store import Store
6
+ from document_parsing import generate_documents
7
+
8
+ from fastapi import APIRouter, Body
9
+ from fastapi.responses import JSONResponse
10
+ from fastapi import APIRouter, UploadFile, File, Body
11
+
12
+ router = APIRouter()
13
+
14
+ @router.put("/v1/docs")
15
+ async def recreate_collection(name: Annotated[str, Body(embed=True)]):
16
+ """ `name` of the doc to be created.
17
+ If one exits, delete and recreate.
18
+ """
19
+ print(f"creating collection {name} in db")
20
+ return Store.get_instance().create_collection(name)
21
+
22
+ @router.post("/v1/docs")
23
+ async def update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
24
+ """Update an existing document with information from the file.
25
+ If one doesn't exist with name, it creates a new document to update.
26
+ `name` of the collection
27
+ `file` to upload.
28
+ `fileName` name of the file. This is used for metadata purposes only.
29
+ """
30
+
31
+ _db = Store.get_instance().get_collection(name)
32
+ if not _db:
33
+ return JSONResponse(status_code=404, content={})
34
+
35
+ async for doc in generate_documents(file, file_name):
36
+ print(doc)
37
+ _db.add_documents([doc])
38
+ return JSONResponse(status_code=200, content={"name": name})
39
+