thomasjacob04 commited on
Commit
314bb9f
·
verified ·
1 Parent(s): f8f632b

Upload 2 files

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. test.pdf +3 -0
  3. train.py +137 -0
.gitattributes CHANGED
@@ -38,3 +38,4 @@ vol2.pdf filter=lfs diff=lfs merge=lfs -text
38
  vol3.pdf filter=lfs diff=lfs merge=lfs -text
39
  vol4.pdf filter=lfs diff=lfs merge=lfs -text
40
  vol5.pdf filter=lfs diff=lfs merge=lfs -text
 
 
38
  vol3.pdf filter=lfs diff=lfs merge=lfs -text
39
  vol4.pdf filter=lfs diff=lfs merge=lfs -text
40
  vol5.pdf filter=lfs diff=lfs merge=lfs -text
41
+ test.pdf filter=lfs diff=lfs merge=lfs -text
test.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecd8e1207b3be0e246d40823509a2c774594319601bd1c28171722f735058a2e
3
+ size 381011
train.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fitz # PyMuPDF
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+ from milvus import Milvus, DataType
5
+ import os
6
+ from langchain_community.llms import HuggingFaceEndpoint
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.schema import StrOutputParser
9
+ from langchain.schema.runnable import RunnablePassthrough
10
+ from langchain.schema import BaseLoader, LCDocument
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from docling.document_converter import DocumentConverter
13
+ import gradio as gr
14
+ from typing import Iterator
15
+
16
+ # Initialize Milvus
17
+ milvus = Milvus(host='localhost', port='19530')
18
+
19
+ # Load BAAI embedding model
20
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
21
+ model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")
22
+
23
+ # Docling PDF Loader
24
+ class DoclingPDFLoader(BaseLoader):
25
+ def __init__(self, file_path: str | list[str]) -> None:
26
+ self._file_paths = file_path if isinstance(file_path, list) else [file_path]
27
+ self._converter = DocumentConverter()
28
+
29
+ def lazy_load(self) -> Iterator[LCDocument]:
30
+ for source in self._file_paths:
31
+ dl_doc = self._converter.convert(source).document
32
+ text = dl_doc.export_to_markdown()
33
+ yield LCDocument(page_content=text)
34
+
35
+ def load(self) -> list[LCDocument]:
36
+ return list(self.lazy_load())
37
+
38
+ # Function to extract and split text from PDF
39
+ def extract_text_from_pdf(pdf_path):
40
+ loader = DoclingPDFLoader(file_path=pdf_path)
41
+ text_splitter = RecursiveCharacterTextSplitter(
42
+ chunk_size=1000,
43
+ chunk_overlap=200,
44
+ )
45
+ docs = loader.load()
46
+ splits = text_splitter.split_documents(docs)
47
+ return " ".join([doc.page_content for doc in splits])
48
+
49
+ # Set up LLM
50
+ HF_API_KEY = os.environ.get("HF_API_KEY")
51
+ HF_LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
52
+ llm = HuggingFaceEndpoint(
53
+ repo_id=HF_LLM_MODEL_ID,
54
+ huggingfacehub_api_token=HF_API_KEY,
55
+ )
56
+
57
+ # Function to extract text from PDF
58
+ def extract_text_from_pdf(pdf_path):
59
+ doc = fitz.open(pdf_path)
60
+ text = ""
61
+ for page in doc:
62
+ text += page.get_text()
63
+ return text
64
+
65
+ # Function to generate embeddings
66
+ def generate_embeddings(text):
67
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
68
+ with torch.no_grad():
69
+ outputs = model(**inputs)
70
+ return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
71
+
72
+ # Function to insert embeddings into Milvus
73
+ def insert_into_milvus(embeddings):
74
+ collection_name = "pdf_embeddings"
75
+ if not milvus.has_collection(collection_name):
76
+ milvus.create_collection({
77
+ "collection_name": collection_name,
78
+ "dimension": embeddings.shape[0],
79
+ "index_file_size": 1024,
80
+ "metric_type": "L2"
81
+ })
82
+ milvus.insert(collection_name, [embeddings])
83
+
84
+ # Function to query Milvus
85
+ def query_milvus(query_embedding, top_k=5):
86
+ collection_name = "pdf_embeddings"
87
+ search_params = {"metric_type": "L2", "params": {"nprobe": 16}}
88
+ results = milvus.search(collection_name, [query_embedding], top_k, search_params)
89
+ return results
90
+
91
+ # Function to generate response using Llama
92
+ # Update generate_response function to use the RAG pipeline
93
+ def generate_response(query, context):
94
+ prompt = PromptTemplate.from_template(
95
+ "Context information is below.\n---------------------\n{context}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {question}\nAnswer:\n"
96
+ )
97
+
98
+ rag_chain = (
99
+ {"context": context, "question": query}
100
+ | prompt
101
+ | llm
102
+ | StrOutputParser()
103
+ )
104
+
105
+ return rag_chain.invoke(query)
106
+
107
+ # Main function
108
+ def main(pdf_path, query):
109
+ # Step 1: Extract text from PDF
110
+ text = extract_text_from_pdf(pdf_path)
111
+
112
+ # Step 2: Generate embeddings for the text
113
+ embeddings = generate_embeddings(text)
114
+
115
+ # Step 3: Insert embeddings into Milvus
116
+ insert_into_milvus(embeddings)
117
+
118
+ # Step 4: Generate embeddings for the query
119
+ query_embedding = generate_embeddings(query)
120
+
121
+ # Step 5: Query Milvus for similar embeddings
122
+ results = query_milvus(query_embedding)
123
+
124
+ # Step 6: Generate response using Llama
125
+ context = " ".join([result for result in results])
126
+ response = generate_response(query, context)
127
+
128
+ print(response)
129
+
130
+ if __name__ == "__main__":
131
+ # Initialize Gradio interface
132
+ def ask_question(question):
133
+ pdf_path = "test.pdf"
134
+ main(pdf_path, question)
135
+
136
+ iface = gr.Interface(fn=ask_question, inputs="text", outputs="text")
137
+ iface.launch()