Chat_RAG / app.py
Samizie's picture
Update app.py
7ce62eb verified
import subprocess
import streamlit as st
import asyncio
import numpy as np
# Assume these functions exist in your scraper module
import requests
import pandas as pd
import re
import numpy as np
import faiss
from langchain_community.document_loaders import AsyncChromiumLoader
from langchain_community.document_transformers import Html2TextTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaLLM
#from langchain_ollama import OllamaEmbeddings
from langchain_groq import ChatGroq
from itertools import chain
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
subprocess.run(["playwright", "install"], check=True)
# Scraping and Embedding Function
async def process_urls(urls):
# Load multiple URLs asynchronously
loader = AsyncChromiumLoader(urls)
docs = await loader.aload()
# Transform HTML to text
text_transformer = Html2TextTransformer()
transformed_docs = text_transformer.transform_documents(docs)
# Split the text into chunks and retain metadata
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
split_docs_nested = [text_splitter.split_documents([doc]) for doc in transformed_docs]
#split_docs = text_splitter.split_documents(transformed_docs)
split_docs = list(chain.from_iterable(split_docs_nested))
# Attach the source URL to each split document
for doc in split_docs:
doc.metadata["source_url"] = doc.metadata.get("source", "Unknown") # Ensure URL metadata exists
return split_docs
def clean_text(text):
"""Remove unnecessary whitespace, line breaks, and special characters."""
text = re.sub(r'\s+', ' ', text).strip() # Remove excessive whitespace
text = re.sub(r'\[.*?\]|\(.*?\)', '', text) # Remove bracketed text (e.g., [advert])
return text
def embed_text(text_list):
embeddings = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
#return embeddings.encode(text_list)
if embeddings is None or len(embeddings) == 0:
raise ValueError("Embedding function returned an empty result.")
return embeddings.encode(text_list)
def store_embeddings(docs):
"""Convert text into embeddings and store them in FAISS."""
#all_text = [clean_text(doc.page_content) for doc in docs if doc.page_content]
all_text = [clean_text(doc.page_content) for doc in docs if hasattr(doc, "page_content")]
text_sources = [doc.metadata["source_url"] for doc in docs]
embeddings = embed_text(all_text)
if embeddings is None or embeddings.size == 0:
raise ValueError("Embedding function returned None or empty list.")
embeddings = np.array(embeddings, dtype=np.float32)
# Normalize embeddings for better FAISS similarity search
faiss.normalize_L2(embeddings)
d = embeddings.shape[1]
index = faiss.IndexFlatIP(d) # Inner Product (cosine similarity)
index.add(embeddings)
return index, all_text, text_sources
def search_faiss(index, query_embedding, text_data, text_sources, top_k=5, min_score=0.5):
#query_embedding = np.array([query_embedding], dtype=np.float32)
query_embedding = query_embedding.reshape(1, -1)
faiss.normalize_L2(query_embedding) # Normalize query embedding for similarity
distances, indices = index.search(query_embedding, top_k)
results = []
if indices.size > 0:
for i in range(len(indices[0])):
if distances[0][i] >= min_score: # Ignore irrelevant results
idx = indices[0][i]
if idx < len(text_data):
results.append({"source": text_sources[idx], "content": text_data[idx]})
return results
def query_llm(index, text_data, text_sources, query):
groq_api="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"
chat = ChatGroq(model="llama-3.2-1b-preview", groq_api_key=groq_api, temperature=0)
# Embed the query
query_embedding = embed_text([query])[0]
# Search FAISS for relevant documents
relevant_docs = search_faiss(index, query_embedding, text_data, text_sources, top_k=3)
print(type(relevant_docs))
print(relevant_docs)
# If no relevant docs, return a default message
if not relevant_docs:
return "No relevant information found."
# Query LLM with retrieved content
responses = []
for doc in relevant_docs:
if isinstance(doc, dict) and "source" in doc and "content" in doc:
source_url = doc["source"]
content = doc["content"][:10000]
else:
print(f"Unexpected doc format: {doc}") # Debugging print
continue
prompt = f"""
Based on the following content, answer the question: "{query}"
Content (from {source_url}):
{content}
"
"""
response = chat.invoke(prompt)
#print(type(response))
responses.append({"source": source_url, "response": response})
return responses
# Streamlit UI
st.title("Web Scraper & AI Query Interface")
urls = st.text_area("Enter URLs (one per line)", "https://en.wikipedia.org/wiki/Nigeria\nhttps://en.wikipedia.org/wiki/Ghana")
query = st.text_input("Enter your question", "Where is Nigeria located?")
if st.button("Run Scraper"):
st.write("Fetching and processing URLs...")
async def run_scraper():
url_list = urls.split("\n")
split_docs = await process_urls(url_list)
index, text_data, text_sources = store_embeddings(split_docs)
return index, text_data, text_sources
# Run async function inside Streamlit
index, text_data, text_sources = asyncio.run(run_scraper())
st.write("Data processed! Now you can ask questions about the scraped content.")
user_query = st.text_input("Ask a question about the scraped data")
if st.button("Query Model"):
query_embedding = np.array([embed_text([user_query])[0]]).reshape(1, -1)
result = query_llm(index, text_data, text_sources, user_query)
for entry in result:
st.subheader(f"Source: {entry['source']}")
st.write(f"Response: {entry['response'].content}")