sahanes commited on
Commit
7b695f2
·
verified ·
1 Parent(s): 9470817

Upload main.py

Browse files
Files changed (1) hide show
  1. app/main.py +332 -0
app/main.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ from fastapi import FastAPI, UploadFile, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ import PyPDF2
6
+ import openai
7
+ import numpy as np
8
+ import faiss
9
+ import tiktoken
10
+ from typing import List
11
+ import io
12
+ from dotenv import load_dotenv
13
+ import os
14
+
15
+ app = FastAPI()
16
+
17
+ # Add CORS middleware
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+ # In-memory storage
27
+
28
+
29
+ class DocumentStore:
30
+ def __init__(self):
31
+ self.documents: List[str] = []
32
+ self.embeddings = None
33
+ self.index = None
34
+
35
+ def reset(self):
36
+ self.documents = []
37
+ self.embeddings = None
38
+ self.index = None
39
+
40
+
41
+ doc_store = DocumentStore()
42
+
43
+
44
+ class Question(BaseModel):
45
+ text: str
46
+
47
+
48
+ def get_embedding(text: str) -> List[float]:
49
+ response = openai.embeddings.create(
50
+ model="text-embedding-3-small",
51
+ input=text
52
+ )
53
+ return response.data[0].embedding
54
+
55
+
56
+ def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
57
+ words = text.split()
58
+ chunks = []
59
+ current_chunk = []
60
+ current_size = 0
61
+
62
+ for word in words:
63
+ current_chunk.append(word)
64
+ current_size += len(word) + 1
65
+
66
+ if current_size >= chunk_size:
67
+ chunks.append(" ".join(current_chunk))
68
+ current_chunk = []
69
+ current_size = 0
70
+
71
+ if current_chunk:
72
+ chunks.append(" ".join(current_chunk))
73
+
74
+ return chunks
75
+
76
+
77
+ @app.post("/upload")
78
+ async def upload_pdf(file: UploadFile):
79
+ if not file.filename.endswith('.pdf'):
80
+ raise HTTPException(status_code=400, detail="File must be a PDF")
81
+
82
+ try:
83
+ # Reset the document store
84
+ doc_store.reset()
85
+
86
+ # Read PDF content
87
+ content = await file.read()
88
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(content))
89
+ text = ""
90
+ for page in pdf_reader.pages:
91
+ text += page.extract_text()
92
+
93
+ # Chunk the text
94
+ chunks = chunk_text(text)
95
+ doc_store.documents = chunks
96
+
97
+ # Create embeddings
98
+ embeddings = [get_embedding(chunk) for chunk in chunks]
99
+ doc_store.embeddings = np.array(embeddings, dtype=np.float32)
100
+
101
+ # Create FAISS index
102
+ dimension = len(embeddings[0])
103
+ doc_store.index = faiss.IndexFlatL2(dimension)
104
+ doc_store.index.add(doc_store.embeddings)
105
+
106
+ return {"message": "PDF processed successfully", "chunks": len(chunks)}
107
+
108
+ except Exception as e:
109
+ raise HTTPException(status_code=500, detail=str(e))
110
+
111
+
112
+ @app.post("/ask")
113
+ async def ask_question(question: Question):
114
+ if not doc_store.index:
115
+ raise HTTPException(
116
+ status_code=400, detail="No document has been uploaded yet")
117
+
118
+ try:
119
+ # Get question embedding
120
+ question_embedding = get_embedding(question.text)
121
+
122
+ # Search similar chunks
123
+ k = 10 # Number of relevant chunks to retrieve
124
+ D, I = doc_store.index.search(
125
+ np.array([question_embedding], dtype=np.float32), k)
126
+
127
+ # Get relevant chunks
128
+ relevant_chunks = [doc_store.documents[i] for i in I[0]]
129
+ print(relevant_chunks)
130
+
131
+ # Create prompt
132
+ prompt = f"""Based on the following context, please answer the question.
133
+ If the answer cannot be found in the context, say "I cannot find the answer in the document." You may also use the context to infer information that is not explicitly stated in the context. For example, if the context does not explicitly state what the paper is about, you may infer that the paper is about the topic of the question or the retrieved context.
134
+ Context:
135
+ {' '.join(relevant_chunks)}
136
+ Question: {question.text}
137
+ """
138
+
139
+ # Get response from OpenAI
140
+ response = openai.chat.completions.create(
141
+ model="gpt-4o-mini",
142
+ messages=[
143
+ {"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
144
+ {"role": "user", "content": prompt}
145
+ ]
146
+ )
147
+
148
+ return {"answer": response.choices[0].message.content}
149
+
150
+ except Exception as e:
151
+ raise HTTPException(status_code=500, detail=str(e))
152
+
153
+ # Configure OpenAI API key
154
+ load_dotenv()
155
+ openai.api_key = os.getenv("OPENAI_API_KEY")
156
+
157
+ if __name__ == "__main__":
158
+ import uvicorn
159
+ uvicorn.run(
160
+ "main:app",
161
+ host="0.0.0.0",
162
+ port=8000,
163
+ reload=True,
164
+ log_level="info",
165
+ workers=1
166
+ =======
167
+ from fastapi import FastAPI, UploadFile, HTTPException
168
+ from fastapi.middleware.cors import CORSMiddleware
169
+ from pydantic import BaseModel
170
+ import PyPDF2
171
+ import openai
172
+ import numpy as np
173
+ import faiss
174
+ import tiktoken
175
+ from typing import List
176
+ import io
177
+ from dotenv import load_dotenv
178
+ import os
179
+
180
+ app = FastAPI()
181
+
182
+ # Add CORS middleware
183
+ app.add_middleware(
184
+ CORSMiddleware,
185
+ allow_origins=["*"],
186
+ allow_credentials=True,
187
+ allow_methods=["*"],
188
+ allow_headers=["*"],
189
+ )
190
+
191
+ # In-memory storage
192
+
193
+
194
+ class DocumentStore:
195
+ def __init__(self):
196
+ self.documents: List[str] = []
197
+ self.embeddings = None
198
+ self.index = None
199
+
200
+ def reset(self):
201
+ self.documents = []
202
+ self.embeddings = None
203
+ self.index = None
204
+
205
+
206
+ doc_store = DocumentStore()
207
+
208
+
209
+ class Question(BaseModel):
210
+ text: str
211
+
212
+
213
+ def get_embedding(text: str) -> List[float]:
214
+ response = openai.embeddings.create(
215
+ model="text-embedding-3-small",
216
+ input=text
217
+ )
218
+ return response.data[0].embedding
219
+
220
+
221
+ def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
222
+ words = text.split()
223
+ chunks = []
224
+ current_chunk = []
225
+ current_size = 0
226
+
227
+ for word in words:
228
+ current_chunk.append(word)
229
+ current_size += len(word) + 1
230
+
231
+ if current_size >= chunk_size:
232
+ chunks.append(" ".join(current_chunk))
233
+ current_chunk = []
234
+ current_size = 0
235
+
236
+ if current_chunk:
237
+ chunks.append(" ".join(current_chunk))
238
+
239
+ return chunks
240
+
241
+
242
+ @app.post("/upload")
243
+ async def upload_pdf(file: UploadFile):
244
+ if not file.filename.endswith('.pdf'):
245
+ raise HTTPException(status_code=400, detail="File must be a PDF")
246
+
247
+ try:
248
+ # Reset the document store
249
+ doc_store.reset()
250
+
251
+ # Read PDF content
252
+ content = await file.read()
253
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(content))
254
+ text = ""
255
+ for page in pdf_reader.pages:
256
+ text += page.extract_text()
257
+
258
+ # Chunk the text
259
+ chunks = chunk_text(text)
260
+ doc_store.documents = chunks
261
+
262
+ # Create embeddings
263
+ embeddings = [get_embedding(chunk) for chunk in chunks]
264
+ doc_store.embeddings = np.array(embeddings, dtype=np.float32)
265
+
266
+ # Create FAISS index
267
+ dimension = len(embeddings[0])
268
+ doc_store.index = faiss.IndexFlatL2(dimension)
269
+ doc_store.index.add(doc_store.embeddings)
270
+
271
+ return {"message": "PDF processed successfully", "chunks": len(chunks)}
272
+
273
+ except Exception as e:
274
+ raise HTTPException(status_code=500, detail=str(e))
275
+
276
+
277
+ @app.post("/ask")
278
+ async def ask_question(question: Question):
279
+ if not doc_store.index:
280
+ raise HTTPException(
281
+ status_code=400, detail="No document has been uploaded yet")
282
+
283
+ try:
284
+ # Get question embedding
285
+ question_embedding = get_embedding(question.text)
286
+
287
+ # Search similar chunks
288
+ k = 10 # Number of relevant chunks to retrieve
289
+ D, I = doc_store.index.search(
290
+ np.array([question_embedding], dtype=np.float32), k)
291
+
292
+ # Get relevant chunks
293
+ relevant_chunks = [doc_store.documents[i] for i in I[0]]
294
+ print(relevant_chunks)
295
+
296
+ # Create prompt
297
+ prompt = f"""Based on the following context, please answer the question.
298
+ If the answer cannot be found in the context, say "I cannot find the answer in the document." You may also use the context to infer information that is not explicitly stated in the context. For example, if the context does not explicitly state what the paper is about, you may infer that the paper is about the topic of the question or the retrieved context.
299
+ Context:
300
+ {' '.join(relevant_chunks)}
301
+ Question: {question.text}
302
+ """
303
+
304
+ # Get response from OpenAI
305
+ response = openai.chat.completions.create(
306
+ model="gpt-4o-mini",
307
+ messages=[
308
+ {"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
309
+ {"role": "user", "content": prompt}
310
+ ]
311
+ )
312
+
313
+ return {"answer": response.choices[0].message.content}
314
+
315
+ except Exception as e:
316
+ raise HTTPException(status_code=500, detail=str(e))
317
+
318
+ # Configure OpenAI API key
319
+ load_dotenv()
320
+ openai.api_key = os.getenv("OPENAI_API_KEY")
321
+
322
+ if __name__ == "__main__":
323
+ import uvicorn
324
+ uvicorn.run(
325
+ "main:app",
326
+ host="0.0.0.0",
327
+ port=8000,
328
+ reload=True,
329
+ log_level="info",
330
+ workers=1
331
+ >>>>>>> 467f9179b7ec187f353f256c52c2ae9e8be701b2
332
+ )