talexm commited on
Commit
6dd2090
·
1 Parent(s): 5a9370b
app.py CHANGED
@@ -5,6 +5,8 @@ from PIL import Image
5
  from rag_sec.document_search_system import DocumentSearchSystem
6
  from chainguard.blockchain_logger import BlockchainLogger
7
  from rag_sec.document_search_system import main
 
 
8
 
9
  # Blockchain Logger
10
  blockchain_logger = BlockchainLogger()
@@ -65,29 +67,44 @@ if st.button("Validate Blockchain Integrity"):
65
 
66
  # Query System
67
  st.subheader("Query Files")
68
- system = main() # Initialize system with Neo4j and load documents
69
 
70
- # Query Input
71
- query = st.text_input("Enter your query", placeholder="E.g., 'Good comedy'")
72
- if st.button("Search"):
73
- if query:
74
- # Process the query
75
- result = system.process_query(query)
76
 
77
- # Display the results
78
- st.write("Query Status:", result.get("status"))
79
- st.write("Query Response:", result.get("response"))
 
 
80
 
81
- if "retrieved_documents" in result:
82
- st.write("Retrieved Documents:")
83
- for doc in result["retrieved_documents"]:
84
- st.markdown(f"- {doc}")
85
 
86
- if "blockchain_details" in result:
87
- st.write("Blockchain Details:")
88
- st.json(result["blockchain_details"])
89
 
90
- if result.get("status") == "rejected":
91
- st.error(f"Query Blocked: {result.get('message')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  else:
93
  st.warning("Please enter a query to search.")
 
 
 
 
 
5
  from rag_sec.document_search_system import DocumentSearchSystem
6
  from chainguard.blockchain_logger import BlockchainLogger
7
  from rag_sec.document_search_system import main
8
+ import streamlit as st
9
+ from rag_sec.document_retriver import DocumentRetriever
10
 
11
  # Blockchain Logger
12
  blockchain_logger = BlockchainLogger()
 
67
 
68
  # Query System
69
  st.subheader("Query Files")
 
70
 
71
+ # Initialize DocumentRetriever
72
+ retriever = DocumentRetriever()
 
 
 
 
73
 
74
+ @st.cache(allow_output_mutation=True)
75
+ def load_retriever():
76
+ """Load documents into the retriever."""
77
+ retriever.load_documents()
78
+ return retriever
79
 
80
+ # Load the retriever and documents
81
+ st.write("Loading documents...")
82
+ retriever = load_retriever()
83
+ st.write("Documents successfully loaded!")
84
 
85
+ # Streamlit UI
86
+ st.title("Document Search App")
87
+ st.subheader("Enter a query to search for related documents")
88
 
89
+ # Query Input
90
+ query = st.text_input("Enter your query (e.g., 'sports news', 'machine learning')")
91
+
92
+ if st.button("Search"):
93
+ if query:
94
+ # Retrieve documents
95
+ results = retriever.retrieve(query)
96
+ if results == ["Document retrieval is not initialized."]:
97
+ st.error("Document retrieval is not initialized. Please reload the app.")
98
+ elif not results:
99
+ st.warning("No relevant documents found for your query.")
100
+ else:
101
+ st.success(f"Found {len(results)} relevant document(s).")
102
+ for idx, doc in enumerate(results, start=1):
103
+ st.write(f"### Document {idx}")
104
+ st.write(doc[:500]) # Display first 500 characters of each document
105
  else:
106
  st.warning("Please enter a query to search.")
107
+
108
+ # Debugging Section
109
+ if st.checkbox("Show Debug Information"):
110
+ st.write(f"Total documents loaded: {len(retriever.documents)}")
rag_sec/document_retriver.py CHANGED
@@ -1,23 +1,17 @@
1
- import faiss
2
- from sklearn.feature_extraction.text import TfidfVectorizer
3
- import numpy as np
4
  from sklearn.datasets import fetch_20newsgroups
5
 
6
  class DocumentRetriever:
7
  def __init__(self):
8
  self.documents = []
9
 
10
- def load_documents(self):
11
- """Load 20 Newsgroups dataset."""
12
  newsgroups_data = fetch_20newsgroups(subset='all')
13
- self.documents = newsgroups_data.data
14
- if not self.documents:
15
- print("No documents loaded!")
16
 
17
  def retrieve(self, query):
18
  """Retrieve documents related to the query."""
19
  if not self.documents:
20
  return ["Document retrieval is not initialized."]
21
- # Simple keyword match (can replace with advanced semantic similarity later)
22
  return [doc for doc in self.documents if query.lower() in doc.lower()]
23
-
 
 
 
 
1
  from sklearn.datasets import fetch_20newsgroups
2
 
3
  class DocumentRetriever:
4
  def __init__(self):
5
  self.documents = []
6
 
7
+ def load_documents(self, subset_size=500):
8
+ """Load a subset of 20 Newsgroups dataset."""
9
  newsgroups_data = fetch_20newsgroups(subset='all')
10
+ self.documents = newsgroups_data.data[:subset_size] # Load only the first `subset_size` documents
11
+ print(f"Loaded {len(self.documents)} documents.")
 
12
 
13
  def retrieve(self, query):
14
  """Retrieve documents related to the query."""
15
  if not self.documents:
16
  return ["Document retrieval is not initialized."]
 
17
  return [doc for doc in self.documents if query.lower() in doc.lower()]
 
rag_sec/document_search_system.py CHANGED
@@ -7,10 +7,10 @@ import sys
7
  from os import path
8
 
9
  sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
10
- from .bad_query_detector import BadQueryDetector
11
- from .query_transformer import QueryTransformer
12
- from .document_retriver import DocumentRetriever
13
- from .senamtic_response_generator import SemanticResponseGenerator
14
 
15
 
16
  class DataTransformer:
@@ -189,9 +189,27 @@ def main():
189
 
190
  return system
191
 
192
-
193
  if __name__ == "__main__":
194
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  # home_dir = Path(os.getenv("HOME", "/"))
197
  # data_dir = home_dir / "data-sets/aclImdb/train"
 
7
  from os import path
8
 
9
  sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
10
+ from bad_query_detector import BadQueryDetector
11
+ from query_transformer import QueryTransformer
12
+ from document_retriver import DocumentRetriever
13
+ from senamtic_response_generator import SemanticResponseGenerator
14
 
15
 
16
  class DataTransformer:
 
189
 
190
  return system
191
 
 
192
  if __name__ == "__main__":
193
+ retriever = DocumentRetriever()
194
+ retriever.load_documents()
195
+
196
+ # Test queries
197
+ queries = [
198
+ "sports news",
199
+ "political debates",
200
+ "machine learning",
201
+ "space exploration"
202
+ ]
203
+
204
+ for query in queries:
205
+ print(f"\nQuery: {query}")
206
+ results = retriever.retrieve(query)
207
+ for idx, doc in enumerate(results, start=1):
208
+ print(f"\nResult {idx}:\n{doc[:500]}...\n") # Show first 500 characters of each document
209
+
210
+
211
+ # if __name__ == "__main__":
212
+ # main()
213
 
214
  # home_dir = Path(os.getenv("HOME", "/"))
215
  # data_dir = home_dir / "data-sets/aclImdb/train"