Spaces:
Running
Running
Shreyas094
commited on
Commit
•
6775be9
1
Parent(s):
cad15d1
Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,454 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
if val[0]:
|
22 |
-
messages.append({"role": "user", "content": val[0]})
|
23 |
-
if val[1]:
|
24 |
-
messages.append({"role": "assistant", "content": val[1]})
|
25 |
-
|
26 |
-
messages.append({"role": "user", "content": message})
|
27 |
-
|
28 |
-
response = ""
|
29 |
-
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
-
|
39 |
-
response += token
|
40 |
-
yield response
|
41 |
-
|
42 |
-
|
43 |
-
"""
|
44 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
45 |
-
"""
|
46 |
-
demo = gr.ChatInterface(
|
47 |
-
respond,
|
48 |
-
additional_inputs=[
|
49 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
50 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
51 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
52 |
-
gr.Slider(
|
53 |
-
minimum=0.1,
|
54 |
-
maximum=1.0,
|
55 |
-
value=0.95,
|
56 |
-
step=0.05,
|
57 |
-
label="Top-p (nucleus sampling)",
|
58 |
-
),
|
59 |
-
],
|
60 |
-
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import asyncio
|
3 |
+
import aiohttp
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import io
|
7 |
+
import numpy as np
|
8 |
+
from newspaper import Article
|
9 |
+
import PyPDF2
|
10 |
+
from collections import Counter
|
11 |
+
import json
|
12 |
+
from datetime import datetime
|
13 |
+
from sentence_transformers import SentenceTransformer
|
14 |
+
from rank_bm25 import BM25Okapi
|
15 |
+
from sentence_transformers.util import pytorch_cos_sim
|
16 |
+
from enum import Enum
|
17 |
+
from groq import Groq
|
18 |
+
import os
|
19 |
+
from typing import List, Dict, Any, Set
|
20 |
+
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
# Load environment variables from .env file
|
23 |
+
load_dotenv()
|
24 |
+
|
25 |
+
# Initialize Groq client
|
26 |
+
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
27 |
+
|
28 |
+
class ScoringMethod(Enum):
|
29 |
+
BM25 = "bm25"
|
30 |
+
TFIDF = "tfidf"
|
31 |
+
COMBINED = "combined"
|
32 |
+
|
33 |
+
async def get_available_engines(session, base_url, headers):
|
34 |
+
"""Fetch available search engines from SearxNG instance."""
|
35 |
+
try:
|
36 |
+
# First try the search endpoint to get engines
|
37 |
+
params = {
|
38 |
+
"q": "test",
|
39 |
+
"format": "json",
|
40 |
+
"engines": "all"
|
41 |
+
}
|
42 |
+
async with session.get(f"{base_url}/search", headers=headers, params=params) as response:
|
43 |
+
data = await response.json()
|
44 |
+
available_engines = set()
|
45 |
+
# Extract unique engine names from the response
|
46 |
+
if "search" in data:
|
47 |
+
for engine_data in data["search"]:
|
48 |
+
if isinstance(engine_data, dict) and "engine" in engine_data:
|
49 |
+
available_engines.add(engine_data["engine"])
|
50 |
+
|
51 |
+
# If no engines found, try alternate endpoint
|
52 |
+
if not available_engines:
|
53 |
+
async with session.get(f"{base_url}/engines", headers=headers) as response:
|
54 |
+
engines_data = await response.json()
|
55 |
+
available_engines = set(engine["name"] for engine in engines_data if engine.get("enabled", True))
|
56 |
+
|
57 |
+
return list(available_engines)
|
58 |
+
except Exception as e:
|
59 |
+
logging.error(f'Error fetching search engines: {e}')
|
60 |
+
# Return default engines if unable to fetch
|
61 |
+
return ["google", "bing", "duckduckgo", "brave", "wikipedia"]
|
62 |
+
|
63 |
+
def select_search_engines(available_engines: List[str]) -> Set[str]:
|
64 |
+
"""Let user select search engines from available options."""
|
65 |
+
print("\nAvailable search engines:")
|
66 |
+
engines_list = sorted(available_engines)
|
67 |
+
for i, engine in enumerate(engines_list, 1):
|
68 |
+
print(f"{i}. {engine}")
|
69 |
+
|
70 |
+
print("\nEnter the numbers of engines you want to use (comma-separated), or 'all' for all engines:")
|
71 |
+
selection = input("Your selection: ").strip().lower()
|
72 |
+
|
73 |
+
if selection == 'all':
|
74 |
+
return set(engines_list)
|
75 |
+
|
76 |
+
try:
|
77 |
+
selected_indices = [int(idx.strip()) - 1 for idx in selection.split(',')]
|
78 |
+
return {engines_list[idx] for idx in selected_indices if 0 <= idx < len(engines_list)}
|
79 |
+
except (ValueError, IndexError):
|
80 |
+
logging.error("Invalid selection, using all engines as fallback")
|
81 |
+
return set(engines_list)
|
82 |
+
|
83 |
+
|
84 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
|
85 |
+
|
86 |
+
async def scrape_url(url, max_chars):
|
87 |
+
logging.info(f'Scraping URL: {url}')
|
88 |
+
if url.endswith(".pdf"):
|
89 |
+
return await scrape_pdf(url, max_chars)
|
90 |
+
else:
|
91 |
+
return await scrape_html(url, max_chars)
|
92 |
+
|
93 |
+
async def scrape_html(url, max_chars):
|
94 |
+
try:
|
95 |
+
article = Article(url)
|
96 |
+
article.download()
|
97 |
+
article.parse()
|
98 |
+
text = article.text[:max_chars]
|
99 |
+
publish_date = article.publish_date
|
100 |
+
logging.info(f'Scraped HTML content from {url}')
|
101 |
+
return {"content": text, "publish_date": publish_date.isoformat() if publish_date else None}
|
102 |
+
except Exception as e:
|
103 |
+
logging.error(f'Error scraping HTML content from {url}: {e}')
|
104 |
+
return None
|
105 |
+
|
106 |
+
async def scrape_pdf(url, max_chars):
|
107 |
+
try:
|
108 |
+
async with aiohttp.ClientSession() as session:
|
109 |
+
async with session.get(url) as response:
|
110 |
+
pdf_bytes = await response.read()
|
111 |
+
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
|
112 |
+
text = ""
|
113 |
+
for page_num in range(len(pdf_reader.pages)):
|
114 |
+
page = pdf_reader.pages[page_num]
|
115 |
+
text += page.extract_text()
|
116 |
+
text = text[:max_chars]
|
117 |
+
logging.info(f'Scraped PDF content from {url}')
|
118 |
+
return {"content": text, "publish_date": None}
|
119 |
+
except Exception as e:
|
120 |
+
logging.error(f'Error scraping PDF content from {url}: {e}')
|
121 |
+
return None
|
122 |
+
|
123 |
+
def normalize_scores(scores):
|
124 |
+
"""Normalize scores to [0, 1] range using min-max normalization"""
|
125 |
+
if not isinstance(scores, np.ndarray):
|
126 |
+
scores = np.array(scores)
|
127 |
+
|
128 |
+
if len(scores) == 0:
|
129 |
+
return []
|
130 |
+
|
131 |
+
min_score = np.min(scores)
|
132 |
+
max_score = np.max(scores)
|
133 |
+
|
134 |
+
if max_score - min_score > 0:
|
135 |
+
normalized = (scores - min_score) / (max_score - min_score)
|
136 |
+
else:
|
137 |
+
normalized = np.ones_like(scores)
|
138 |
+
|
139 |
+
return normalized.tolist()
|
140 |
+
|
141 |
+
async def calculate_bm25(query, documents):
|
142 |
+
"""Calculate BM25 scores for documents."""
|
143 |
+
try:
|
144 |
+
if not documents:
|
145 |
+
return []
|
146 |
+
|
147 |
+
bm25 = BM25Okapi([doc.split() for doc in documents])
|
148 |
+
scores = bm25.get_scores(query.split())
|
149 |
+
return normalize_scores(scores)
|
150 |
+
|
151 |
+
except Exception as e:
|
152 |
+
logging.error(f'Error calculating BM25 scores: {e}')
|
153 |
+
return [0] * len(documents)
|
154 |
+
|
155 |
+
async def calculate_tfidf(query, documents, measure="cosine"):
|
156 |
+
"""Calculate TF-IDF based similarity scores."""
|
157 |
+
try:
|
158 |
+
if not documents:
|
159 |
+
return []
|
160 |
+
|
161 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
162 |
+
query_embedding = model.encode(query)
|
163 |
+
document_embeddings = model.encode(documents)
|
164 |
+
|
165 |
+
# Normalize embeddings
|
166 |
+
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
167 |
+
document_embeddings = document_embeddings / np.linalg.norm(document_embeddings, axis=1)[:, np.newaxis]
|
168 |
+
|
169 |
+
if measure == "cosine":
|
170 |
+
# Calculate cosine similarity
|
171 |
+
scores = np.dot(document_embeddings, query_embedding)
|
172 |
+
return normalize_scores(scores)
|
173 |
+
else:
|
174 |
+
raise ValueError("Unsupported similarity measure.")
|
175 |
+
|
176 |
+
except Exception as e:
|
177 |
+
logging.error(f'Error calculating TF-IDF scores: {e}')
|
178 |
+
return [0] * len(documents)
|
179 |
+
|
180 |
+
def combine_scores(bm25_score, tfidf_score, weights=(0.5, 0.5)):
|
181 |
+
"""Combine scores using weighted average."""
|
182 |
+
return weights[0] * bm25_score + weights[1] * tfidf_score
|
183 |
+
|
184 |
+
async def get_document_scores(query, documents, scoring_method: ScoringMethod):
|
185 |
+
"""Calculate document scores based on the chosen scoring method."""
|
186 |
+
if not documents:
|
187 |
+
return []
|
188 |
+
|
189 |
+
if scoring_method == ScoringMethod.BM25:
|
190 |
+
scores = await calculate_bm25(query, documents)
|
191 |
+
return [(score, 0) for score in scores]
|
192 |
+
elif scoring_method == ScoringMethod.TFIDF:
|
193 |
+
scores = await calculate_tfidf(query, documents)
|
194 |
+
return [(0, score) for score in scores]
|
195 |
+
else: # COMBINED
|
196 |
+
bm25_scores = await calculate_bm25(query, documents)
|
197 |
+
tfidf_scores = await calculate_tfidf(query, documents)
|
198 |
+
return list(zip(bm25_scores, tfidf_scores))
|
199 |
+
|
200 |
+
def get_total_score(scores, scoring_method: ScoringMethod):
|
201 |
+
"""Calculate total score based on the scoring method."""
|
202 |
+
bm25_score, tfidf_score = scores
|
203 |
+
if scoring_method == ScoringMethod.BM25:
|
204 |
+
return bm25_score
|
205 |
+
elif scoring_method == ScoringMethod.TFIDF:
|
206 |
+
return tfidf_score
|
207 |
+
else: # COMBINED
|
208 |
+
return combine_scores(bm25_score, tfidf_score)
|
209 |
+
|
210 |
+
async def generate_summary(query: str, articles: List[Dict[str, Any]], temperature: float = 0.7) -> str:
|
211 |
+
"""
|
212 |
+
Generate a summary of the articles using Groq's LLama 3.1 8b model.
|
213 |
+
"""
|
214 |
+
try:
|
215 |
+
# Format the articles into a structured JSON string
|
216 |
+
json_input = json.dumps(articles, indent=2)
|
217 |
+
|
218 |
+
system_prompt = """You are Sentinel, a world-class AI model who is expert at searching the web and answering user's queries. You are also an expert at summarizing web pages or documents and searching for content in them."""
|
219 |
+
|
220 |
+
user_prompt = f"""
|
221 |
+
Please provide a comprehensive summary based on the following JSON input:
|
222 |
+
{json_input}
|
223 |
+
|
224 |
+
Original Query: {query}
|
225 |
+
|
226 |
+
Instructions:
|
227 |
+
1. Analyze the query and the provided documents.
|
228 |
+
2. Write a detailed, long, and complete research document that is informative and relevant to the user's query based on provided context.
|
229 |
+
3. Use this context to answer the user's query in the best way possible. Use an unbiased and journalistic tone.
|
230 |
+
4. Use an unbiased and professional tone in your response.
|
231 |
+
5. Do not repeat text verbatim from the input.
|
232 |
+
6. Provide the answer in the response itself.
|
233 |
+
7. Use markdown to format your response.
|
234 |
+
8. Use bullet points to list information where appropriate.
|
235 |
+
9. Cite the answer using [number] notation along with the appropriate source URL embedded in the notation.
|
236 |
+
10. Place these citations at the end of the relevant sentences.
|
237 |
+
11. You can cite the same sentence multiple times if it's relevant.
|
238 |
+
12. Make sure the answer is not short and is informative.
|
239 |
+
13. Your response should be detailed, informative, accurate, and directly relevant to the user's query."""
|
240 |
+
|
241 |
+
messages = [
|
242 |
+
{"role": "system", "content": system_prompt},
|
243 |
+
{"role": "user", "content": user_prompt}
|
244 |
+
]
|
245 |
+
|
246 |
+
response = groq_client.chat.completions.create(
|
247 |
+
messages=messages,
|
248 |
+
model="llama-3.1-70b-versatile", # Using LLama 3.1 8b model
|
249 |
+
max_tokens=5000,
|
250 |
+
temperature=temperature,
|
251 |
+
top_p=0.9,
|
252 |
+
presence_penalty=1.2,
|
253 |
+
stream=False
|
254 |
+
)
|
255 |
+
|
256 |
+
return response.choices[0].message.content.strip()
|
257 |
+
|
258 |
+
except Exception as e:
|
259 |
+
logging.error(f'Error generating summary: {e}')
|
260 |
+
return f"Error generating summary: {str(e)}"
|
261 |
+
|
262 |
+
class ChatBot:
|
263 |
+
def __init__(self):
|
264 |
+
self.scoring_method = ScoringMethod.COMBINED
|
265 |
+
self.num_results = 10
|
266 |
+
self.max_chars = 10000
|
267 |
+
self.score_threshold = 0.8
|
268 |
+
self.temperature = 0.1
|
269 |
+
self.history = []
|
270 |
+
self.base_url = "http://localhost:8888"
|
271 |
+
self.headers = {
|
272 |
+
"X-Searx-API-Key": "f9f07f93b37b8483aadb5ba717f556f3a4ac507b281b4ca01e6c6288aa3e3ae5"
|
273 |
+
}
|
274 |
+
# Default search engines in case we can't fetch from SearxNG
|
275 |
+
self.default_engines = ["google", "bing", "duckduckgo", "brave"]
|
276 |
+
|
277 |
+
async def get_search_results(self,
|
278 |
+
query: str,
|
279 |
+
num_results: int,
|
280 |
+
max_chars: int,
|
281 |
+
score_threshold: float,
|
282 |
+
temperature: float,
|
283 |
+
scoring_method_str: str,
|
284 |
+
selected_engines: List[str]) -> str:
|
285 |
+
try:
|
286 |
+
# Convert scoring method string to enum
|
287 |
+
scoring_method_map = {
|
288 |
+
"BM25": ScoringMethod.BM25,
|
289 |
+
"TF-IDF": ScoringMethod.TFIDF,
|
290 |
+
"Combined": ScoringMethod.COMBINED
|
291 |
+
}
|
292 |
+
self.scoring_method = scoring_method_map[scoring_method_str]
|
293 |
+
|
294 |
+
async with aiohttp.ClientSession() as session:
|
295 |
+
# Use the selected engines from the interface
|
296 |
+
logging.info(f'Using engines: {", ".join(selected_engines)}')
|
297 |
+
logging.info(f'Parameters: Results={num_results}, Chars={max_chars}, Threshold={score_threshold}, Temp={temperature}, Method={scoring_method_str}')
|
298 |
+
|
299 |
+
# Perform search
|
300 |
+
params = {
|
301 |
+
"q": query,
|
302 |
+
"format": "json",
|
303 |
+
"engines": ",".join(selected_engines),
|
304 |
+
"limit": num_results
|
305 |
+
}
|
306 |
+
|
307 |
+
try:
|
308 |
+
async with session.get(f"{self.base_url}/search", headers=self.headers, params=params) as response:
|
309 |
+
data = await response.json()
|
310 |
+
except Exception as e:
|
311 |
+
return f"Error: Could not connect to search service. Please check if SearxNG is running at {self.base_url}. Error: {str(e)}"
|
312 |
+
|
313 |
+
if "results" not in data or not data["results"]:
|
314 |
+
return "No results found."
|
315 |
+
|
316 |
+
results = data["results"][:num_results]
|
317 |
+
tasks = [scrape_url(result["url"], max_chars) for result in results]
|
318 |
+
scraped_data = await asyncio.gather(*tasks)
|
319 |
+
|
320 |
+
valid_results = [(result, article)
|
321 |
+
for result, article in zip(results, scraped_data)
|
322 |
+
if article is not None]
|
323 |
+
|
324 |
+
if not valid_results:
|
325 |
+
return "No valid articles found after scraping."
|
326 |
+
|
327 |
+
results, scraped_data = zip(*valid_results)
|
328 |
+
contents = [article["content"] for article in scraped_data]
|
329 |
+
|
330 |
+
scores = await get_document_scores(query, contents, self.scoring_method)
|
331 |
+
|
332 |
+
scored_articles = []
|
333 |
+
for i, (score_tuple, article) in enumerate(zip(scores, scraped_data)):
|
334 |
+
total_score = get_total_score(score_tuple, self.scoring_method)
|
335 |
+
if total_score >= self.score_threshold:
|
336 |
+
scored_articles.append({
|
337 |
+
"url": results[i]["url"],
|
338 |
+
"title": results[i]["title"],
|
339 |
+
"content": article["content"],
|
340 |
+
"publish_date": article["publish_date"],
|
341 |
+
"score": round(total_score, 4),
|
342 |
+
"bm25_score": round(score_tuple[0], 4),
|
343 |
+
"tfidf_score": round(score_tuple[1], 4),
|
344 |
+
"engine": results[i].get("engine", "unknown")
|
345 |
+
})
|
346 |
+
|
347 |
+
scored_articles.sort(key=lambda x: x["score"], reverse=True)
|
348 |
+
unique_articles = []
|
349 |
+
seen_content = set()
|
350 |
+
|
351 |
+
for article in scored_articles:
|
352 |
+
if article["content"] not in seen_content:
|
353 |
+
seen_content.add(article["content"])
|
354 |
+
unique_articles.append(article)
|
355 |
+
|
356 |
+
# Generate summary using Groq API
|
357 |
+
summary = await generate_summary(query, unique_articles, self.temperature)
|
358 |
+
|
359 |
+
# Format the response for chat
|
360 |
+
response = f"**Search Parameters:**\n"
|
361 |
+
response += f"- Results: {num_results}\n"
|
362 |
+
response += f"- Max Characters: {max_chars}\n"
|
363 |
+
response += f"- Score Threshold: {score_threshold}\n"
|
364 |
+
response += f"- Temperature: {temperature}\n"
|
365 |
+
response += f"- Scoring Method: {scoring_method_str}\n"
|
366 |
+
response += f"- Search Engines: {', '.join(selected_engines)}\n\n"
|
367 |
+
response += f"**Summary of Search Results:**\n\n{summary}\n\n"
|
368 |
+
response += "\n**Sources:**\n"
|
369 |
+
for i, article in enumerate(unique_articles, 1):
|
370 |
+
response += f"{i}. [{article['title']}]({article['url']}) (Score: {article['score']})\n"
|
371 |
+
|
372 |
+
return response
|
373 |
+
|
374 |
+
except Exception as e:
|
375 |
+
logging.error(f'Error in search_and_summarize: {e}')
|
376 |
+
return f"Error occurred: {str(e)}"
|
377 |
+
|
378 |
+
def chat(self,
|
379 |
+
message: str,
|
380 |
+
history: List[List[str]],
|
381 |
+
num_results: int,
|
382 |
+
max_chars: int,
|
383 |
+
score_threshold: float,
|
384 |
+
temperature: float,
|
385 |
+
scoring_method: str,
|
386 |
+
engines: List[str]) -> str:
|
387 |
+
"""
|
388 |
+
Process chat messages and return responses with custom parameters.
|
389 |
+
"""
|
390 |
+
# Run the async search function in the sync context
|
391 |
+
response = asyncio.run(self.get_search_results(
|
392 |
+
message,
|
393 |
+
num_results,
|
394 |
+
max_chars,
|
395 |
+
score_threshold,
|
396 |
+
temperature,
|
397 |
+
scoring_method,
|
398 |
+
engines
|
399 |
+
))
|
400 |
+
return response
|
401 |
+
|
402 |
+
def create_gradio_interface() -> gr.Interface:
|
403 |
+
chatbot = ChatBot()
|
404 |
+
|
405 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
406 |
+
gr.Markdown("# Research Assistant")
|
407 |
+
gr.Markdown("Enter your search query, and I'll search, analyze, and summarize relevant articles for you.")
|
408 |
+
|
409 |
+
with gr.Row():
|
410 |
+
with gr.Column(scale=3):
|
411 |
+
chatbot_interface = gr.ChatInterface(
|
412 |
+
fn=chatbot.chat,
|
413 |
+
additional_inputs=[
|
414 |
+
gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of Results"),
|
415 |
+
gr.Slider(minimum=1000, maximum=50000, value=10000, step=1000, label="Max Characters per Article"),
|
416 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Score Threshold"),
|
417 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Temperature"),
|
418 |
+
gr.Radio(["BM25", "TF-IDF", "Combined"], value="Combined", label="Scoring Method"),
|
419 |
+
gr.CheckboxGroup(
|
420 |
+
choices=["google", "bing", "duckduckgo", "brave", "wikipedia"],
|
421 |
+
value=["google", "bing", "duckduckgo"],
|
422 |
+
label="Search Engines"
|
423 |
+
)
|
424 |
+
],
|
425 |
+
examples=[
|
426 |
+
["What are the latest developments in quantum computing?"],
|
427 |
+
["Explain the impact of artificial intelligence on healthcare"],
|
428 |
+
["What are the current trends in renewable energy?"]
|
429 |
+
]
|
430 |
+
)
|
431 |
+
|
432 |
+
with gr.Column(scale=1):
|
433 |
+
gr.Markdown("### Parameter Descriptions")
|
434 |
+
gr.Markdown("""
|
435 |
+
- **Number of Results**: Number of search results to fetch
|
436 |
+
- **Max Characters**: Maximum characters to analyze per article
|
437 |
+
- **Score Threshold**: Minimum relevance score (0-1) for including articles
|
438 |
+
- **Temperature**: Controls creativity in summary generation (0=focused, 1=creative)
|
439 |
+
- **Scoring Method**: Algorithm for ranking article relevance
|
440 |
+
- BM25: Traditional keyword-based ranking
|
441 |
+
- TF-IDF: Semantic similarity-based ranking
|
442 |
+
- Combined: Balanced approach using both methods
|
443 |
+
- **Search Engines**: Select which search engines to use
|
444 |
+
""")
|
445 |
+
|
446 |
+
return demo
|
447 |
|
448 |
if __name__ == "__main__":
|
449 |
+
# Configure logging
|
450 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
|
451 |
+
|
452 |
+
# Create and launch the interface
|
453 |
+
demo = create_gradio_interface()
|
454 |
+
demo.launch(share=True)
|