Shreyas094 commited on
Commit
6775be9
1 Parent(s): cad15d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +450 -60
app.py CHANGED
@@ -1,64 +1,454 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ import asyncio
3
+ import aiohttp
4
+ import logging
5
+ import math
6
+ import io
7
+ import numpy as np
8
+ from newspaper import Article
9
+ import PyPDF2
10
+ from collections import Counter
11
+ import json
12
+ from datetime import datetime
13
+ from sentence_transformers import SentenceTransformer
14
+ from rank_bm25 import BM25Okapi
15
+ from sentence_transformers.util import pytorch_cos_sim
16
+ from enum import Enum
17
+ from groq import Groq
18
+ import os
19
+ from typing import List, Dict, Any, Set
20
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Load environment variables from .env file
23
+ load_dotenv()
24
+
25
+ # Initialize Groq client
26
+ groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
27
+
28
+ class ScoringMethod(Enum):
29
+ BM25 = "bm25"
30
+ TFIDF = "tfidf"
31
+ COMBINED = "combined"
32
+
33
+ async def get_available_engines(session, base_url, headers):
34
+ """Fetch available search engines from SearxNG instance."""
35
+ try:
36
+ # First try the search endpoint to get engines
37
+ params = {
38
+ "q": "test",
39
+ "format": "json",
40
+ "engines": "all"
41
+ }
42
+ async with session.get(f"{base_url}/search", headers=headers, params=params) as response:
43
+ data = await response.json()
44
+ available_engines = set()
45
+ # Extract unique engine names from the response
46
+ if "search" in data:
47
+ for engine_data in data["search"]:
48
+ if isinstance(engine_data, dict) and "engine" in engine_data:
49
+ available_engines.add(engine_data["engine"])
50
+
51
+ # If no engines found, try alternate endpoint
52
+ if not available_engines:
53
+ async with session.get(f"{base_url}/engines", headers=headers) as response:
54
+ engines_data = await response.json()
55
+ available_engines = set(engine["name"] for engine in engines_data if engine.get("enabled", True))
56
+
57
+ return list(available_engines)
58
+ except Exception as e:
59
+ logging.error(f'Error fetching search engines: {e}')
60
+ # Return default engines if unable to fetch
61
+ return ["google", "bing", "duckduckgo", "brave", "wikipedia"]
62
+
63
+ def select_search_engines(available_engines: List[str]) -> Set[str]:
64
+ """Let user select search engines from available options."""
65
+ print("\nAvailable search engines:")
66
+ engines_list = sorted(available_engines)
67
+ for i, engine in enumerate(engines_list, 1):
68
+ print(f"{i}. {engine}")
69
+
70
+ print("\nEnter the numbers of engines you want to use (comma-separated), or 'all' for all engines:")
71
+ selection = input("Your selection: ").strip().lower()
72
+
73
+ if selection == 'all':
74
+ return set(engines_list)
75
+
76
+ try:
77
+ selected_indices = [int(idx.strip()) - 1 for idx in selection.split(',')]
78
+ return {engines_list[idx] for idx in selected_indices if 0 <= idx < len(engines_list)}
79
+ except (ValueError, IndexError):
80
+ logging.error("Invalid selection, using all engines as fallback")
81
+ return set(engines_list)
82
+
83
+
84
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
85
+
86
+ async def scrape_url(url, max_chars):
87
+ logging.info(f'Scraping URL: {url}')
88
+ if url.endswith(".pdf"):
89
+ return await scrape_pdf(url, max_chars)
90
+ else:
91
+ return await scrape_html(url, max_chars)
92
+
93
+ async def scrape_html(url, max_chars):
94
+ try:
95
+ article = Article(url)
96
+ article.download()
97
+ article.parse()
98
+ text = article.text[:max_chars]
99
+ publish_date = article.publish_date
100
+ logging.info(f'Scraped HTML content from {url}')
101
+ return {"content": text, "publish_date": publish_date.isoformat() if publish_date else None}
102
+ except Exception as e:
103
+ logging.error(f'Error scraping HTML content from {url}: {e}')
104
+ return None
105
+
106
+ async def scrape_pdf(url, max_chars):
107
+ try:
108
+ async with aiohttp.ClientSession() as session:
109
+ async with session.get(url) as response:
110
+ pdf_bytes = await response.read()
111
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
112
+ text = ""
113
+ for page_num in range(len(pdf_reader.pages)):
114
+ page = pdf_reader.pages[page_num]
115
+ text += page.extract_text()
116
+ text = text[:max_chars]
117
+ logging.info(f'Scraped PDF content from {url}')
118
+ return {"content": text, "publish_date": None}
119
+ except Exception as e:
120
+ logging.error(f'Error scraping PDF content from {url}: {e}')
121
+ return None
122
+
123
+ def normalize_scores(scores):
124
+ """Normalize scores to [0, 1] range using min-max normalization"""
125
+ if not isinstance(scores, np.ndarray):
126
+ scores = np.array(scores)
127
+
128
+ if len(scores) == 0:
129
+ return []
130
+
131
+ min_score = np.min(scores)
132
+ max_score = np.max(scores)
133
+
134
+ if max_score - min_score > 0:
135
+ normalized = (scores - min_score) / (max_score - min_score)
136
+ else:
137
+ normalized = np.ones_like(scores)
138
+
139
+ return normalized.tolist()
140
+
141
+ async def calculate_bm25(query, documents):
142
+ """Calculate BM25 scores for documents."""
143
+ try:
144
+ if not documents:
145
+ return []
146
+
147
+ bm25 = BM25Okapi([doc.split() for doc in documents])
148
+ scores = bm25.get_scores(query.split())
149
+ return normalize_scores(scores)
150
+
151
+ except Exception as e:
152
+ logging.error(f'Error calculating BM25 scores: {e}')
153
+ return [0] * len(documents)
154
+
155
+ async def calculate_tfidf(query, documents, measure="cosine"):
156
+ """Calculate TF-IDF based similarity scores."""
157
+ try:
158
+ if not documents:
159
+ return []
160
+
161
+ model = SentenceTransformer('all-MiniLM-L6-v2')
162
+ query_embedding = model.encode(query)
163
+ document_embeddings = model.encode(documents)
164
+
165
+ # Normalize embeddings
166
+ query_embedding = query_embedding / np.linalg.norm(query_embedding)
167
+ document_embeddings = document_embeddings / np.linalg.norm(document_embeddings, axis=1)[:, np.newaxis]
168
+
169
+ if measure == "cosine":
170
+ # Calculate cosine similarity
171
+ scores = np.dot(document_embeddings, query_embedding)
172
+ return normalize_scores(scores)
173
+ else:
174
+ raise ValueError("Unsupported similarity measure.")
175
+
176
+ except Exception as e:
177
+ logging.error(f'Error calculating TF-IDF scores: {e}')
178
+ return [0] * len(documents)
179
+
180
+ def combine_scores(bm25_score, tfidf_score, weights=(0.5, 0.5)):
181
+ """Combine scores using weighted average."""
182
+ return weights[0] * bm25_score + weights[1] * tfidf_score
183
+
184
+ async def get_document_scores(query, documents, scoring_method: ScoringMethod):
185
+ """Calculate document scores based on the chosen scoring method."""
186
+ if not documents:
187
+ return []
188
+
189
+ if scoring_method == ScoringMethod.BM25:
190
+ scores = await calculate_bm25(query, documents)
191
+ return [(score, 0) for score in scores]
192
+ elif scoring_method == ScoringMethod.TFIDF:
193
+ scores = await calculate_tfidf(query, documents)
194
+ return [(0, score) for score in scores]
195
+ else: # COMBINED
196
+ bm25_scores = await calculate_bm25(query, documents)
197
+ tfidf_scores = await calculate_tfidf(query, documents)
198
+ return list(zip(bm25_scores, tfidf_scores))
199
+
200
+ def get_total_score(scores, scoring_method: ScoringMethod):
201
+ """Calculate total score based on the scoring method."""
202
+ bm25_score, tfidf_score = scores
203
+ if scoring_method == ScoringMethod.BM25:
204
+ return bm25_score
205
+ elif scoring_method == ScoringMethod.TFIDF:
206
+ return tfidf_score
207
+ else: # COMBINED
208
+ return combine_scores(bm25_score, tfidf_score)
209
+
210
+ async def generate_summary(query: str, articles: List[Dict[str, Any]], temperature: float = 0.7) -> str:
211
+ """
212
+ Generate a summary of the articles using Groq's LLama 3.1 8b model.
213
+ """
214
+ try:
215
+ # Format the articles into a structured JSON string
216
+ json_input = json.dumps(articles, indent=2)
217
+
218
+ system_prompt = """You are Sentinel, a world-class AI model who is expert at searching the web and answering user's queries. You are also an expert at summarizing web pages or documents and searching for content in them."""
219
+
220
+ user_prompt = f"""
221
+ Please provide a comprehensive summary based on the following JSON input:
222
+ {json_input}
223
+
224
+ Original Query: {query}
225
+
226
+ Instructions:
227
+ 1. Analyze the query and the provided documents.
228
+ 2. Write a detailed, long, and complete research document that is informative and relevant to the user's query based on provided context.
229
+ 3. Use this context to answer the user's query in the best way possible. Use an unbiased and journalistic tone.
230
+ 4. Use an unbiased and professional tone in your response.
231
+ 5. Do not repeat text verbatim from the input.
232
+ 6. Provide the answer in the response itself.
233
+ 7. Use markdown to format your response.
234
+ 8. Use bullet points to list information where appropriate.
235
+ 9. Cite the answer using [number] notation along with the appropriate source URL embedded in the notation.
236
+ 10. Place these citations at the end of the relevant sentences.
237
+ 11. You can cite the same sentence multiple times if it's relevant.
238
+ 12. Make sure the answer is not short and is informative.
239
+ 13. Your response should be detailed, informative, accurate, and directly relevant to the user's query."""
240
+
241
+ messages = [
242
+ {"role": "system", "content": system_prompt},
243
+ {"role": "user", "content": user_prompt}
244
+ ]
245
+
246
+ response = groq_client.chat.completions.create(
247
+ messages=messages,
248
+ model="llama-3.1-70b-versatile", # Using LLama 3.1 8b model
249
+ max_tokens=5000,
250
+ temperature=temperature,
251
+ top_p=0.9,
252
+ presence_penalty=1.2,
253
+ stream=False
254
+ )
255
+
256
+ return response.choices[0].message.content.strip()
257
+
258
+ except Exception as e:
259
+ logging.error(f'Error generating summary: {e}')
260
+ return f"Error generating summary: {str(e)}"
261
+
262
+ class ChatBot:
263
+ def __init__(self):
264
+ self.scoring_method = ScoringMethod.COMBINED
265
+ self.num_results = 10
266
+ self.max_chars = 10000
267
+ self.score_threshold = 0.8
268
+ self.temperature = 0.1
269
+ self.history = []
270
+ self.base_url = "http://localhost:8888"
271
+ self.headers = {
272
+ "X-Searx-API-Key": "f9f07f93b37b8483aadb5ba717f556f3a4ac507b281b4ca01e6c6288aa3e3ae5"
273
+ }
274
+ # Default search engines in case we can't fetch from SearxNG
275
+ self.default_engines = ["google", "bing", "duckduckgo", "brave"]
276
+
277
+ async def get_search_results(self,
278
+ query: str,
279
+ num_results: int,
280
+ max_chars: int,
281
+ score_threshold: float,
282
+ temperature: float,
283
+ scoring_method_str: str,
284
+ selected_engines: List[str]) -> str:
285
+ try:
286
+ # Convert scoring method string to enum
287
+ scoring_method_map = {
288
+ "BM25": ScoringMethod.BM25,
289
+ "TF-IDF": ScoringMethod.TFIDF,
290
+ "Combined": ScoringMethod.COMBINED
291
+ }
292
+ self.scoring_method = scoring_method_map[scoring_method_str]
293
+
294
+ async with aiohttp.ClientSession() as session:
295
+ # Use the selected engines from the interface
296
+ logging.info(f'Using engines: {", ".join(selected_engines)}')
297
+ logging.info(f'Parameters: Results={num_results}, Chars={max_chars}, Threshold={score_threshold}, Temp={temperature}, Method={scoring_method_str}')
298
+
299
+ # Perform search
300
+ params = {
301
+ "q": query,
302
+ "format": "json",
303
+ "engines": ",".join(selected_engines),
304
+ "limit": num_results
305
+ }
306
+
307
+ try:
308
+ async with session.get(f"{self.base_url}/search", headers=self.headers, params=params) as response:
309
+ data = await response.json()
310
+ except Exception as e:
311
+ return f"Error: Could not connect to search service. Please check if SearxNG is running at {self.base_url}. Error: {str(e)}"
312
+
313
+ if "results" not in data or not data["results"]:
314
+ return "No results found."
315
+
316
+ results = data["results"][:num_results]
317
+ tasks = [scrape_url(result["url"], max_chars) for result in results]
318
+ scraped_data = await asyncio.gather(*tasks)
319
+
320
+ valid_results = [(result, article)
321
+ for result, article in zip(results, scraped_data)
322
+ if article is not None]
323
+
324
+ if not valid_results:
325
+ return "No valid articles found after scraping."
326
+
327
+ results, scraped_data = zip(*valid_results)
328
+ contents = [article["content"] for article in scraped_data]
329
+
330
+ scores = await get_document_scores(query, contents, self.scoring_method)
331
+
332
+ scored_articles = []
333
+ for i, (score_tuple, article) in enumerate(zip(scores, scraped_data)):
334
+ total_score = get_total_score(score_tuple, self.scoring_method)
335
+ if total_score >= self.score_threshold:
336
+ scored_articles.append({
337
+ "url": results[i]["url"],
338
+ "title": results[i]["title"],
339
+ "content": article["content"],
340
+ "publish_date": article["publish_date"],
341
+ "score": round(total_score, 4),
342
+ "bm25_score": round(score_tuple[0], 4),
343
+ "tfidf_score": round(score_tuple[1], 4),
344
+ "engine": results[i].get("engine", "unknown")
345
+ })
346
+
347
+ scored_articles.sort(key=lambda x: x["score"], reverse=True)
348
+ unique_articles = []
349
+ seen_content = set()
350
+
351
+ for article in scored_articles:
352
+ if article["content"] not in seen_content:
353
+ seen_content.add(article["content"])
354
+ unique_articles.append(article)
355
+
356
+ # Generate summary using Groq API
357
+ summary = await generate_summary(query, unique_articles, self.temperature)
358
+
359
+ # Format the response for chat
360
+ response = f"**Search Parameters:**\n"
361
+ response += f"- Results: {num_results}\n"
362
+ response += f"- Max Characters: {max_chars}\n"
363
+ response += f"- Score Threshold: {score_threshold}\n"
364
+ response += f"- Temperature: {temperature}\n"
365
+ response += f"- Scoring Method: {scoring_method_str}\n"
366
+ response += f"- Search Engines: {', '.join(selected_engines)}\n\n"
367
+ response += f"**Summary of Search Results:**\n\n{summary}\n\n"
368
+ response += "\n**Sources:**\n"
369
+ for i, article in enumerate(unique_articles, 1):
370
+ response += f"{i}. [{article['title']}]({article['url']}) (Score: {article['score']})\n"
371
+
372
+ return response
373
+
374
+ except Exception as e:
375
+ logging.error(f'Error in search_and_summarize: {e}')
376
+ return f"Error occurred: {str(e)}"
377
+
378
+ def chat(self,
379
+ message: str,
380
+ history: List[List[str]],
381
+ num_results: int,
382
+ max_chars: int,
383
+ score_threshold: float,
384
+ temperature: float,
385
+ scoring_method: str,
386
+ engines: List[str]) -> str:
387
+ """
388
+ Process chat messages and return responses with custom parameters.
389
+ """
390
+ # Run the async search function in the sync context
391
+ response = asyncio.run(self.get_search_results(
392
+ message,
393
+ num_results,
394
+ max_chars,
395
+ score_threshold,
396
+ temperature,
397
+ scoring_method,
398
+ engines
399
+ ))
400
+ return response
401
+
402
+ def create_gradio_interface() -> gr.Interface:
403
+ chatbot = ChatBot()
404
+
405
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
406
+ gr.Markdown("# Research Assistant")
407
+ gr.Markdown("Enter your search query, and I'll search, analyze, and summarize relevant articles for you.")
408
+
409
+ with gr.Row():
410
+ with gr.Column(scale=3):
411
+ chatbot_interface = gr.ChatInterface(
412
+ fn=chatbot.chat,
413
+ additional_inputs=[
414
+ gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of Results"),
415
+ gr.Slider(minimum=1000, maximum=50000, value=10000, step=1000, label="Max Characters per Article"),
416
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Score Threshold"),
417
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Temperature"),
418
+ gr.Radio(["BM25", "TF-IDF", "Combined"], value="Combined", label="Scoring Method"),
419
+ gr.CheckboxGroup(
420
+ choices=["google", "bing", "duckduckgo", "brave", "wikipedia"],
421
+ value=["google", "bing", "duckduckgo"],
422
+ label="Search Engines"
423
+ )
424
+ ],
425
+ examples=[
426
+ ["What are the latest developments in quantum computing?"],
427
+ ["Explain the impact of artificial intelligence on healthcare"],
428
+ ["What are the current trends in renewable energy?"]
429
+ ]
430
+ )
431
+
432
+ with gr.Column(scale=1):
433
+ gr.Markdown("### Parameter Descriptions")
434
+ gr.Markdown("""
435
+ - **Number of Results**: Number of search results to fetch
436
+ - **Max Characters**: Maximum characters to analyze per article
437
+ - **Score Threshold**: Minimum relevance score (0-1) for including articles
438
+ - **Temperature**: Controls creativity in summary generation (0=focused, 1=creative)
439
+ - **Scoring Method**: Algorithm for ranking article relevance
440
+ - BM25: Traditional keyword-based ranking
441
+ - TF-IDF: Semantic similarity-based ranking
442
+ - Combined: Balanced approach using both methods
443
+ - **Search Engines**: Select which search engines to use
444
+ """)
445
+
446
+ return demo
447
 
448
  if __name__ == "__main__":
449
+ # Configure logging
450
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
451
+
452
+ # Create and launch the interface
453
+ demo = create_gradio_interface()
454
+ demo.launch(share=True)