chat-w-csv / test.py
DrishtiSharma's picture
Update test.py
0e95be6 verified
import streamlit as st
import pandas as pd
import os
import traceback
from dotenv import load_dotenv
from llama_index.readers.file.paged_csv.base import PagedCSVReader
from llama_index.core import Settings, VectorStoreIndex
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from langchain_community.vectorstores import FAISS as LangChainFAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.documents import Document
import faiss
import tempfile
# Load environment variables
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# βœ… Check OpenAI API Key
if not os.getenv("OPENAI_API_KEY"):
st.error("⚠️ OpenAI API Key is missing! Please check your .env file or environment variables.")
# βœ… Ensure OpenAI Embeddings match FAISS dimensions
embedding_function = OpenAIEmbeddings()
test_vector = embedding_function.embed_query("test")
faiss_dimension = len(test_vector)
# βœ… Update global settings for LlamaIndex
Settings.llm = OpenAI(model="gpt-4o")
Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=faiss_dimension)
# Streamlit app
st.title("Chat with CSV Files - LangChain vs LlamaIndex")
# File uploader
uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
if uploaded_file:
try:
# Read and preview CSV data using pandas
data = pd.read_csv(uploaded_file)
st.write("Preview of uploaded data:")
st.dataframe(data)
# Save the uploaded file to a temporary location
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as temp_file:
temp_file_path = temp_file.name
data.to_csv(temp_file.name, index=False, encoding="utf-8")
temp_file.flush()
# Tabs for LangChain and LlamaIndex
tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
# βœ… LangChain Processing
with tab1:
st.subheader("LangChain Query")
try:
# βœ… Store each row as a single document
st.write("Processing CSV with a custom loader...")
documents = []
for _, row in data.iterrows():
content = " | ".join([f"{col}: {row[col]}" for col in data.columns]) # βœ… Store entire row as a document
doc = Document(page_content=content)
documents.append(doc)
# βœ… Create FAISS VectorStore
st.write(f"βœ… Initializing FAISS with dimension: {faiss_dimension}")
langchain_index = faiss.IndexFlatL2(faiss_dimension)
docstore = InMemoryDocstore()
index_to_docstore_id = {}
langchain_vector_store = LangChainFAISS(
embedding_function=embedding_function,
index=langchain_index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
)
# βœ… Ensure documents are added correctly
try:
langchain_vector_store.add_documents(documents)
st.write("βœ… Documents successfully added to FAISS VectorStore.")
except Exception as e:
st.error(f"Error adding documents to FAISS: {e}")
st.text(traceback.format_exc())
# βœ… Limit number of retrieved documents
retriever = langchain_vector_store.as_retriever(search_kwargs={"k": 15}) # Fetch 15 docs instead of 5
# βœ… Query Processing
query = st.text_input("Ask a question about your data (LangChain):")
if query:
try:
retrieved_docs = retriever.get_relevant_documents(query)
retrieved_context = "\n\n".join([doc.page_content for doc in retrieved_docs])
retrieved_context = retrieved_context[:3000]
# βœ… Show retrieved context for debugging
st.write("πŸ” **Retrieved Context Preview:**")
st.text(retrieved_context)
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. Keep the answer concise.\n\n"
f"{retrieved_context}"
)
# Simulate LangChain RAG Chain (update actual logic if necessary)
st.write("πŸš€ Query processed successfully.")
st.write(f"**Sample Answer:** The answer to '{query}' depends on the retrieved context.")
except Exception as e:
error_message = traceback.format_exc()
st.error(f"Error processing query: {e}")
st.text(error_message)
except Exception as e:
error_message = traceback.format_exc()
st.error(f"Error processing with LangChain: {e}")
st.text(error_message)
except Exception as e:
error_message = traceback.format_exc()
st.error(f"Error reading uploaded file: {e}")
st.text(error_message)