Shreyas094's picture
Update app.py
54eab8b verified
raw
history blame
10.4 kB
import os
import json
import re
import gradio as gr
import pandas as pd
import requests
import random
import urllib.parse
from tempfile import NamedTemporaryFile
from typing import List
from bs4 import BeautifulSoup
import logging
from duckduckgo_search import ddg
from langchain_community.llms import HuggingFaceHub
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
# Global variables
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
def get_model(temperature, top_p, repetition_penalty):
return HuggingFaceHub(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
model_kwargs={
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"max_length": 1000
},
huggingfacehub_api_token=huggingface_token
)
def load_document(file: NamedTemporaryFile) -> List[Document]:
loader = PyPDFLoader(file.name)
return loader.load_and_split()
def update_vectors(files):
if not files:
return "Please upload at least one PDF file."
embed = get_embeddings()
total_chunks = 0
all_data = []
for file in files:
data = load_document(file)
all_data.extend(data)
total_chunks += len(data)
if os.path.exists("faiss_database"):
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
database.add_documents(all_data)
else:
database = FAISS.from_documents(all_data, embed)
database.save_local("faiss_database")
return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
def get_embeddings():
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
def clear_cache():
if os.path.exists("faiss_database"):
os.remove("faiss_database")
return "Cache cleared successfully."
else:
return "No cache to clear."
def extract_text_from_webpage(html):
soup = BeautifulSoup(html, 'html.parser')
for script in soup(["script", "style"]):
script.extract()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return text
_useragent_list = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
]
def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
escaped_term = urllib.parse.quote_plus(term)
start = 0
all_results = []
max_chars_per_page = 8000
with requests.Session() as session:
while start < num_results:
try:
user_agent = random.choice(_useragent_list)
headers = {
'User-Agent': user_agent
}
resp = session.get(
url="https://www.google.com/search",
headers=headers,
params={
"q": term,
"num": num_results - start,
"hl": lang,
"start": start,
"safe": safe,
},
timeout=timeout,
verify=ssl_verify,
)
resp.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"Error retrieving search results: {e}")
break
soup = BeautifulSoup(resp.text, "html.parser")
result_block = soup.find_all("div", attrs={"class": "g"})
if not result_block:
break
for result in result_block:
link = result.find("a", href=True)
if link:
link = link["href"]
try:
webpage = session.get(link, headers=headers, timeout=timeout)
webpage.raise_for_status()
visible_text = extract_text_from_webpage(webpage.text)
if len(visible_text) > max_chars_per_page:
visible_text = visible_text[:max_chars_per_page] + "..."
all_results.append({"link": link, "text": visible_text})
except requests.exceptions.RequestException as e:
print(f"Error retrieving webpage content: {e}")
all_results.append({"link": link, "text": None})
else:
all_results.append({"link": None, "text": None})
start += len(result_block)
if not all_results:
return [{"link": None, "text": "No information found in the web search results."}]
return all_results
def duckduckgo_search(query, max_results=5):
try:
results = ddg(query, region='wt-wt', safesearch='Moderate', time=None, max_results=max_results)
formatted_results = []
for result in results:
formatted_results.append({
"link": result.get('href', ''),
"text": result.get('title', '') + '. ' + result.get('body', '')
})
return formatted_results
except Exception as e:
print(f"Error in DuckDuckGo search: {e}")
return [{"link": None, "text": "No information found in the web search results."}]
def respond(
message,
history: list[tuple[str, str]],
temperature,
top_p,
repetition_penalty,
max_tokens,
search_engine
):
model = get_model(temperature, top_p, repetition_penalty)
# Perform web search
if search_engine == "Google":
search_results = google_search(message)
else:
search_results = duckduckgo_search(message)
# Check if we have a FAISS database
if os.path.exists("faiss_database"):
embed = get_embeddings()
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
retriever = database.as_retriever()
relevant_docs = retriever.get_relevant_documents(message)
context_str = "\n".join([doc.page_content for doc in relevant_docs])
# Use the context in the prompt
prompt_template = f"""
Answer the question based on the following context and web search results:
Context from documents:
{context_str}
Web Search Results:
{{search_results}}
Question: {{message}}
If the context and web search results don't contain relevant information, state that the information is not available.
Provide a concise and direct answer to the question.
"""
else:
prompt_template = """
Answer the question based on the following web search results:
Web Search Results:
{search_results}
Question: {message}
If the web search results don't contain relevant information, state that the information is not available.
Provide a concise and direct answer to the question.
"""
prompt = PromptTemplate(
input_variables=["search_results", "message"],
template=prompt_template
)
chain = LLMChain(llm=model, prompt=prompt)
search_results_text = "\n".join([f"- {result['text']}" for result in search_results if result['text']])
response = chain.run(search_results=search_results_text, message=message)
# Add sources
sources = set(result["link"] for result in search_results if result["link"])
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
response += sources_section
return response
# Gradio interface
demo = gr.Blocks()
with demo:
gr.Markdown("# Chat with your PDF documents and Web Search")
with gr.Row():
file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
update_button = gr.Button("Upload PDF")
update_output = gr.Textbox(label="Update Status")
update_button.click(update_vectors, inputs=[file_input], outputs=update_output)
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Conversation")
message_input = gr.Textbox(label="Enter your message")
submit_button = gr.Button("Submit")
with gr.Column(scale=1):
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty")
max_tokens = gr.Slider(minimum=1, maximum=1000, value=500, step=1, label="Max tokens")
search_engine = gr.Dropdown(["DuckDuckGo", "Google"], value="DuckDuckGo", label="Search Engine")
submit_button.click(
respond,
inputs=[
message_input,
chatbot,
temperature,
top_p,
repetition_penalty,
max_tokens,
search_engine
],
outputs=chatbot
)
clear_button = gr.Button("Clear Cache")
clear_output = gr.Textbox(label="Cache Status")
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
if __name__ == "__main__":
demo.launch()