ADKU commited on
Commit
7ef58b7
·
verified ·
1 Parent(s): 23c923c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -103
app.py CHANGED
@@ -8,6 +8,8 @@ import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
9
  import google.generativeai as genai
10
  import logging
 
 
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
@@ -64,8 +66,8 @@ except Exception as e:
64
  logger.error(f"Model loading failed: {e}")
65
  raise
66
 
67
- # Generate SciBERT embeddings
68
- def generate_embeddings_sci_bert(texts, batch_size=32):
69
  try:
70
  all_embeddings = []
71
  for i in range(0, len(texts), batch_size):
@@ -94,7 +96,7 @@ except Exception as e:
94
  logger.error(f"FAISS index creation failed: {e}")
95
  raise
96
 
97
- # Hybrid search function (return indices instead of truncated strings)
98
  def get_relevant_papers(query):
99
  if not query.strip():
100
  return [], "Please enter a search query."
@@ -106,73 +108,127 @@ def get_relevant_papers(query):
106
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
107
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
108
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
109
- # Return formatted strings for dropdown and indices for full data
110
  papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
111
  return papers, ranked_results[:5], "Search completed."
112
  except Exception as e:
113
  logger.error(f"Search failed: {e}")
114
  return [], [], "Search failed. Please try again."
115
 
116
- # Gemini API QA function with full context
117
- def answer_question(selected_index, question, history):
118
- if selected_index is None:
119
- return [(question, "Please select a paper first!")], history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if not question.strip():
121
  return [(question, "Please ask a question!")], history
122
  if question.lower() in ["exit", "done"]:
123
- return [("Conversation ended.", "Select a new paper or search again!")], []
124
 
125
  try:
126
- # Get full paper data from DataFrame using index
127
- paper_data = df.iloc[selected_index]
128
- title = paper_data["title"]
129
- abstract = paper_data["abstract"] # Full abstract, not truncated
130
- authors = ", ".join(paper_data["authors"])
131
- doi = paper_data["doi"]
132
-
133
- # Build prompt with all fields
134
- prompt = (
135
- "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
136
- "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
137
- "When asked about tech stacks or methods, follow these guidelines:\n"
138
- "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
139
- "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
140
- "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
141
- "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
142
- "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
143
- "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
144
- "Here’s the paper:\n"
145
- f"Title: {title}\n"
146
- f"Authors: {authors}\n"
147
- f"Abstract: {abstract}\n"
148
- f"DOI: {doi}\n\n"
149
- )
150
-
151
- # Add history if present
152
- if history:
153
- prompt += "Previous conversation (use for context):\n"
154
- for user_q, bot_a in history[-2:]:
155
- prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
156
-
157
- prompt += f"Now, answer this question: {question}"
158
-
159
- logger.info(f"Prompt sent to Gemini API: {prompt[:200]}...")
160
-
161
- # Call Gemini API (Gemini 1.5 Flash)
162
- model = genai.GenerativeModel("gemini-1.5-flash")
163
- response = model.generate_content(prompt)
164
- answer = response.text.strip()
165
-
166
- # Fallback for poor responses
167
- if not answer or len(answer) < 15:
168
- answer = (
169
- "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
170
- "- Python: Core language for ML/DL.\n"
171
- "- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
172
- "- Scikit-learn: For traditional ML algorithms.\n"
173
- "- Pandas/NumPy: For data handling and preprocessing."
 
 
174
  )
175
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  history.append((question, answer))
177
  return history, history
178
  except Exception as e:
@@ -183,70 +239,106 @@ def answer_question(selected_index, question, history):
183
  # Gradio UI
184
  with gr.Blocks(
185
  css="""
186
- .chatbot {height: 600px; overflow-y: auto;}
187
- .sidebar {width: 300px;}
188
- #main {display: flex; flex-direction: row;}
 
 
 
 
189
  """,
190
- theme=gr.themes.Default(primary_hue="blue")
191
  ) as demo:
192
- gr.Markdown("# ResearchGPT - Paper Search & Chat")
 
193
  with gr.Row(elem_id="main"):
194
- # Sidebar for search
195
- with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
196
- gr.Markdown("### Search Papers")
197
- query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
198
- search_btn = gr.Button("Search")
199
- paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
200
- search_status = gr.Textbox(label="Search Status", interactive=False)
201
-
202
- # States to store paper choices and indices
203
- paper_choices_state = gr.State([])
204
- paper_indices_state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- search_btn.click(
207
- fn=get_relevant_papers,
208
- inputs=query_input,
209
- outputs=[paper_choices_state, paper_indices_state, search_status]
210
- ).then(
211
- fn=lambda choices: gr.update(choices=choices, value=None),
212
- inputs=paper_choices_state,
213
- outputs=paper_dropdown
214
- )
215
-
216
  # Main chat area
217
- with gr.Column(scale=3):
218
- gr.Markdown("### Chat with Selected Paper")
219
- selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
220
  chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
221
  question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
222
  chat_btn = gr.Button("Send")
223
 
224
- # State to store conversation history and selected index
225
  history_state = gr.State([])
226
  selected_index_state = gr.State(None)
227
 
228
- # Update selected paper and index
229
- def update_selected_paper(choice, indices):
230
- if choice is None:
231
- return "", None
232
- index = int(choice.split(".")[0]) - 1 # Extract rank (e.g., "1." -> 0)
233
- selected_idx = indices[index]
234
- return choice, selected_idx
 
 
 
 
235
 
236
- paper_dropdown.change(
237
- fn=update_selected_paper,
238
- inputs=[paper_dropdown, paper_indices_state],
239
- outputs=[selected_paper, selected_index_state]
 
 
 
 
 
240
  ).then(
241
  fn=lambda: [],
242
  inputs=None,
243
- outputs=chatbot
244
  )
245
-
246
- # Handle chat
 
 
 
 
 
247
  chat_btn.click(
248
  fn=answer_question,
249
- inputs=[selected_index_state, question_input, history_state],
250
  outputs=[chatbot, history_state]
251
  ).then(
252
  fn=lambda: "",
 
8
  from transformers import AutoTokenizer, AutoModel
9
  import google.generativeai as genai
10
  import logging
11
+ from PyPDF2 import PdfReader
12
+ import io
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
 
66
  logger.error(f"Model loading failed: {e}")
67
  raise
68
 
69
+ # Generate SciBERT embeddings (optimized with larger batch size)
70
+ def generate_embeddings_sci_bert(texts, batch_size=64): # Increased batch size for efficiency
71
  try:
72
  all_embeddings = []
73
  for i in range(0, len(texts), batch_size):
 
96
  logger.error(f"FAISS index creation failed: {e}")
97
  raise
98
 
99
+ # Hybrid search function (unchanged from original)
100
  def get_relevant_papers(query):
101
  if not query.strip():
102
  return [], "Please enter a search query."
 
108
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
109
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
110
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
 
111
  papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
112
  return papers, ranked_results[:5], "Search completed."
113
  except Exception as e:
114
  logger.error(f"Search failed: {e}")
115
  return [], [], "Search failed. Please try again."
116
 
117
+ # Process uploaded PDF for RAG
118
+ def process_uploaded_pdf(file):
119
+ try:
120
+ pdf_reader = PdfReader(file)
121
+ text = ""
122
+ for page in pdf_reader.pages:
123
+ text += page.extract_text() or ""
124
+ cleaned_text = clean_text(text)
125
+ chunks = [cleaned_text[i:i+1000] for i in range(0, len(cleaned_text), 1000)] # Chunk for efficiency
126
+ embeddings = generate_embeddings_sci_bert(chunks)
127
+ faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
128
+ faiss_index.add(embeddings.astype(np.float32))
129
+ tokenized_chunks = [chunk.split() for chunk in chunks]
130
+ bm25_rag = BM25Okapi(tokenized_chunks)
131
+ return {"chunks": chunks, "embeddings": embeddings, "faiss_index": faiss_index, "bm25": bm25_rag}, "Document processed successfully"
132
+ except Exception as e:
133
+ logger.error(f"PDF processing failed: {e}")
134
+ return None, "Failed to process document"
135
+
136
+ # Hybrid search for RAG
137
+ def get_relevant_chunks(query, uploaded_doc):
138
+ if not query.strip():
139
+ return [], "Please enter a question."
140
+ try:
141
+ query_embedding = generate_embeddings_sci_bert([query])
142
+ distances, indices = uploaded_doc["faiss_index"].search(query_embedding.astype(np.float32), 3)
143
+ bm25_scores = uploaded_doc["bm25"].get_scores(query.split())
144
+ combined_indices = list(set(indices[0]) | set(np.argsort(bm25_scores)[::-1][:3]))
145
+ ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
146
+ return [uploaded_doc["chunks"][idx] for idx in ranked_results[:3]], "Retrieval completed."
147
+ except Exception as e:
148
+ logger.error(f"RAG retrieval failed: {e}")
149
+ return [], "Retrieval failed."
150
+
151
+ # Unified QA function
152
+ def answer_question(mode, selected_index, question, history, uploaded_doc=None):
153
  if not question.strip():
154
  return [(question, "Please ask a question!")], history
155
  if question.lower() in ["exit", "done"]:
156
+ return [("Conversation ended.", "Start a new conversation!")], []
157
 
158
  try:
159
+ if mode == "research":
160
+ if selected_index is None:
161
+ return [(question, "Please select a paper first!")], history
162
+ paper_data = df.iloc[selected_index]
163
+ title = paper_data["title"]
164
+ abstract = paper_data["abstract"]
165
+ authors = ", ".join(paper_data["authors"])
166
+ doi = paper_data["doi"]
167
+ prompt = (
168
+ "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
169
+ "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
170
+ "When asked about tech stacks or methods, follow these guidelines:\n"
171
+ "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
172
+ "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
173
+ "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
174
+ "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
175
+ "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
176
+ "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
177
+ "Here’s the paper:\n"
178
+ f"Title: {title}\n"
179
+ f"Authors: {authors}\n"
180
+ f"Abstract: {abstract}\n"
181
+ f"DOI: {doi}\n\n"
182
+ )
183
+ if history:
184
+ prompt += "Previous conversation (use for context):\n"
185
+ for user_q, bot_a in history[-2:]:
186
+ prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
187
+ prompt += f"Now, answer this question: {question}"
188
+ model = genai.GenerativeModel("gemini-1.5-flash")
189
+ response = model.generate_content(prompt)
190
+ answer = response.text.strip()
191
+ if not answer or len(answer) < 15:
192
+ answer = (
193
+ "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
194
+ "- Python: Core language for ML/DL.\n"
195
+ "- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
196
+ "- Scikit-learn: For traditional ML algorithms.\n"
197
+ "- Pandas/NumPy: For data handling and preprocessing."
198
+ )
199
+
200
+ elif mode == "rag":
201
+ if uploaded_doc is None:
202
+ return [(question, "Please upload a document first!")], history
203
+ relevant_chunks, _ = get_relevant_chunks(question, uploaded_doc)
204
+ context = "\n".join(relevant_chunks)
205
+ prompt = (
206
+ "You are an expert AI assistant specializing in answering questions based on uploaded documents. "
207
+ "Provide concise, accurate answers based on the following document content:\n"
208
+ f"Content: {context}\n\n"
209
  )
210
+ if history:
211
+ prompt += "Previous conversation (use for context):\n"
212
+ for user_q, bot_a in history[-2:]:
213
+ prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
214
+ prompt += f"Now, answer this question: {question}"
215
+ model = genai.GenerativeModel("gemini-1.5-flash")
216
+ response = model.generate_content(prompt)
217
+ answer = response.text.strip()
218
+
219
+ else: # general mode
220
+ prompt = (
221
+ "You are a highly knowledgeable AI assistant. Answer the following question concisely and accurately:\n"
222
+ )
223
+ if history:
224
+ prompt += "Previous conversation (use for context):\n"
225
+ for user_q, bot_a in history[-2:]:
226
+ prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
227
+ prompt += f"Question: {question}"
228
+ model = genai.GenerativeModel("gemini-1.5-flash")
229
+ response = model.generate_content(prompt)
230
+ answer = response.text.strip()
231
+
232
  history.append((question, answer))
233
  return history, history
234
  except Exception as e:
 
239
  # Gradio UI
240
  with gr.Blocks(
241
  css="""
242
+ .chatbot {height: 500px; overflow-y: auto; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
243
+ .sidebar {width: 350px; padding: 15px; background: #f8f9fa; border-radius: 10px;}
244
+ #main {display: flex; flex-direction: row; gap: 20px; padding: 20px;}
245
+ .tab-content {padding: 20px; background: #ffffff; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
246
+ .gr-button {background: #007bff; color: white; border-radius: 5px; transition: background 0.3s;}
247
+ .gr-button:hover {background: #0056b3;}
248
+ h1 {color: #007bff; text-align: center; margin-bottom: 20px;}
249
  """,
250
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")
251
  ) as demo:
252
+ gr.Markdown("# Triad: ResearchGPT, RAG, & General Chat")
253
+
254
  with gr.Row(elem_id="main"):
255
+ # Sidebar
256
+ with gr.Column(scale=1, min_width=350, elem_classes="sidebar"):
257
+ mode_tabs = gr.Tabs()
258
+ with mode_tabs:
259
+ # Research Mode (unchanged backend)
260
+ with gr.TabItem("Research Mode"):
261
+ gr.Markdown("### Search Papers")
262
+ query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
263
+ search_btn = gr.Button("Search")
264
+ paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
265
+ search_status = gr.Textbox(label="Search Status", interactive=False)
266
+ paper_choices_state = gr.State([])
267
+ paper_indices_state = gr.State([])
268
+
269
+ search_btn.click(
270
+ fn=get_relevant_papers,
271
+ inputs=query_input,
272
+ outputs=[paper_choices_state, paper_indices_state, search_status]
273
+ ).then(
274
+ fn=lambda choices: gr.update(choices=choices, value=None),
275
+ inputs=paper_choices_state,
276
+ outputs=paper_dropdown
277
+ )
278
+
279
+ # RAG Mode
280
+ with gr.TabItem("RAG Mode"):
281
+ gr.Markdown("### Upload Document")
282
+ file_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
283
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
284
+ uploaded_doc_state = gr.State(None)
285
+ file_upload.change(
286
+ fn=process_uploaded_pdf,
287
+ inputs=file_upload,
288
+ outputs=[uploaded_doc_state, upload_status]
289
+ )
290
+
291
+ # General Mode
292
+ with gr.TabItem("General Chat"):
293
+ gr.Markdown("Ask anything, powered by Gemini!")
294
 
 
 
 
 
 
 
 
 
 
 
295
  # Main chat area
296
+ with gr.Column(scale=3, elem_classes="tab-content"):
297
+ gr.Markdown("### Chat Area")
298
+ selected_display = gr.Markdown(label="Selected Context", value="Select a mode to begin!")
299
  chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
300
  question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
301
  chat_btn = gr.Button("Send")
302
 
 
303
  history_state = gr.State([])
304
  selected_index_state = gr.State(None)
305
 
306
+ def update_display(mode, choice, indices, uploaded_doc):
307
+ if mode == "research" and choice:
308
+ index = int(choice.split(".")[0]) - 1
309
+ selected_idx = indices[index]
310
+ paper = df.iloc[selected_idx]
311
+ return f"**{paper['title']}**<br>DOI: [{paper['doi']}](https://doi.org/{paper['doi']})", selected_idx
312
+ elif mode == "rag" and uploaded_doc:
313
+ return "Uploaded Document Ready", None
314
+ elif mode == "general":
315
+ return "General Chat Mode", None
316
+ return "Select a mode to begin!", None
317
 
318
+ mode_tabs.select(
319
+ fn=lambda tab: ("research" if tab == "Research Mode" else "rag" if tab == "RAG Mode" else "general"),
320
+ inputs=None,
321
+ outputs=None,
322
+ _js="tab => tab"
323
+ ).then(
324
+ fn=update_display,
325
+ inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
326
+ outputs=[selected_display, selected_index_state]
327
  ).then(
328
  fn=lambda: [],
329
  inputs=None,
330
+ outputs=[chatbot, history_state]
331
  )
332
+
333
+ paper_dropdown.change(
334
+ fn=update_display,
335
+ inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
336
+ outputs=[selected_display, selected_index_state]
337
+ )
338
+
339
  chat_btn.click(
340
  fn=answer_question,
341
+ inputs=[mode_tabs, selected_index_state, question_input, history_state, uploaded_doc_state],
342
  outputs=[chatbot, history_state]
343
  ).then(
344
  fn=lambda: "",