Spaces:
Sleeping
Sleeping
import subprocess | |
import streamlit as st | |
from decouple import config | |
import asyncio | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate | |
from langchain_core.messages import SystemMessage | |
from scraper.scraper import process_urls | |
from embedding.vector_store import initialize_vector_store, clear_chroma_db | |
from conversation.talks import clean_input, small_talks | |
#subprocess.run(["playwright", "install"], check=True) | |
#subprocess.run(["playwright", "install-deps"], check=True) | |
#Clearing ChromaDB at startup to clean up any previous data | |
clear_chroma_db() | |
#Groq API Key | |
groq_api = config("GROQ_API_KEY") | |
#Initializing LLM with memory | |
llm = ChatGroq(model="llama-3.2-1b-preview", groq_api_key=groq_api, temperature=0) | |
#Ensure proper asyncio handling for Windows | |
import sys | |
if sys.platform.startswith("win"): | |
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) | |
#Async helper function | |
def run_asyncio_coroutine(coro): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return loop.run_until_complete(coro) | |
import streamlit as st | |
st.title("WebGPT 1.0 π€") | |
# URL inputs | |
urls = st.text_area("Enter URLs (one per line)") | |
run_scraper = st.button("Run Scraper", disabled=not urls.strip()) | |
# Sessions & states | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] # Chat history | |
if "history" not in st.session_state: | |
st.session_state.history = "" # Stores past Q&A for memory | |
if "scraping_done" not in st.session_state: | |
st.session_state.scraping_done = False | |
if "vector_store" not in st.session_state: | |
st.session_state.vector_store = None | |
# Run scraper | |
if run_scraper: | |
st.write("Fetching and processing URLs... This may take a while.") | |
split_docs = run_asyncio_coroutine(process_urls(urls.split("\n"))) | |
st.session_state.vector_store = initialize_vector_store(split_docs) | |
st.session_state.scraping_done = True | |
st.success("Scraping and processing completed!") | |
# β Clear chat button | |
if st.button("Clear Chat"): | |
st.session_state.messages = [] # Reset message history | |
st.session_state.history = "" # Reset history tracking | |
st.success("Chat cleared!") | |
# Ensuring chat only enables after scraping | |
if not st.session_state.scraping_done: | |
st.warning("Scrape some data first to enable chat!") | |
else: | |
st.write("### Chat With WebGPT π¬") | |
# Display chat history | |
for message in st.session_state.messages: | |
role, text = message["role"], message["text"] | |
with st.chat_message(role): | |
st.write(text) | |
# Takes in user input | |
user_query = st.chat_input("Ask a question...") | |
if user_query: | |
st.session_state.messages.append({"role": "user", "text": user_query}) | |
with st.chat_message("user"): | |
st.write(user_query) | |
user_query_cleaned = clean_input(user_query) | |
response = "" # Default value for response | |
source_url = "" # Default value for source url | |
# Check for small talk responses | |
if user_query_cleaned in small_talks: | |
response = small_talks[user_query_cleaned] | |
source_url = "Knowledge base" # Small talk comes from the knowledge base | |
else: | |
# β Setup retriever (with a similarity threshold or top-k retrieval) | |
retriever = st.session_state.vector_store.as_retriever( | |
search_kwargs={'k': 5} | |
) | |
# β Retrieve context | |
retrieved_docs = retriever.invoke(user_query_cleaned) | |
retrieved_text = " ".join([doc.page_content for doc in retrieved_docs]) | |
# β Define Langchain PromptTemplate properly | |
system_prompt_template = PromptTemplate( | |
input_variables=["context", "query"], | |
template=""" | |
You are WebGPT, an AI assistant for question-answering tasks that **only answers questions based on the provided context**. | |
- Understand the context {context} first and provide a relevant answer. | |
- If the answer is **not** found in the Context, reply with: "I can't find your request in the provided context." | |
- If the question is **unrelated** to the Context, reply with: "I can't answer that. do not generate responses." | |
- **Do not** use external knowledge, assumptions, or filler responses. Stick to the context provided. | |
- Keep responses clear, concise, and relevant to the userβs query. | |
Context: | |
{context} | |
Now, answer the user's question: | |
{input} | |
""" | |
) | |
# β Generate prompt with retrieved context & user query | |
final_prompt = system_prompt_template.format( | |
context=retrieved_text, | |
input=user_query_cleaned | |
) | |
# β Create chains (ensure the prompt is correct) | |
scraper_chain = create_stuff_documents_chain(llm=llm, prompt=system_prompt_template) | |
llm_chain = create_retrieval_chain(retriever, scraper_chain) | |
# β Process response and source | |
if retrieved_docs: | |
try: | |
response_data = llm_chain.invoke({"context": retrieved_text, "input": user_query_cleaned}) | |
response = response_data.get("answer", "").strip() | |
source_url = retrieved_docs[0].metadata.get("source", "Unknown") | |
# Fallback if response is still empty | |
if not response: | |
response = "I can't find your request in the provided context." | |
source_url = "No source found" | |
except Exception as e: | |
response = f"Error generating response: {str(e)}" | |
source_url = "Error" | |
else: | |
response = "I can't find your request in the provided context." | |
source_url = "No source found" | |
# β Track history & update session state | |
history_text = "\n".join( | |
[f"User: {msg['text']}" if msg["role"] == "user" else f"AI: {msg['text']}" for msg in st.session_state.messages] | |
) | |
st.session_state.history = history_text | |
# β Format and display response | |
formatted_response = f"**Answer:** {response}" | |
if response != "I can't find your request in the provided context." and source_url: | |
formatted_response += f"\n\n**Source:** {source_url}" | |
st.session_state.messages.append({"role": "assistant", "text": formatted_response}) | |
with st.chat_message("assistant"): | |
st.write(formatted_response) | |