mitulagr2 commited on
Commit
1ff6584
1 Parent(s): f562d60

Migrate to llama-index

Browse files
Files changed (4) hide show
  1. app/main.py +13 -11
  2. app/rag.py +81 -51
  3. requirements.txt +4 -5
  4. start_service.sh +3 -0
app/main.py CHANGED
@@ -38,17 +38,19 @@ def upload(files: list[UploadFile]):
38
  session_assistant.clear()
39
  session_messages = []
40
 
41
- for file in files:
42
- path = f"files/{file.filename}"
43
- try:
44
- suffix = Path(file.filename).suffix
45
- with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
46
- shutil.copyfileobj(file.file, tmp)
47
- tmp_path = Path(tmp.name)
48
- session_assistant.ingest(tmp_path)
49
- os.remove(tmp_path)
50
- finally:
51
- file.file.close()
 
 
52
 
53
  return "Files inserted!"
54
 
 
38
  session_assistant.clear()
39
  session_messages = []
40
 
41
+ try:
42
+ for file in files:
43
+ path = f"files/{file.filename}"
44
+ try:
45
+ suffix = Path(file.filename).suffix
46
+ with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
47
+ shutil.copyfileobj(file.file, tmp)
48
+
49
+ finally:
50
+ file.file.close()
51
+ finally:
52
+ session_assistant.ingest("files/")
53
+ os.remove("files/")
54
 
55
  return "Files inserted!"
56
 
app/rag.py CHANGED
@@ -1,67 +1,97 @@
1
- from langchain_community.vectorstores import Chroma
2
- from langchain_community.llms import Ollama
3
- from langchain_community.embeddings import FastEmbedEmbeddings
4
- from langchain.schema.output_parser import StrOutputParser
5
- from langchain_community.document_loaders import PyMuPDFLoader
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.schema.runnable import RunnablePassthrough
8
- from langchain.prompts import PromptTemplate
9
- from langchain_community.vectorstores.utils import filter_complex_metadata
 
 
 
 
 
 
 
10
 
11
 
12
  class ChatPDF:
13
- vector_store = None
14
- retriever = None
15
- chain = None
16
 
17
  def __init__(self):
18
- self.model = Ollama(
19
- model="qwen:1.8b",
20
- keep_alive=-1,
21
- temperature=0,
22
- num_predict=512,
23
- repeat_penalty=1.3,
24
- repeat_last_n=-1
25
- )
26
 
27
- self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
28
- self.prompt = PromptTemplate.from_template(
29
- """
30
- <|im_start|> You are an assistant for question-answering tasks. Use the following pieces of retrieved context to
31
- answer the question. If you don't know the answer, just say that you don't know. Use 512 characters
32
- maximum and keep the answer concise.
33
- Question: {question}
34
- Context: {context}
35
- Answer: <|im_end|>
36
- """
37
- )
 
38
 
39
- def ingest(self, pdf_file_path: str):
40
- docs = PyMuPDFLoader(file_path=pdf_file_path).load()
41
- chunks = self.text_splitter.split_documents(docs)
42
- chunks = filter_complex_metadata(chunks)
43
 
44
- vector_store = Chroma.from_documents(documents=chunks, embedding=FastEmbedEmbeddings())
45
- self.retriever = vector_store.as_retriever(
46
- search_type="similarity_score_threshold",
47
- search_kwargs={
48
- "k": 4,
49
- "score_threshold": 0.5,
50
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
- self.chain = ({"context": self.retriever, "question": RunnablePassthrough()}
54
- | self.prompt
55
- | self.model
56
- | StrOutputParser())
 
 
 
 
 
 
 
57
 
58
  def ask(self, query: str):
59
- if not self.chain:
60
  return "Please, add a PDF document first."
61
 
62
- return self.chain.invoke(query)
 
 
 
63
 
64
  def clear(self):
65
- self.vector_store = None
66
- self.retriever = None
67
- self.chain = None
 
1
+ from llama_index.core import (
2
+ SimpleDirectoryReader,
3
+ VectorStoreIndex,
4
+ StorageContext,
5
+ Settings,
6
+ get_response_synthesizer)
7
+ from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
8
+ from llama_index.core.node_parser import SentenceSplitter
9
+ from llama_index.core.schema import TextNode, MetadataMode
10
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
11
+ from llama_index.embeddings.ollama import OllamaEmbedding
12
+ from llama_index.llms.ollama import Ollama
13
+ from llama_index.core.retrievers import VectorIndexRetriever
14
+ from llama_index.core.indices.query.query_transform import HyDEQueryTransform
15
+ import qdrant_client
16
+ import logging
17
 
18
 
19
  class ChatPDF:
20
+ text_chunks = []
21
+ doc_ids = []
22
+ nodes = []
23
 
24
  def __init__(self):
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
27
 
28
+ text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
29
+
30
+ logger.info("initializing the vector store related objects")
31
+ client = qdrant_client.QdrantClient(host="localhost", port=6333)
32
+ vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
33
+
34
+ logger.info("initializing the OllamaEmbedding")
35
+ embed_model = OllamaEmbedding(model_name='mxbai-embed-large', request_timeout=1000000)
36
+ logger.info("initializing the global settings")
37
+ Settings.embed_model = embed_model
38
+ Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
39
+ Settings.transformations = [text_parser]
40
 
41
+ def ingest(self, dir_path: str):
42
+ docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
 
 
43
 
44
+ logger.info("enumerating docs")
45
+ for doc_idx, doc in enumerate(docs):
46
+ curr_text_chunks = text_parser.split_text(doc.text)
47
+ text_chunks.extend(curr_text_chunks)
48
+ doc_ids.extend([doc_idx] * len(curr_text_chunks))
49
+
50
+ logger.info("enumerating text_chunks")
51
+ for idx, text_chunk in enumerate(text_chunks):
52
+ node = TextNode(text=text_chunk)
53
+ src_doc = docs[doc_ids[idx]]
54
+ node.metadata = src_doc.metadata
55
+ nodes.append(node)
56
+
57
+ logger.info("enumerating nodes")
58
+ for node in nodes:
59
+ node_embedding = embed_model.get_text_embedding(
60
+ node.get_content(metadata_mode=MetadataMode.ALL)
61
+ )
62
+ node.embedding = node_embedding
63
+
64
+ logger.info("initializing the storage context")
65
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
66
+ logger.info("indexing the nodes in VectorStoreIndex")
67
+ index = VectorStoreIndex(
68
+ nodes=nodes,
69
+ storage_context=storage_context,
70
+ transformations=Settings.transformations,
71
  )
72
 
73
+ logger.info("initializing the VectorIndexRetriever with top_k as 5")
74
+ vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
75
+ response_synthesizer = get_response_synthesizer()
76
+ logger.info("creating the RetrieverQueryEngine instance")
77
+ vector_query_engine = RetrieverQueryEngine(
78
+ retriever=vector_retriever,
79
+ response_synthesizer=response_synthesizer,
80
+ )
81
+ logger.info("creating the HyDEQueryTransform instance")
82
+ hyde = HyDEQueryTransform(include_original=True)
83
+ self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
84
 
85
  def ask(self, query: str):
86
+ if not self.hyde_query_engine:
87
  return "Please, add a PDF document first."
88
 
89
+ logger.info("retrieving the response to the query")
90
+ response = self.hyde_query_engine.query(str_or_query_bundle=query)
91
+ print(response)
92
+ return response
93
 
94
  def clear(self):
95
+ self.text_chunks = []
96
+ self.doc_ids = []
97
+ self.nodes = []
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  fastapi
2
- pymupdf
3
- langchain
4
- langchain-community
5
- fastembed
6
- chromadb
 
1
  fastapi
2
+ llama-index
3
+ llama-index-vector-stores-qdrant
4
+ llama-index-embeddings-ollama
5
+ llama-index-llms-ollama
 
start_service.sh CHANGED
@@ -6,6 +6,9 @@ ollama serve &
6
  # Wait for Ollama to start
7
  sleep 5
8
 
 
 
 
9
  # Pull and run <YOUR_MODEL_NAME>
10
  ollama pull qwen:1.8b
11
 
 
6
  # Wait for Ollama to start
7
  sleep 5
8
 
9
+ #
10
+ ollama pull mxbai-embed-large
11
+
12
  # Pull and run <YOUR_MODEL_NAME>
13
  ollama pull qwen:1.8b
14