Nithin1905 commited on
Commit
eed1c5d
1 Parent(s): 593935c

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +330 -0
main.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import sys
4
+ import psycopg2
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from pinecone import Pinecone, ServerlessSpec
7
+ import time
8
+ from openai import OpenAI
9
+ import fitz
10
+ from dotenv import load_dotenv
11
+ from fastapi import FastAPI, HTTPException, UploadFile, File
12
+ from pydantic import BaseModel
13
+ import hashlib
14
+ import cohere
15
+ import asyncio # Import asyncio for asynchronous operations
16
+
17
+ load_dotenv()
18
+
19
+ TOGETHER_API_KEY = os.getenv("TOGETHER_API")
20
+ COHERE_API = os.getenv("COHERE_API")
21
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
22
+ HELICON_API_KEY = os.getenv("HELICON_API_KEY")
23
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
24
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
25
+
26
+ app = FastAPI()
27
+
28
+ SysPromptDefault = "You are now in the role of an expert AI."
29
+ SummaryTextPrompt = "You are an assistant tasked with summarizing TEXT for retrieval. These summaries will be embedded and used to retrieve the raw text elements. Give a concise summary of the TEXT that is well optimized for retrieval."
30
+ GenerationPrompt = "You are in the role of an expert AI whose task is to give ANSWER to the user's QUESTION based on the provided CONTEXT. Fully rely on CONTEXT; you can't also use your own intelligence too. The summary should be less than 300 words for each QUESTION. You must respond in markdown format; don't use big headings."
31
+
32
+
33
+ class QuestionInput(BaseModel):
34
+ query: str
35
+
36
+
37
+ # Global in-memory storage (consider using a proper database or caching mechanism for production)
38
+ file_store = {}
39
+
40
+
41
+ def pinecone_server():
42
+ pc = Pinecone(api_key=PINECONE_API_KEY)
43
+ index_name = 'law-compliance'
44
+ if index_name not in pc.list_indexes().names():
45
+ pc.create_index(
46
+ index_name,
47
+ dimension=1024,
48
+ metric='cosine',
49
+ spec=ServerlessSpec(
50
+ cloud='aws',
51
+ region='us-east-1'
52
+ )
53
+ )
54
+ time.sleep(1)
55
+ index = pc.Index(index_name)
56
+ index.describe_index_stats()
57
+ return index
58
+
59
+
60
+ def extract_text_from_pdf(pdf_path):
61
+ doc = fitz.open(pdf_path)
62
+ texts = []
63
+
64
+ for page_number in range(len(doc)):
65
+ page = doc.load_page(page_number)
66
+ text = page.get_text()
67
+ texts.append(text)
68
+
69
+ doc.close()
70
+
71
+ return texts
72
+
73
+
74
+ def split(texts):
75
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=384, chunk_overlap=10)
76
+ text = "\n".join(texts)
77
+ chunks = text_splitter.split_text(text)
78
+ return chunks
79
+
80
+
81
+ def response(message, model="llama3-8b-8192", SysPrompt=SysPromptDefault, temperature=0.2):
82
+ client = OpenAI(
83
+ api_key=GROQ_API_KEY,
84
+ base_url="https://gateway.hconeai.com/openai/v1",
85
+ default_headers={
86
+ "Helicone-Auth": f"Bearer {HELICON_API_KEY}",
87
+ "Helicone-Target-Url": "https://api.groq.com"
88
+ }
89
+ )
90
+
91
+ messages = [{"role": "system", "content": SysPrompt}, {"role": "user", "content": message}]
92
+ response = client.chat.completions.create(
93
+ model=model,
94
+ messages=messages,
95
+ temperature=temperature,
96
+ )
97
+ return response.choices[0].message.content
98
+
99
+
100
+ def generate_text_summaries(texts, summarize_texts):
101
+ text_summaries = []
102
+ if texts and summarize_texts:
103
+ message = f"TEXT:\n\n{texts}"
104
+ model = "llama3-8b-8192"
105
+ text_summaries = response(message=message, model=model, SysPrompt=SummaryTextPrompt, temperature=0)
106
+ elif texts:
107
+ text_summaries = texts
108
+
109
+ return text_summaries
110
+
111
+
112
+ def get_digest(pdf_content):
113
+ h = hashlib.sha256()
114
+ h.update(pdf_content) # Hash the binary content of the PDF
115
+ return h.hexdigest()
116
+
117
+
118
+ def fetch_vectorstore_from_db(file_id):
119
+ conn = psycopg2.connect(
120
+ dbname="postgres",
121
+ user="postgres.kstfnkkxavowoutfytoq",
122
+ password="nI20th0in3@",
123
+ host="aws-0-us-east-1.pooler.supabase.com",
124
+ port="5432"
125
+ )
126
+ cur = conn.cursor()
127
+ create_table_query = '''
128
+ CREATE TABLE IF NOT EXISTS law_research_pro (
129
+ file_id VARCHAR(255) PRIMARY KEY,
130
+ file_name VARCHAR(255),
131
+ name_space VARCHAR(255)
132
+ );
133
+ '''
134
+ cur.execute(create_table_query)
135
+ conn.commit()
136
+ fetch_query = '''
137
+ SELECT name_space
138
+ FROM law_research_pro
139
+ WHERE file_id = %s;
140
+ '''
141
+ cur.execute(fetch_query, (file_id,))
142
+ result = cur.fetchone()
143
+ cur.close()
144
+ conn.close()
145
+ if result:
146
+ return result[0]
147
+ return None
148
+
149
+
150
+ def get_next_namespace():
151
+ conn = psycopg2.connect(
152
+ dbname="postgres",
153
+ user="postgres.kstfnkkxavowoutfytoq",
154
+ password="nI20th0in3@",
155
+ host="aws-0-us-east-1.pooler.supabase.com",
156
+ port="5432"
157
+ )
158
+ cur = conn.cursor()
159
+ cur.execute("SELECT COUNT(*) FROM law_research_pro")
160
+ count = cur.fetchone()[0]
161
+ next_namespace = f"pdf-{count + 1}"
162
+ cur.close()
163
+ conn.close()
164
+ return next_namespace
165
+
166
+
167
+ def insert_data(file_id, file_name, name_space):
168
+ conn = psycopg2.connect(
169
+ dbname="postgres",
170
+ user="postgres.kstfnkkxavowoutfytoq",
171
+ password="nI20th0in3@",
172
+ host="aws-0-us-east-1.pooler.supabase.com",
173
+ port="5432"
174
+ )
175
+ cur = conn.cursor()
176
+ create_table_query = '''
177
+ CREATE TABLE IF NOT EXISTS law_research_pro (
178
+ file_id VARCHAR(255) PRIMARY KEY,
179
+ file_name VARCHAR(255),
180
+ name_space VARCHAR(255)
181
+ );
182
+ '''
183
+ cur.execute(create_table_query)
184
+ conn.commit()
185
+ insert_query = '''
186
+ INSERT INTO law_research_pro (file_id, file_name, name_space)
187
+ VALUES (%s, %s, %s)
188
+ ON CONFLICT (file_id) DO NOTHING;
189
+ '''
190
+ cur.execute(insert_query, (file_id, file_name, name_space))
191
+ conn.commit()
192
+ cur.close()
193
+ conn.close()
194
+
195
+
196
+ def create_documents(chunks, summaries):
197
+ documents = []
198
+ retrieve_contents = []
199
+
200
+ for e, s in zip(chunks, summaries):
201
+ i = str(uuid.uuid4())
202
+ doc = {
203
+ 'page_content': s,
204
+ 'metadata': {
205
+ 'id': i,
206
+ 'type': 'text',
207
+ 'original_content': e
208
+ }
209
+ }
210
+ retrieve_contents.append((i, e))
211
+ documents.append(doc)
212
+
213
+ return documents, retrieve_contents
214
+
215
+
216
+ async def embed_and_upsert(documents, cohere_api_key, name_space):
217
+ cohere_client = cohere.Client(cohere_api_key)
218
+ summaries = [doc['page_content'] for doc in documents]
219
+ pinecone_index = pinecone_server()
220
+ embeddings = await cohere_client.embed(
221
+ texts=summaries,
222
+ input_type='search_document',
223
+ model="embed-english-v3.0"
224
+ ).embeddings
225
+
226
+ pinecone_data = []
227
+ for doc, embedding in zip(documents, embeddings):
228
+ pinecone_data.append({
229
+ 'id': doc['metadata']['id'],
230
+ 'values': embedding,
231
+ 'metadata': doc['metadata']
232
+ })
233
+
234
+ pinecone_index.upsert(vectors=pinecone_data, namespace=name_space)
235
+
236
+
237
+ async def embedding_creation(pdf_content, COHERE_API, name_space):
238
+ texts = extract_text_from_pdf(pdf_content)
239
+ chunks = split(texts)
240
+ text_summaries = generate_text_summaries(chunks, summarize_texts=False)
241
+ documents, retrieve_contents = create_documents(chunks, text_summaries)
242
+ await embed_and_upsert(documents, COHERE_API, name_space)
243
+ print("Embeddings created and upserted successfully into Pinecone.")
244
+
245
+
246
+ def embed(question):
247
+ cohere_client = cohere.Client(COHERE_API)
248
+ embeddings = cohere_client.embed(
249
+ texts=[question],
250
+ model="embed-english-v3.0",
251
+ input_type='search_query'
252
+ ).embeddings
253
+ return embeddings
254
+
255
+
256
+ def process_rerank_response(rerank_response, docs):
257
+ rerank_docs = []
258
+ for item in rerank_response.results:
259
+ index = item.index
260
+ if 0 <= index < len(docs):
261
+ rerank_docs.append(docs[index])
262
+ else:
263
+ print(f"Warning: Index {index} is out of range for documents list.")
264
+ return rerank_docs
265
+
266
+
267
+ async def get_name_space(question, pdf_content, file_name):
268
+ file_id = get_digest(pdf_content)
269
+ existing_namespace = fetch_vectorstore_from_db(file_id)
270
+
271
+ if existing_namespace:
272
+ print("Document already exists. Using existing namespace.")
273
+ name_space = existing_namespace
274
+ else:
275
+ print("Document is new. Creating embeddings and new namespace.")
276
+ name_space = get_next_namespace()
277
+ await embedding_creation(pdf_content, COHERE_API, name_space)
278
+ insert_data(file_id, file_name, name_space)
279
+ await asyncio.sleep(10) # Use asyncio.sleep instead of time.sleep
280
+
281
+ return name_space
282
+
283
+
284
+ async def get_docs(question, pdf_content, file_name):
285
+ index = pinecone_server()
286
+ co = cohere.Client(COHERE_API)
287
+ xq = embed(question)[0]
288
+ name_space = await get_name_space(question, pdf_content, file_name)
289
+ print(name_space)
290
+ res = index.query(namespace=name_space, vector=xq, top_k=5, include_metadata=True)
291
+ print(res)
292
+ docs = [x["metadata"]['original_content'] for x in res["matches"]]
293
+
294
+ if not docs:
295
+ print("No matching documents found.")
296
+ return []
297
+
298
+ results = co.rerank(query=question, documents=docs, top_n=3, model='rerank-english-v3.0')
299
+ reranked_docs = process_rerank_response(results, docs)
300
+ return reranked_docs
301
+
302
+
303
+ async def answer(question, pdf_content, file_name):
304
+ docs = await get_docs(question, pdf_content, file_name)
305
+ if not docs:
306
+ return "No relevant documents found for the given question."
307
+
308
+ context = "\n\n".join(docs)
309
+ message = f"CONTEXT:\n\n{context}\n\nQUESTION :\n\n{question}\n\nANSWER: \n"
310
+ model = "llama3-8b-8192"
311
+ output = response(message=message, model=model, SysPrompt=GenerationPrompt, temperature=0)
312
+ return output
313
+
314
+
315
+ @app.post("/ask-question")
316
+ async def ask_question(input: QuestionInput, file: UploadFile = File(...)):
317
+ if not file:
318
+ raise HTTPException(status_code=400, detail="PDF file not provided")
319
+ file_content = await file.read()
320
+ filename = file.filename
321
+
322
+ # Store the file content in the global store
323
+ file_id = get_digest(file_content)
324
+ file_store[file_id] = {
325
+ "pdf_content": file_content,
326
+ "filename": filename
327
+ }
328
+ answer_output = await answer(input.query, file_content, filename)
329
+
330
+ return {"answer": answer_output}