Spaces:
Paused
Paused
Shreyas094
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,8 +3,12 @@ import json
|
|
3 |
import re
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
|
|
|
|
|
|
6 |
from tempfile import NamedTemporaryFile
|
7 |
from typing import List
|
|
|
8 |
from langchain_core.prompts import ChatPromptTemplate
|
9 |
from langchain_community.vectorstores import FAISS
|
10 |
from langchain_community.document_loaders import PyPDFLoader
|
@@ -119,7 +123,78 @@ def is_related_to_history(question, history, threshold=0.3):
|
|
119 |
similarity = get_similarity(question, history_text)
|
120 |
return similarity > threshold
|
121 |
|
122 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
global conversation_history
|
124 |
|
125 |
if not question:
|
@@ -129,17 +204,21 @@ def ask_question(question, temperature, top_p, repetition_penalty):
|
|
129 |
answer = memory_database[question]
|
130 |
else:
|
131 |
embed = get_embeddings()
|
132 |
-
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
133 |
model = get_model(temperature, top_p, repetition_penalty)
|
134 |
|
135 |
history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
|
136 |
|
137 |
-
if
|
138 |
-
|
|
|
139 |
else:
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
143 |
|
144 |
prompt_val = ChatPromptTemplate.from_template(prompt)
|
145 |
formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
|
@@ -220,9 +299,10 @@ with gr.Blocks() as demo:
|
|
220 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
221 |
top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
|
222 |
repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
|
|
|
223 |
|
224 |
def chat(question, history):
|
225 |
-
answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value)
|
226 |
history.append((question, answer))
|
227 |
return "", history
|
228 |
|
@@ -241,4 +321,4 @@ with gr.Blocks() as demo:
|
|
241 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|
242 |
|
243 |
if __name__ == "__main__":
|
244 |
-
demo.launch()
|
|
|
3 |
import re
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
6 |
+
import requests
|
7 |
+
import random
|
8 |
+
import urllib.parse
|
9 |
from tempfile import NamedTemporaryFile
|
10 |
from typing import List
|
11 |
+
from bs4 import BeautifulSoup
|
12 |
from langchain_core.prompts import ChatPromptTemplate
|
13 |
from langchain_community.vectorstores import FAISS
|
14 |
from langchain_community.document_loaders import PyPDFLoader
|
|
|
123 |
similarity = get_similarity(question, history_text)
|
124 |
return similarity > threshold
|
125 |
|
126 |
+
def extract_text_from_webpage(html):
|
127 |
+
soup = BeautifulSoup(html, 'html.parser')
|
128 |
+
for script in soup(["script", "style"]):
|
129 |
+
script.extract() # Remove scripts and styles
|
130 |
+
text = soup.get_text()
|
131 |
+
lines = (line.strip() for line in text.splitlines())
|
132 |
+
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
133 |
+
text = '\n'.join(chunk for chunk in chunks if chunk)
|
134 |
+
return text
|
135 |
+
|
136 |
+
_useragent_list = [
|
137 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
138 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
139 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
|
140 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
|
141 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
|
142 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
|
143 |
+
]
|
144 |
+
|
145 |
+
def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
|
146 |
+
escaped_term = urllib.parse.quote_plus(term)
|
147 |
+
start = 0
|
148 |
+
all_results = []
|
149 |
+
max_chars_per_page = 8000 # Limit the number of characters from each webpage to stay under the token limit
|
150 |
+
|
151 |
+
with requests.Session() as session:
|
152 |
+
while start < num_results:
|
153 |
+
try:
|
154 |
+
user_agent = random.choice(_useragent_list)
|
155 |
+
headers = {
|
156 |
+
'User-Agent': user_agent
|
157 |
+
}
|
158 |
+
resp = session.get(
|
159 |
+
url="https://www.google.com/search",
|
160 |
+
headers=headers,
|
161 |
+
params={
|
162 |
+
"q": term,
|
163 |
+
"num": num_results - start,
|
164 |
+
"hl": lang,
|
165 |
+
"start": start,
|
166 |
+
"safe": safe,
|
167 |
+
},
|
168 |
+
timeout=timeout,
|
169 |
+
verify=ssl_verify,
|
170 |
+
)
|
171 |
+
resp.raise_for_status()
|
172 |
+
except requests.exceptions.RequestException as e:
|
173 |
+
break
|
174 |
+
|
175 |
+
soup = BeautifulSoup(resp.text, "html.parser")
|
176 |
+
result_block = soup.find_all("div", attrs={"class": "g"})
|
177 |
+
if not result_block:
|
178 |
+
break
|
179 |
+
for result in result_block:
|
180 |
+
link = result.find("a", href=True)
|
181 |
+
if link:
|
182 |
+
link = link["href"]
|
183 |
+
try:
|
184 |
+
webpage = session.get(link, headers=headers, timeout=timeout)
|
185 |
+
webpage.raise_for_status()
|
186 |
+
visible_text = extract_text_from_webpage(webpage.text)
|
187 |
+
if len(visible_text) > max_chars_per_page:
|
188 |
+
visible_text = visible_text[:max_chars_per_page] + "..."
|
189 |
+
all_results.append({"link": link, "text": visible_text})
|
190 |
+
except requests.exceptions.RequestException as e:
|
191 |
+
all_results.append({"link": link, "text": None})
|
192 |
+
else:
|
193 |
+
all_results.append({"link": None, "text": None})
|
194 |
+
start += len(result_block)
|
195 |
+
return all_results
|
196 |
+
|
197 |
+
def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
198 |
global conversation_history
|
199 |
|
200 |
if not question:
|
|
|
204 |
answer = memory_database[question]
|
205 |
else:
|
206 |
embed = get_embeddings()
|
|
|
207 |
model = get_model(temperature, top_p, repetition_penalty)
|
208 |
|
209 |
history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
|
210 |
|
211 |
+
if web_search:
|
212 |
+
search_results = google_search(question)
|
213 |
+
context_str = "\n".join([result["text"] for result in search_results if result["text"]])
|
214 |
else:
|
215 |
+
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
216 |
+
if is_related_to_history(question, conversation_history):
|
217 |
+
context_str = "No additional context needed. Please refer to the conversation history."
|
218 |
+
else:
|
219 |
+
retriever = database.as_retriever()
|
220 |
+
relevant_docs = retriever.get_relevant_documents(question)
|
221 |
+
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
222 |
|
223 |
prompt_val = ChatPromptTemplate.from_template(prompt)
|
224 |
formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
|
|
|
299 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
300 |
top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
|
301 |
repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
|
302 |
+
web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
|
303 |
|
304 |
def chat(question, history):
|
305 |
+
answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value, web_search_checkbox.value)
|
306 |
history.append((question, answer))
|
307 |
return "", history
|
308 |
|
|
|
321 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|
322 |
|
323 |
if __name__ == "__main__":
|
324 |
+
demo.launch()
|