Shreyas094 commited on
Commit
10660a7
·
verified ·
1 Parent(s): 82e8e41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -9
app.py CHANGED
@@ -3,8 +3,12 @@ import json
3
  import re
4
  import gradio as gr
5
  import pandas as pd
 
 
 
6
  from tempfile import NamedTemporaryFile
7
  from typing import List
 
8
  from langchain_core.prompts import ChatPromptTemplate
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_community.document_loaders import PyPDFLoader
@@ -119,7 +123,78 @@ def is_related_to_history(question, history, threshold=0.3):
119
  similarity = get_similarity(question, history_text)
120
  return similarity > threshold
121
 
122
- def ask_question(question, temperature, top_p, repetition_penalty):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  global conversation_history
124
 
125
  if not question:
@@ -129,17 +204,21 @@ def ask_question(question, temperature, top_p, repetition_penalty):
129
  answer = memory_database[question]
130
  else:
131
  embed = get_embeddings()
132
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
133
  model = get_model(temperature, top_p, repetition_penalty)
134
 
135
  history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
136
 
137
- if is_related_to_history(question, conversation_history):
138
- context_str = "No additional context needed. Please refer to the conversation history."
 
139
  else:
140
- retriever = database.as_retriever()
141
- relevant_docs = retriever.get_relevant_documents(question)
142
- context_str = "\n".join([doc.page_content for doc in relevant_docs])
 
 
 
 
143
 
144
  prompt_val = ChatPromptTemplate.from_template(prompt)
145
  formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
@@ -220,9 +299,10 @@ with gr.Blocks() as demo:
220
  temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
221
  top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
222
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
 
223
 
224
  def chat(question, history):
225
- answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value)
226
  history.append((question, answer))
227
  return "", history
228
 
@@ -241,4 +321,4 @@ with gr.Blocks() as demo:
241
  clear_button.click(clear_cache, inputs=[], outputs=clear_output)
242
 
243
  if __name__ == "__main__":
244
- demo.launch()
 
3
  import re
4
  import gradio as gr
5
  import pandas as pd
6
+ import requests
7
+ import random
8
+ import urllib.parse
9
  from tempfile import NamedTemporaryFile
10
  from typing import List
11
+ from bs4 import BeautifulSoup
12
  from langchain_core.prompts import ChatPromptTemplate
13
  from langchain_community.vectorstores import FAISS
14
  from langchain_community.document_loaders import PyPDFLoader
 
123
  similarity = get_similarity(question, history_text)
124
  return similarity > threshold
125
 
126
+ def extract_text_from_webpage(html):
127
+ soup = BeautifulSoup(html, 'html.parser')
128
+ for script in soup(["script", "style"]):
129
+ script.extract() # Remove scripts and styles
130
+ text = soup.get_text()
131
+ lines = (line.strip() for line in text.splitlines())
132
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
133
+ text = '\n'.join(chunk for chunk in chunks if chunk)
134
+ return text
135
+
136
+ _useragent_list = [
137
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
138
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
139
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
140
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
141
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
142
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
143
+ ]
144
+
145
+ def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
146
+ escaped_term = urllib.parse.quote_plus(term)
147
+ start = 0
148
+ all_results = []
149
+ max_chars_per_page = 8000 # Limit the number of characters from each webpage to stay under the token limit
150
+
151
+ with requests.Session() as session:
152
+ while start < num_results:
153
+ try:
154
+ user_agent = random.choice(_useragent_list)
155
+ headers = {
156
+ 'User-Agent': user_agent
157
+ }
158
+ resp = session.get(
159
+ url="https://www.google.com/search",
160
+ headers=headers,
161
+ params={
162
+ "q": term,
163
+ "num": num_results - start,
164
+ "hl": lang,
165
+ "start": start,
166
+ "safe": safe,
167
+ },
168
+ timeout=timeout,
169
+ verify=ssl_verify,
170
+ )
171
+ resp.raise_for_status()
172
+ except requests.exceptions.RequestException as e:
173
+ break
174
+
175
+ soup = BeautifulSoup(resp.text, "html.parser")
176
+ result_block = soup.find_all("div", attrs={"class": "g"})
177
+ if not result_block:
178
+ break
179
+ for result in result_block:
180
+ link = result.find("a", href=True)
181
+ if link:
182
+ link = link["href"]
183
+ try:
184
+ webpage = session.get(link, headers=headers, timeout=timeout)
185
+ webpage.raise_for_status()
186
+ visible_text = extract_text_from_webpage(webpage.text)
187
+ if len(visible_text) > max_chars_per_page:
188
+ visible_text = visible_text[:max_chars_per_page] + "..."
189
+ all_results.append({"link": link, "text": visible_text})
190
+ except requests.exceptions.RequestException as e:
191
+ all_results.append({"link": link, "text": None})
192
+ else:
193
+ all_results.append({"link": None, "text": None})
194
+ start += len(result_block)
195
+ return all_results
196
+
197
+ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
198
  global conversation_history
199
 
200
  if not question:
 
204
  answer = memory_database[question]
205
  else:
206
  embed = get_embeddings()
 
207
  model = get_model(temperature, top_p, repetition_penalty)
208
 
209
  history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
210
 
211
+ if web_search:
212
+ search_results = google_search(question)
213
+ context_str = "\n".join([result["text"] for result in search_results if result["text"]])
214
  else:
215
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
216
+ if is_related_to_history(question, conversation_history):
217
+ context_str = "No additional context needed. Please refer to the conversation history."
218
+ else:
219
+ retriever = database.as_retriever()
220
+ relevant_docs = retriever.get_relevant_documents(question)
221
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
222
 
223
  prompt_val = ChatPromptTemplate.from_template(prompt)
224
  formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
 
299
  temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
300
  top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
301
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
302
+ web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
303
 
304
  def chat(question, history):
305
+ answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value, web_search_checkbox.value)
306
  history.append((question, answer))
307
  return "", history
308
 
 
321
  clear_button.click(clear_cache, inputs=[], outputs=clear_output)
322
 
323
  if __name__ == "__main__":
324
+ demo.launch()