thomasjacob04 commited on
Commit
5061d0f
·
verified ·
1 Parent(s): 89e9286

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -80
app.py CHANGED
@@ -1,103 +1,117 @@
1
  import os
2
-
3
  from dotenv import load_dotenv
4
-
5
- load_dotenv()
6
-
7
  from typing import Iterator
8
-
9
  from langchain_core.document_loaders import BaseLoader
10
  from langchain_core.documents import Document as LCDocument
11
-
12
  from docling.document_converter import DocumentConverter
 
 
 
 
 
 
 
 
13
 
14
- # import gradio as gr
15
- from typing import Iterator
 
16
 
17
  class DoclingPDFLoader(BaseLoader):
18
-
19
  def __init__(self, file_path: str | list[str]) -> None:
20
  self._file_paths = file_path if isinstance(file_path, list) else [file_path]
21
  self._converter = DocumentConverter()
22
-
23
  def lazy_load(self) -> Iterator[LCDocument]:
24
  for source in self._file_paths:
25
  dl_doc = self._converter.convert(source).document
26
  text = dl_doc.export_to_markdown()
27
  yield LCDocument(page_content=text)
28
 
29
- FILE_PATH = "10_Pages_Vol_5.pdf" # test paper
30
-
31
- from langchain_text_splitters import RecursiveCharacterTextSplitter
32
-
33
- loader = DoclingPDFLoader(file_path=FILE_PATH)
34
- text_splitter = RecursiveCharacterTextSplitter(
35
- chunk_size=1000,
36
- chunk_overlap=200,
37
- )
38
-
39
- docs = loader.load()
40
- splits = text_splitter.split_documents(docs)
41
-
42
- from langchain_huggingface.embeddings import HuggingFaceEmbeddings
43
-
44
- HF_EMBED_MODEL_ID = "BAAI/bge-small-en-v1.5"
45
- embeddings = HuggingFaceEmbeddings(model_name=HF_EMBED_MODEL_ID)
46
-
47
- from tempfile import TemporaryDirectory
48
-
49
- from langchain_milvus import Milvus
50
-
51
- MILVUS_URI = os.environ.get(
52
- "MILVUS_URI", f"{(tmp_dir := TemporaryDirectory()).name}/milvus_demo.db"
53
- )
54
-
55
- vectorstore = Milvus.from_documents(
56
- splits,
57
- embeddings,
58
- connection_args={"uri": MILVUS_URI},
59
- drop_old=True,
60
- index_params={"index_type": "IVF_FLAT", "metric_type": "L2"},
61
- )
62
-
63
- from langchain_huggingface import HuggingFaceEndpoint
64
-
65
- HF_API_KEY = os.environ.get("HF_API_KEY")
66
- HF_LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
67
-
68
- llm = HuggingFaceEndpoint(
69
- repo_id=HF_LLM_MODEL_ID,
70
- huggingfacehub_api_token=HF_API_KEY,
71
- task="text-generation", # Add this line to specify the task
72
- )
73
-
74
- from typing import Iterable
75
-
76
- from langchain_core.documents import Document as LCDocument
77
- from langchain_core.output_parsers import StrOutputParser
78
- from langchain_core.prompts import PromptTemplate
79
- from langchain_core.runnables import RunnablePassthrough
80
-
81
-
82
- def format_docs(docs: Iterable[LCDocument]):
83
  return "\n\n".join(doc.page_content for doc in docs)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- retriever = vectorstore.as_retriever()
87
-
88
- prompt = PromptTemplate.from_template(
89
- "Context information is below.\n---------------------\n{context}\n---------------------\nUse the context of the work you have been currently trained on, not your prior knowledge, to answer the queries asked. Please use Chapter numbers and page numbers as references as well.\nQuery: {question}\nAnswer:\n"
90
- )
91
-
92
- rag_chain = (
93
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
94
- | prompt
95
- | llm
96
- | StrOutputParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
 
99
- rag_chain.invoke("who are the members of the Sanhedrin who are present?")
100
-
101
-
102
-
103
-
 
1
  import os
2
+ import gradio as gr
3
  from dotenv import load_dotenv
 
 
 
4
  from typing import Iterator
 
5
  from langchain_core.document_loaders import BaseLoader
6
  from langchain_core.documents import Document as LCDocument
 
7
  from docling.document_converter import DocumentConverter
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
10
+ from langchain_milvus import Milvus
11
+ from langchain_huggingface import HuggingFaceEndpoint
12
+ from langchain_core.prompts import PromptTemplate
13
+ from langchain_core.runnables import RunnablePassthrough
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from tempfile import TemporaryDirectory
16
 
17
+ # Load environment variables
18
+ load_dotenv()
19
+ HF_API_KEY = os.environ.get("HF_API_KEY")
20
 
21
  class DoclingPDFLoader(BaseLoader):
 
22
  def __init__(self, file_path: str | list[str]) -> None:
23
  self._file_paths = file_path if isinstance(file_path, list) else [file_path]
24
  self._converter = DocumentConverter()
25
+
26
  def lazy_load(self) -> Iterator[LCDocument]:
27
  for source in self._file_paths:
28
  dl_doc = self._converter.convert(source).document
29
  text = dl_doc.export_to_markdown()
30
  yield LCDocument(page_content=text)
31
 
32
+ def format_docs(docs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return "\n\n".join(doc.page_content for doc in docs)
34
 
35
+ def setup_rag_chain(pdf_path):
36
+ # Initialize loader and split documents
37
+ loader = DoclingPDFLoader(file_path=pdf_path)
38
+ text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=1000,
40
+ chunk_overlap=200,
41
+ )
42
+ docs = loader.load()
43
+ splits = text_splitter.split_documents(docs)
44
+
45
+ # Setup embeddings
46
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
47
+
48
+ # Setup Milvus vectorstore
49
+ tmp_dir = TemporaryDirectory()
50
+ MILVUS_URI = f"{tmp_dir.name}/milvus_demo.db"
51
+ vectorstore = Milvus.from_documents(
52
+ splits,
53
+ embeddings,
54
+ connection_args={"uri": MILVUS_URI},
55
+ drop_old=True,
56
+ index_params={"index_type": "IVF_FLAT", "metric_type": "L2"},
57
+ )
58
+
59
+ # Setup LLM
60
+ llm = HuggingFaceEndpoint(
61
+ repo_id="mistralai/Mistral-7B-Instruct-v0.3",
62
+ huggingfacehub_api_token=HF_API_KEY,
63
+ task="text-generation",
64
+ )
65
+
66
+ # Setup RAG chain
67
+ retriever = vectorstore.as_retriever()
68
+ prompt = PromptTemplate.from_template(
69
+ "Context information is below.\n---------------------\n{context}\n---------------------\nUse the context of the work you have been currently trained on, not your prior knowledge, to answer the queries asked. Please use Chapter numbers and page numbers as references as well.\nQuery: {question}\nAnswer:\n"
70
+ )
71
+
72
+ return (
73
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
74
+ | prompt
75
+ | llm
76
+ | StrOutputParser()
77
+ )
78
 
79
+ def process_query(pdf_file, query):
80
+ if pdf_file is None:
81
+ return "Please upload a PDF file first."
82
+
83
+ # Save the uploaded file temporarily
84
+ temp_pdf_path = "temp_upload.pdf"
85
+ with open(temp_pdf_path, "wb") as f:
86
+ f.write(pdf_file)
87
+
88
+ try:
89
+ # Setup and run the RAG chain
90
+ rag_chain = setup_rag_chain(temp_pdf_path)
91
+ response = rag_chain.invoke(query)
92
+ return response
93
+ except Exception as e:
94
+ return f"An error occurred: {str(e)}"
95
+ finally:
96
+ # Clean up temporary file
97
+ if os.path.exists(temp_pdf_path):
98
+ os.remove(temp_pdf_path)
99
+
100
+ # Create Gradio interface
101
+ demo = gr.Interface(
102
+ fn=process_query,
103
+ inputs=[
104
+ gr.File(label="Upload PDF", file_types=[".pdf"]),
105
+ gr.Textbox(label="Enter your question")
106
+ ],
107
+ outputs=gr.Textbox(label="Answer"),
108
+ title="PDF Question Answering System",
109
+ description="Upload a PDF and ask questions about its content. The system will use RAG to provide relevant answers.",
110
+ examples=[
111
+ [None, "Who are the members of the Sanhedrin who are present?"],
112
+ [None, "What are the main themes discussed in the document?"]
113
+ ]
114
  )
115
 
116
+ if __name__ == "__main__":
117
+ demo.launch()