Spaces:
Sleeping
Sleeping
from huggingface_hub import InferenceClient | |
import streamlit as st | |
import logging | |
import os | |
from dotenv import load_dotenv | |
from datasets import load_dataset | |
from langchain_core.documents import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import BedrockEmbeddings | |
from langchain_qdrant import Qdrant | |
from langchain_aws import ChatBedrock | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import Distance, VectorParams | |
import re | |
import json | |
from urllib.error import URLError | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def load_environment(): | |
"""Load and validate environment variables.""" | |
try: | |
load_dotenv() | |
required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION', 'QDRANT_URL', 'QDRANT_API_KEY'] | |
missing_vars = [var for var in required_vars if not os.getenv(var)] | |
if missing_vars: | |
logger.error(f"Missing environment variables: {missing_vars}") | |
st.error(f"Missing environment variables: {missing_vars}") | |
raise ValueError(f"Missing environment variables: {missing_vars}") | |
logger.info("Environment variables loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading environment variables: {e}") | |
st.error(f"Error loading environment variables: {e}") | |
raise | |
def load_wikipedia_documents(): | |
"""Load 100 Wikipedia documents from Cohere's HF dataset.""" | |
try: | |
dataset = load_dataset( | |
"Cohere/wikipedia-22-12-simple-embeddings", | |
split="train[:100]" # Load only 100 entries | |
) | |
documents = [Document(page_content=item["text"]) for item in dataset] | |
logger.info(f"Loaded {len(documents)} Wikipedia documents") | |
if not documents: | |
logger.error("No documents loaded from dataset") | |
st.error("No documents loaded from dataset") | |
return [] | |
return documents | |
except Exception as e: | |
logger.error(f"Error loading dataset: {e}") | |
st.error(f"Failed to load dataset: {e}") | |
return [] | |
def split_documents(_documents): | |
"""Split documents into chunks.""" | |
try: | |
if not _documents: | |
logger.error("No documents provided for splitting") | |
st.error("No documents provided for splitting") | |
return [] | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
chunks = splitter.split_documents(_documents) | |
logger.info(f"Split into {len(chunks)} chunks") | |
if not chunks: | |
logger.error("No chunks created from documents") | |
st.error("No chunks created from documents") | |
return [] | |
return chunks | |
except Exception as e: | |
logger.error(f"Error splitting documents: {e}") | |
st.error(f"Failed to split documents: {e}") | |
return [] | |
def initialize_embeddings(): | |
"""Initialize AWS Bedrock embeddings.""" | |
try: | |
embeddings = BedrockEmbeddings( | |
model_id="amazon.titan-embed-text-v1", | |
region_name=os.getenv("AWS_REGION") | |
) | |
logger.info("Initialized Bedrock embeddings") | |
return embeddings | |
except Exception as e: | |
logger.error(f"Error initializing embeddings: {e}") | |
st.error(f"Failed to initialize embeddings: {e}") | |
return None | |
def store_in_qdrant(_chunks, _embeddings): | |
"""Store document chunks in a hosted Qdrant instance after deleting all collections.""" | |
try: | |
# Initialize Qdrant client | |
client = QdrantClient( | |
url=os.getenv("QDRANT_URL"), | |
api_key=os.getenv("QDRANT_API_KEY"), | |
timeout=30 | |
) | |
# Test Qdrant connection | |
try: | |
client.get_collections() | |
logger.info("Successfully connected to Qdrant at %s", os.getenv("QDRANT_URL")) | |
except Exception as e: | |
logger.error("Failed to connect to Qdrant: %s", e) | |
st.error(f"Failed to connect to Qdrant: {e}") | |
return None | |
# Delete all existing collections | |
try: | |
collections = client.get_collections().collections | |
for collection in collections: | |
client.delete_collection(collection.name) | |
logger.info(f"Deleted Qdrant collection: {collection.name}") | |
logger.info("All Qdrant collections deleted") | |
except Exception as e: | |
logger.warning(f"Error deleting collections: {e}") | |
st.warning(f"Error deleting collections: {e}") | |
# Validate input chunks | |
if not _chunks: | |
logger.error("No chunks provided for Qdrant storage") | |
st.error("No chunks provided for Qdrant storage") | |
return None | |
# Create and populate new collection | |
collection_name = "wikipedia_chunks" | |
try: | |
vector_store = Qdrant.from_documents( | |
documents=_chunks, | |
embedding=_embeddings, | |
url=os.getenv("QDRANT_URL"), | |
api_key=os.getenv("QDRANT_API_KEY"), | |
collection_name=collection_name, | |
force_recreate=True # Ensure fresh collection | |
) | |
logger.info(f"Created Qdrant collection {collection_name} with {len(_chunks)} chunks") | |
except Exception as e: | |
logger.error(f"Error creating Qdrant collection: {e}") | |
st.error(f"Failed to create Qdrant collection: {e}") | |
return None | |
# Verify storage | |
try: | |
collection_info = client.get_collection(collection_name) | |
stored_points = collection_info.points_count | |
logger.info(f"Stored {stored_points} points in Qdrant collection {collection_name}") | |
if stored_points == 0: | |
logger.error("No documents stored in Qdrant collection") | |
st.error("No documents stored in Qdrant collection") | |
return None | |
if stored_points != len(_chunks): | |
logger.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant") | |
st.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant") | |
return vector_store | |
except Exception as e: | |
logger.error(f"Error verifying Qdrant storage: {e}") | |
st.error(f"Failed to verify Qdrant storage: {e}") | |
return None | |
except Exception as e: | |
logger.error(f"Error in Qdrant storage process: {e}") | |
st.error(f"Failed to store documents in Qdrant: {e}") | |
return None | |
def initialize_llm(): | |
"""Initialize AWS Bedrock Claude 3.5 Sonnet model.""" | |
try: | |
llm = ChatBedrock( | |
model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", | |
region_name=os.getenv("AWS_REGION"), | |
model_kwargs={"max_tokens": 1000} | |
) | |
logger.info("Initialized Claude 3.5 Sonnet") | |
return llm | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {e}") | |
st.error(f"Failed to initialize LLM: {e}") | |
return None | |
def extract_score_from_text(text): | |
"""Extract the first float number between 0 and 1 from the text using regex.""" | |
try: | |
matches = re.findall(r'\b0(?:\.\d+)?\b|\b1(?:\.0+)?\b', text) | |
if not matches: | |
logger.warning("No score found in text") | |
return None | |
score = float(matches[0]) | |
if 0.0 <= score <= 1.0: | |
return score | |
logger.warning(f"Score {score} out of expected range 0-1") | |
return None | |
except ValueError as e: | |
logger.warning(f"Cannot convert match to float: {e}") | |
return None | |
def claude_rerank(docs, query, llm, top_n=5): | |
"""Rerank documents based on relevance using the LLM.""" | |
try: | |
rerank_prompt = ChatPromptTemplate.from_template( | |
""" | |
Given the query: "{query}" and the document chunk: "{chunk}", please rate | |
the relevance on a scale from 0 to 1 (0=not relevant, 1=highly relevant). | |
Respond with a number only, like: 0.8 | |
""" | |
) | |
scored_docs = [] | |
for idx, doc in enumerate(docs): | |
prompt = rerank_prompt.format(query=query, chunk=doc.page_content) | |
response = llm.invoke(prompt) | |
text = response.content.strip() | |
logger.info(f"Doc {idx} rerank raw output: {text}") | |
score = extract_score_from_text(text) | |
if score is None: | |
logger.warning(f"Failed to extract valid score for doc {idx}. Assigning 0.") | |
score = 0.0 | |
scored_docs.append((doc, score)) | |
scored_docs.sort(key=lambda x: x[1], reverse=True) | |
logger.info(f"Reranked top {top_n} docs based on scores") | |
return [doc for doc, _ in scored_docs[:top_n]] | |
except Exception as e: | |
logger.error(f"Error in reranking: {e}") | |
st.error(f"Error in reranking: {e}") | |
return docs[:top_n] # Fallback to original docs | |
def create_rag_chain(vector_store, llm, use_rerank=False): | |
"""Create a RAG chain with or without reranking.""" | |
try: | |
prompt_template = ChatPromptTemplate.from_template( | |
"""You are a helpful assistant. Use the following context to answer the question concisely.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:""" | |
) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 20 if use_rerank else 5}) | |
def rerank_context(inputs): | |
try: | |
docs = retriever.invoke(inputs["question"]) | |
if not docs: | |
logger.warning("No documents retrieved for query") | |
return {"context": "", "question": inputs["question"]} | |
if use_rerank: | |
docs = claude_rerank(docs, inputs["question"], llm) | |
return {"context": "\n\n".join(doc.page_content for doc in docs), "question": inputs["question"]} | |
except Exception as e: | |
logger.error(f"Error in rerank_context: {e}") | |
return {"context": "", "question": inputs["question"]} | |
chain = rerank_context | prompt_template | llm | StrOutputParser() | |
logger.info(f"Initialized {'re-ranked' if use_rerank else 'baseline'} RAG chain") | |
return chain | |
except Exception as e: | |
logger.error(f"Error creating RAG chain: {e}") | |
st.error(f"Failed to create RAG chain: {e}") | |
return None | |
def main(): | |
st.title("Wikipedia Q&A with RAG (Qdrant + AWS Bedrock)") | |
st.write("Enter a question to get answers using baseline and reranked retrieval methods.") | |
# Load environment variables | |
try: | |
load_environment() | |
except ValueError: | |
return | |
# Initialize components | |
documents = load_wikipedia_documents() | |
if not documents: | |
st.error("Cannot proceed without documents") | |
return | |
chunks = split_documents(documents) | |
if not chunks: | |
st.error("Cannot proceed without document chunks") | |
return | |
embeddings = initialize_embeddings() | |
if embeddings is None: | |
st.error("Cannot proceed without embeddings") | |
return | |
vector_store = store_in_qdrant(chunks, embeddings) | |
if vector_store is None: | |
st.error("Cannot proceed without vector store") | |
return | |
llm = initialize_llm() | |
if llm is None: | |
st.error("Cannot proceed without LLM") | |
return | |
baseline_chain = create_rag_chain(vector_store, llm, use_rerank=False) | |
if baseline_chain is None: | |
st.error("Cannot proceed without baseline chain") | |
return | |
rerank_chain = create_rag_chain(vector_store, llm, use_rerank=True) | |
if rerank_chain is None: | |
st.error("Cannot proceed without rerank chain") | |
return | |
# Streamlit input | |
query = st.text_input("Enter your question:", placeholder="e.g., What are the main causes of climate change?") | |
if query: | |
with st.spinner("Processing your query..."): | |
try: | |
baseline_response = baseline_chain.invoke({"question": query}) | |
rerank_response = rerank_chain.invoke({"question": query}) | |
st.subheader("Results") | |
st.write("**Query:**", query) | |
st.write("**Baseline Answer:**") | |
st.write(baseline_response) | |
st.write("**Reranked Answer:**") | |
st.write(rerank_response) | |
except Exception as e: | |
logger.error(f"Error processing query: {e}") | |
st.error(f"Error processing query: {e}") | |
if __name__ == "__main__": | |
main() | |