Spaces:
Running
Running
talexm
commited on
Commit
·
6dd2090
1
Parent(s):
5a9370b
update
Browse files- app.py +36 -19
- rag_sec/document_retriver.py +4 -10
- rag_sec/document_search_system.py +24 -6
app.py
CHANGED
@@ -5,6 +5,8 @@ from PIL import Image
|
|
5 |
from rag_sec.document_search_system import DocumentSearchSystem
|
6 |
from chainguard.blockchain_logger import BlockchainLogger
|
7 |
from rag_sec.document_search_system import main
|
|
|
|
|
8 |
|
9 |
# Blockchain Logger
|
10 |
blockchain_logger = BlockchainLogger()
|
@@ -65,29 +67,44 @@ if st.button("Validate Blockchain Integrity"):
|
|
65 |
|
66 |
# Query System
|
67 |
st.subheader("Query Files")
|
68 |
-
system = main() # Initialize system with Neo4j and load documents
|
69 |
|
70 |
-
#
|
71 |
-
|
72 |
-
if st.button("Search"):
|
73 |
-
if query:
|
74 |
-
# Process the query
|
75 |
-
result = system.process_query(query)
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
else:
|
93 |
st.warning("Please enter a query to search.")
|
|
|
|
|
|
|
|
|
|
5 |
from rag_sec.document_search_system import DocumentSearchSystem
|
6 |
from chainguard.blockchain_logger import BlockchainLogger
|
7 |
from rag_sec.document_search_system import main
|
8 |
+
import streamlit as st
|
9 |
+
from rag_sec.document_retriver import DocumentRetriever
|
10 |
|
11 |
# Blockchain Logger
|
12 |
blockchain_logger = BlockchainLogger()
|
|
|
67 |
|
68 |
# Query System
|
69 |
st.subheader("Query Files")
|
|
|
70 |
|
71 |
+
# Initialize DocumentRetriever
|
72 |
+
retriever = DocumentRetriever()
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
@st.cache(allow_output_mutation=True)
|
75 |
+
def load_retriever():
|
76 |
+
"""Load documents into the retriever."""
|
77 |
+
retriever.load_documents()
|
78 |
+
return retriever
|
79 |
|
80 |
+
# Load the retriever and documents
|
81 |
+
st.write("Loading documents...")
|
82 |
+
retriever = load_retriever()
|
83 |
+
st.write("Documents successfully loaded!")
|
84 |
|
85 |
+
# Streamlit UI
|
86 |
+
st.title("Document Search App")
|
87 |
+
st.subheader("Enter a query to search for related documents")
|
88 |
|
89 |
+
# Query Input
|
90 |
+
query = st.text_input("Enter your query (e.g., 'sports news', 'machine learning')")
|
91 |
+
|
92 |
+
if st.button("Search"):
|
93 |
+
if query:
|
94 |
+
# Retrieve documents
|
95 |
+
results = retriever.retrieve(query)
|
96 |
+
if results == ["Document retrieval is not initialized."]:
|
97 |
+
st.error("Document retrieval is not initialized. Please reload the app.")
|
98 |
+
elif not results:
|
99 |
+
st.warning("No relevant documents found for your query.")
|
100 |
+
else:
|
101 |
+
st.success(f"Found {len(results)} relevant document(s).")
|
102 |
+
for idx, doc in enumerate(results, start=1):
|
103 |
+
st.write(f"### Document {idx}")
|
104 |
+
st.write(doc[:500]) # Display first 500 characters of each document
|
105 |
else:
|
106 |
st.warning("Please enter a query to search.")
|
107 |
+
|
108 |
+
# Debugging Section
|
109 |
+
if st.checkbox("Show Debug Information"):
|
110 |
+
st.write(f"Total documents loaded: {len(retriever.documents)}")
|
rag_sec/document_retriver.py
CHANGED
@@ -1,23 +1,17 @@
|
|
1 |
-
import faiss
|
2 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
3 |
-
import numpy as np
|
4 |
from sklearn.datasets import fetch_20newsgroups
|
5 |
|
6 |
class DocumentRetriever:
|
7 |
def __init__(self):
|
8 |
self.documents = []
|
9 |
|
10 |
-
def load_documents(self):
|
11 |
-
"""Load 20 Newsgroups dataset."""
|
12 |
newsgroups_data = fetch_20newsgroups(subset='all')
|
13 |
-
self.documents = newsgroups_data.data
|
14 |
-
|
15 |
-
print("No documents loaded!")
|
16 |
|
17 |
def retrieve(self, query):
|
18 |
"""Retrieve documents related to the query."""
|
19 |
if not self.documents:
|
20 |
return ["Document retrieval is not initialized."]
|
21 |
-
# Simple keyword match (can replace with advanced semantic similarity later)
|
22 |
return [doc for doc in self.documents if query.lower() in doc.lower()]
|
23 |
-
|
|
|
|
|
|
|
|
|
1 |
from sklearn.datasets import fetch_20newsgroups
|
2 |
|
3 |
class DocumentRetriever:
|
4 |
def __init__(self):
|
5 |
self.documents = []
|
6 |
|
7 |
+
def load_documents(self, subset_size=500):
|
8 |
+
"""Load a subset of 20 Newsgroups dataset."""
|
9 |
newsgroups_data = fetch_20newsgroups(subset='all')
|
10 |
+
self.documents = newsgroups_data.data[:subset_size] # Load only the first `subset_size` documents
|
11 |
+
print(f"Loaded {len(self.documents)} documents.")
|
|
|
12 |
|
13 |
def retrieve(self, query):
|
14 |
"""Retrieve documents related to the query."""
|
15 |
if not self.documents:
|
16 |
return ["Document retrieval is not initialized."]
|
|
|
17 |
return [doc for doc in self.documents if query.lower() in doc.lower()]
|
|
rag_sec/document_search_system.py
CHANGED
@@ -7,10 +7,10 @@ import sys
|
|
7 |
from os import path
|
8 |
|
9 |
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
|
15 |
|
16 |
class DataTransformer:
|
@@ -189,9 +189,27 @@ def main():
|
|
189 |
|
190 |
return system
|
191 |
|
192 |
-
|
193 |
if __name__ == "__main__":
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
# home_dir = Path(os.getenv("HOME", "/"))
|
197 |
# data_dir = home_dir / "data-sets/aclImdb/train"
|
|
|
7 |
from os import path
|
8 |
|
9 |
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
10 |
+
from bad_query_detector import BadQueryDetector
|
11 |
+
from query_transformer import QueryTransformer
|
12 |
+
from document_retriver import DocumentRetriever
|
13 |
+
from senamtic_response_generator import SemanticResponseGenerator
|
14 |
|
15 |
|
16 |
class DataTransformer:
|
|
|
189 |
|
190 |
return system
|
191 |
|
|
|
192 |
if __name__ == "__main__":
|
193 |
+
retriever = DocumentRetriever()
|
194 |
+
retriever.load_documents()
|
195 |
+
|
196 |
+
# Test queries
|
197 |
+
queries = [
|
198 |
+
"sports news",
|
199 |
+
"political debates",
|
200 |
+
"machine learning",
|
201 |
+
"space exploration"
|
202 |
+
]
|
203 |
+
|
204 |
+
for query in queries:
|
205 |
+
print(f"\nQuery: {query}")
|
206 |
+
results = retriever.retrieve(query)
|
207 |
+
for idx, doc in enumerate(results, start=1):
|
208 |
+
print(f"\nResult {idx}:\n{doc[:500]}...\n") # Show first 500 characters of each document
|
209 |
+
|
210 |
+
|
211 |
+
# if __name__ == "__main__":
|
212 |
+
# main()
|
213 |
|
214 |
# home_dir = Path(os.getenv("HOME", "/"))
|
215 |
# data_dir = home_dir / "data-sets/aclImdb/train"
|