ethanrom's picture
Update app.py
d1204dd
import streamlit as st
import pickle
import os
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
from langchain.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain.text_splitter import CharacterTextSplitter
from analysis import calculate_word_overlaps, calculate_duplication_rate, cosine_similarity_score, jaccard_similarity_score, display_similarity_results
with open("docs_data.pkl", "rb") as file:
docs = pickle.load(file)
metadata_list = []
unique_metadata_list = []
seen = set()
embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.load_local("faiss_index", embeddings)
retriever = vectorstore.as_retriever(search_type="similarity")
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, redundant_filter, relevant_filter]
)
bm25_retriever = BM25Retriever.from_texts(docs)
st.title("Document Retrieval App")
vecotstore_k = st.number_input("Set k value for Dense Retriever:", value=5, min_value=1, step=1)
bm25_k = st.number_input("Set k value for sparse Retriever:", value=2, min_value=1, step=1)
retriever.search_kwargs["k"] = vecotstore_k
bm25_retriever.k = bm25_k
compressed_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)
query = st.text_input("Enter a query:", "what is a horizontal conflict")
if st.button("Retrieve Documents"):
compressed_ensemble_retriever = EnsembleRetriever(retrievers=[compressed_retriever, bm25_compression_retriever], weights=[0.5, 0.5])
ensemble_retriever = EnsembleRetriever(retrievers=[retriever, bm25_retriever], weights=[0.5, 0.5])
with st.expander("Retrieved Documents"):
col1, col2 = st.columns(2)
with col1:
st.header("Without Compression")
normal_results = ensemble_retriever.get_relevant_documents(query)
for doc in normal_results:
st.write(doc.page_content)
st.write("---")
with col2:
st.header("With Compression")
compressed_results = compressed_ensemble_retriever.get_relevant_documents(query)
for doc in compressed_results:
st.write(doc.page_content)
st.write("---")
if hasattr(doc, 'metadata'):
metadata = doc.metadata
metadata_list.append(metadata)
for metadata in metadata_list:
metadata_tuple = tuple(metadata.items())
if metadata_tuple not in seen:
unique_metadata_list.append(metadata)
seen.add(metadata_tuple)
st.write(unique_metadata_list)
with st.expander("Analysis"):
st.write("Analysis of Retrieval Results")
total_words_normal = sum(len(doc.page_content.split()) for doc in normal_results)
total_words_compressed = sum(len(doc.page_content.split()) for doc in compressed_results)
reduction_percentage = ((total_words_normal - total_words_compressed) / total_words_normal) * 100
col1, col2 = st.columns(2)
st.write(f"Total words in documents (Normal): {total_words_normal}")
st.write(f"Total words in documents (Compressed): {total_words_compressed}")
st.write(f"Reduction Percentage: {reduction_percentage:.2f}%")
average_word_overlap_normal = calculate_word_overlaps([doc.page_content for doc in normal_results], query)
average_word_overlap_compressed = calculate_word_overlaps([doc.page_content for doc in compressed_results], query)
duplication_rate_normal = calculate_duplication_rate([doc.page_content for doc in normal_results])
duplication_rate_compressed = calculate_duplication_rate([doc.page_content for doc in compressed_results])
cosine_scores_normal = cosine_similarity_score([doc.page_content for doc in normal_results], query)
jaccard_scores_normal = jaccard_similarity_score([doc.page_content for doc in normal_results], query)
cosine_scores_compressed = cosine_similarity_score([doc.page_content for doc in compressed_results], query)
jaccard_scores_compressed = jaccard_similarity_score([doc.page_content for doc in compressed_results], query)
with col1:
st.subheader("Normal")
st.write(f"Average Word Overlap: {average_word_overlap_normal:.2f}")
st.write(f"Duplication Rate: {duplication_rate_normal:.2%}")
st.write("Results without Compression:")
display_similarity_results(cosine_scores_normal, jaccard_scores_normal, "")
with col2:
st.subheader("Compressed")
st.write(f"Average Word Overlap: {average_word_overlap_compressed:.2f}")
st.write(f"Duplication Rate: {duplication_rate_compressed:.2%}")
st.write("Results with Compression:")
display_similarity_results(cosine_scores_compressed, jaccard_scores_compressed, "")