Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pickle | |
import os | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever | |
from langchain.document_transformers import EmbeddingsRedundantFilter | |
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline | |
from langchain.text_splitter import CharacterTextSplitter | |
from analysis import calculate_word_overlaps, calculate_duplication_rate, cosine_similarity_score, jaccard_similarity_score, display_similarity_results | |
with open("docs_data.pkl", "rb") as file: | |
docs = pickle.load(file) | |
metadata_list = [] | |
unique_metadata_list = [] | |
seen = set() | |
embeddings = HuggingFaceEmbeddings() | |
vectorstore = FAISS.load_local("faiss_index", embeddings) | |
retriever = vectorstore.as_retriever(search_type="similarity") | |
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ") | |
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) | |
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5) | |
pipeline_compressor = DocumentCompressorPipeline( | |
transformers=[splitter, redundant_filter, relevant_filter] | |
) | |
bm25_retriever = BM25Retriever.from_texts(docs) | |
st.title("Document Retrieval App") | |
vecotstore_k = st.number_input("Set k value for Dense Retriever:", value=5, min_value=1, step=1) | |
bm25_k = st.number_input("Set k value for sparse Retriever:", value=2, min_value=1, step=1) | |
retriever.search_kwargs["k"] = vecotstore_k | |
bm25_retriever.k = bm25_k | |
compressed_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever) | |
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever) | |
query = st.text_input("Enter a query:", "what is a horizontal conflict") | |
if st.button("Retrieve Documents"): | |
compressed_ensemble_retriever = EnsembleRetriever(retrievers=[compressed_retriever, bm25_compression_retriever], weights=[0.5, 0.5]) | |
ensemble_retriever = EnsembleRetriever(retrievers=[retriever, bm25_retriever], weights=[0.5, 0.5]) | |
with st.expander("Retrieved Documents"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Without Compression") | |
normal_results = ensemble_retriever.get_relevant_documents(query) | |
for doc in normal_results: | |
st.write(doc.page_content) | |
st.write("---") | |
with col2: | |
st.header("With Compression") | |
compressed_results = compressed_ensemble_retriever.get_relevant_documents(query) | |
for doc in compressed_results: | |
st.write(doc.page_content) | |
st.write("---") | |
if hasattr(doc, 'metadata'): | |
metadata = doc.metadata | |
metadata_list.append(metadata) | |
for metadata in metadata_list: | |
metadata_tuple = tuple(metadata.items()) | |
if metadata_tuple not in seen: | |
unique_metadata_list.append(metadata) | |
seen.add(metadata_tuple) | |
st.write(unique_metadata_list) | |
with st.expander("Analysis"): | |
st.write("Analysis of Retrieval Results") | |
total_words_normal = sum(len(doc.page_content.split()) for doc in normal_results) | |
total_words_compressed = sum(len(doc.page_content.split()) for doc in compressed_results) | |
reduction_percentage = ((total_words_normal - total_words_compressed) / total_words_normal) * 100 | |
col1, col2 = st.columns(2) | |
st.write(f"Total words in documents (Normal): {total_words_normal}") | |
st.write(f"Total words in documents (Compressed): {total_words_compressed}") | |
st.write(f"Reduction Percentage: {reduction_percentage:.2f}%") | |
average_word_overlap_normal = calculate_word_overlaps([doc.page_content for doc in normal_results], query) | |
average_word_overlap_compressed = calculate_word_overlaps([doc.page_content for doc in compressed_results], query) | |
duplication_rate_normal = calculate_duplication_rate([doc.page_content for doc in normal_results]) | |
duplication_rate_compressed = calculate_duplication_rate([doc.page_content for doc in compressed_results]) | |
cosine_scores_normal = cosine_similarity_score([doc.page_content for doc in normal_results], query) | |
jaccard_scores_normal = jaccard_similarity_score([doc.page_content for doc in normal_results], query) | |
cosine_scores_compressed = cosine_similarity_score([doc.page_content for doc in compressed_results], query) | |
jaccard_scores_compressed = jaccard_similarity_score([doc.page_content for doc in compressed_results], query) | |
with col1: | |
st.subheader("Normal") | |
st.write(f"Average Word Overlap: {average_word_overlap_normal:.2f}") | |
st.write(f"Duplication Rate: {duplication_rate_normal:.2%}") | |
st.write("Results without Compression:") | |
display_similarity_results(cosine_scores_normal, jaccard_scores_normal, "") | |
with col2: | |
st.subheader("Compressed") | |
st.write(f"Average Word Overlap: {average_word_overlap_compressed:.2f}") | |
st.write(f"Duplication Rate: {duplication_rate_compressed:.2%}") | |
st.write("Results with Compression:") | |
display_similarity_results(cosine_scores_compressed, jaccard_scores_compressed, "") |