Spaces:
Configuration error
Configuration error
Merge pull request #1 from noelty/dev
Browse files- app.py +77 -0
- collection.py +32 -0
- documents.py +39 -0
- indexing.py +128 -0
- new.py +14 -0
- preprocess.py +39 -0
- querying.py +117 -0
- requirements.txt +6 -0
- tt.xml +131 -0
app.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import nltk
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from documents import process_docx, process_pdf, process_txt
|
| 5 |
+
from indexing import index_document
|
| 6 |
+
from querying import query_documents
|
| 7 |
+
import preprocess
|
| 8 |
+
|
| 9 |
+
# Download required NLTK data (do this *once* when the app starts)
|
| 10 |
+
try:
|
| 11 |
+
nltk.data.find("corpora/wordnet")
|
| 12 |
+
except LookupError:
|
| 13 |
+
nltk.download("wordnet")
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
nltk.data.find("corpora/stopwords")
|
| 17 |
+
except LookupError:
|
| 18 |
+
nltk.download("stopwords")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
nltk.data.find("tokenizers/punkt")
|
| 22 |
+
except LookupError:
|
| 23 |
+
nltk.download("punkt")
|
| 24 |
+
|
| 25 |
+
UPLOAD_FOLDER = 'uploads'
|
| 26 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def process_and_query(file, query_text):
|
| 30 |
+
"""
|
| 31 |
+
Processes a document, indexes it, and performs a query. This is the
|
| 32 |
+
main function called by the Gradio interface.
|
| 33 |
+
"""
|
| 34 |
+
if not file:
|
| 35 |
+
return "No file uploaded", []
|
| 36 |
+
|
| 37 |
+
file_path = file.name # Gradio passes a NamedTemporaryFile
|
| 38 |
+
|
| 39 |
+
# Process file
|
| 40 |
+
if file.name.endswith('.docx'):
|
| 41 |
+
text = process_docx(file_path)
|
| 42 |
+
elif file.name.endswith('.pdf'):
|
| 43 |
+
text = process_pdf(file_path)
|
| 44 |
+
elif file.name.endswith('.txt'):
|
| 45 |
+
text = process_txt(file_path)
|
| 46 |
+
else:
|
| 47 |
+
return "Unsupported file type", []
|
| 48 |
+
preprocessed_text = preprocess.preprocess_text(text['text'])
|
| 49 |
+
print (preprocessed_text) #ADD THIS
|
| 50 |
+
|
| 51 |
+
# Index the document
|
| 52 |
+
index_result = index_document("documents", file.name, preprocessed_text)
|
| 53 |
+
|
| 54 |
+
# Perform the query
|
| 55 |
+
query_results = query_documents("documents", query_text)
|
| 56 |
+
|
| 57 |
+
return f"Indexing result: {index_result}", query_results
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Gradio Interface
|
| 61 |
+
iface = gr.Interface(
|
| 62 |
+
fn=process_and_query,
|
| 63 |
+
inputs=[
|
| 64 |
+
gr.File(label="Upload Document"),
|
| 65 |
+
gr.Textbox(label="Enter Query")
|
| 66 |
+
],
|
| 67 |
+
outputs=[
|
| 68 |
+
gr.Textbox(label="Indexing Result"),
|
| 69 |
+
gr.JSON(label="Query Results") # Display query results as JSON
|
| 70 |
+
],
|
| 71 |
+
title="Document Processing and Query",
|
| 72 |
+
description="Upload a document (docx, pdf, or txt), enter a query, and get the results."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
iface.launch()
|
collection.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from qdrant_client import QdrantClient
|
| 2 |
+
from qdrant_client.http.models import VectorParams, Distance
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Connect to the local Qdrant instance (using environment variables)
|
| 6 |
+
QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost")
|
| 7 |
+
QDRANT_PORT = int(os.environ.get("QDRANT_PORT", 6333))
|
| 8 |
+
|
| 9 |
+
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
| 10 |
+
|
| 11 |
+
# Define collection name and vector parameters (using environment variables)
|
| 12 |
+
COLLECTION_NAME = os.environ.get("QDRANT_COLLECTION_NAME", "documents")
|
| 13 |
+
VECTOR_SIZE = int(os.environ.get("QDRANT_VECTOR_SIZE", 384)) # Adjust based on your embeddings
|
| 14 |
+
|
| 15 |
+
#Map the string to the Distance Enum.
|
| 16 |
+
DISTANCE_METRIC_STRING = os.environ.get("QDRANT_DISTANCE_METRIC", "Cosine").lower()
|
| 17 |
+
DISTANCE_METRIC = Distance.COSINE
|
| 18 |
+
if(DISTANCE_METRIC_STRING == "euclid"):
|
| 19 |
+
DISTANCE_METRIC = Distance.EUCLID
|
| 20 |
+
elif(DISTANCE_METRIC_STRING == "dot"):
|
| 21 |
+
DISTANCE_METRIC = Distance.DOT
|
| 22 |
+
|
| 23 |
+
# Create the collection
|
| 24 |
+
try:
|
| 25 |
+
client.recreate_collection(
|
| 26 |
+
collection_name=COLLECTION_NAME,
|
| 27 |
+
vectors_config=VectorParams(size=VECTOR_SIZE, distance=DISTANCE_METRIC),
|
| 28 |
+
)
|
| 29 |
+
print(f"Collection '{COLLECTION_NAME}' created/recreated successfully!")
|
| 30 |
+
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error creating/recreating collection '{COLLECTION_NAME}': {e}")
|
documents.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import docx
|
| 2 |
+
import fitz # PyMuPDF
|
| 3 |
+
|
| 4 |
+
def process_docx(file_path):
|
| 5 |
+
"""Extracts text from a .docx file."""
|
| 6 |
+
try:
|
| 7 |
+
doc = docx.Document(file_path)
|
| 8 |
+
full_text = [para.text for para in doc.paragraphs]
|
| 9 |
+
text = '\n'.join(full_text)
|
| 10 |
+
|
| 11 |
+
print(f"Extracted {len(full_text)} paragraphs from DOCX") # Debugging
|
| 12 |
+
print(f"Extracted Text: {text[:500]}...") # Print first 500 chars
|
| 13 |
+
|
| 14 |
+
return {'text': text.strip()}
|
| 15 |
+
except Exception as e:
|
| 16 |
+
return {'error': str(e)}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def process_pdf(file_path):
|
| 20 |
+
"""Extracts text from a .pdf file."""
|
| 21 |
+
try:
|
| 22 |
+
pdf = fitz.open(file_path)
|
| 23 |
+
text = ""
|
| 24 |
+
for page in pdf:
|
| 25 |
+
text += page.get_text()
|
| 26 |
+
pdf.close()
|
| 27 |
+
return {'text': text.strip()} # Return as a dictionary
|
| 28 |
+
except Exception as e:
|
| 29 |
+
return {'error': str(e)}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def process_txt(file_path):
|
| 33 |
+
"""Extracts text from a .txt file."""
|
| 34 |
+
try:
|
| 35 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 36 |
+
text = f.read()
|
| 37 |
+
return {'text': text.strip()} # Return as a dictionary
|
| 38 |
+
except Exception as e:
|
| 39 |
+
return {'error': str(e)}
|
indexing.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
import nltk
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
from qdrant_client.http.models import VectorParams, Distance
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
|
| 9 |
+
# Download tokenizer for sentence splitting
|
| 10 |
+
nltk.download("punkt")
|
| 11 |
+
from nltk.tokenize import sent_tokenize
|
| 12 |
+
|
| 13 |
+
# Initialize Qdrant client and model
|
| 14 |
+
qdrant_client = QdrantClient(host="localhost", port=6333)
|
| 15 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 16 |
+
|
| 17 |
+
# Set up logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
|
| 20 |
+
def create_collection_if_not_exists(collection_name):
|
| 21 |
+
"""Creates a Qdrant collection if it doesn't already exist."""
|
| 22 |
+
try:
|
| 23 |
+
collections_response = qdrant_client.get_collections()
|
| 24 |
+
existing_collections = [col.name for col in collections_response.collections]
|
| 25 |
+
|
| 26 |
+
if collection_name not in existing_collections:
|
| 27 |
+
qdrant_client.create_collection(
|
| 28 |
+
collection_name=collection_name,
|
| 29 |
+
vectors_config=VectorParams(
|
| 30 |
+
size=384, # Ensure this matches embedding dimensions
|
| 31 |
+
distance=Distance.COSINE
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
+
logging.info(f"Collection '{collection_name}' created.")
|
| 35 |
+
else:
|
| 36 |
+
logging.info(f"Collection '{collection_name}' already exists.")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logging.error(f" Error creating collection '{collection_name}': {e}")
|
| 39 |
+
raise
|
| 40 |
+
|
| 41 |
+
def split_text_into_chunks(text, max_chunk_size=256):
|
| 42 |
+
"""
|
| 43 |
+
Splits text into smaller, manageable chunks for indexing.
|
| 44 |
+
- Uses newline (`\n`) splitting if available.
|
| 45 |
+
- Falls back to `sent_tokenize()` if necessary.
|
| 46 |
+
- Splits large chunks further into smaller ones (max 256 tokens).
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
text (str): Full document text.
|
| 50 |
+
max_chunk_size (int): Maximum token length per chunk.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
list: List of properly split chunks.
|
| 54 |
+
"""
|
| 55 |
+
# Try splitting by newlines if present
|
| 56 |
+
if "\n" in text:
|
| 57 |
+
chunks = [s.strip() for s in text.split("\n") if s.strip()]
|
| 58 |
+
else:
|
| 59 |
+
# Otherwise, use sentence tokenization
|
| 60 |
+
chunks = sent_tokenize(text)
|
| 61 |
+
|
| 62 |
+
# Ensure chunks are not too large (Break long sentences)
|
| 63 |
+
final_chunks = []
|
| 64 |
+
for chunk in chunks:
|
| 65 |
+
if len(chunk) > max_chunk_size:
|
| 66 |
+
# Further split large chunks at punctuation
|
| 67 |
+
split_sub_chunks = re.split(r'(?<=[.?!])\s+', chunk) # Split at sentence-ending punctuation
|
| 68 |
+
final_chunks.extend([s.strip() for s in split_sub_chunks if s.strip()])
|
| 69 |
+
else:
|
| 70 |
+
final_chunks.append(chunk)
|
| 71 |
+
|
| 72 |
+
logging.info(f" Split document into {len(final_chunks)} chunks.")
|
| 73 |
+
return final_chunks
|
| 74 |
+
|
| 75 |
+
def index_document(collection_name, document_id, text, batch_size=100):
|
| 76 |
+
"""
|
| 77 |
+
Indexes document text into Qdrant with improved chunking.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
collection_name (str): Name of the collection.
|
| 81 |
+
document_id (str): ID of the document.
|
| 82 |
+
text (str): Full document text.
|
| 83 |
+
batch_size (int): Number of chunks to process in a single batch.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
dict: Status of the indexing operation.
|
| 87 |
+
"""
|
| 88 |
+
try:
|
| 89 |
+
create_collection_if_not_exists(collection_name)
|
| 90 |
+
|
| 91 |
+
# 🔹 Improved chunking logic
|
| 92 |
+
chunks = split_text_into_chunks(text)
|
| 93 |
+
|
| 94 |
+
if not chunks:
|
| 95 |
+
logging.warning(" No valid chunks extracted for indexing.")
|
| 96 |
+
return {"status": "error", "message": "No valid chunks extracted"}
|
| 97 |
+
|
| 98 |
+
# 🔹 Process chunks in batches
|
| 99 |
+
for i in range(0, len(chunks), batch_size):
|
| 100 |
+
batch_chunks = chunks[i:i + batch_size]
|
| 101 |
+
embeddings = model.encode(batch_chunks).tolist()
|
| 102 |
+
|
| 103 |
+
points = []
|
| 104 |
+
for idx, (chunk, embedding) in enumerate(zip(batch_chunks, embeddings)):
|
| 105 |
+
chunk_id = str(uuid.uuid4())
|
| 106 |
+
|
| 107 |
+
payload = {
|
| 108 |
+
"document_id": document_id,
|
| 109 |
+
"text": chunk,
|
| 110 |
+
"chunk_index": i + idx,
|
| 111 |
+
"file_name": document_id
|
| 112 |
+
}
|
| 113 |
+
points.append({
|
| 114 |
+
"id": chunk_id,
|
| 115 |
+
"vector": embedding,
|
| 116 |
+
"payload": payload
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
# Upsert the batch into Qdrant
|
| 120 |
+
qdrant_client.upsert(collection_name=collection_name, points=points)
|
| 121 |
+
logging.info(f" Indexed batch {i // batch_size + 1} ({len(batch_chunks)} chunks).")
|
| 122 |
+
|
| 123 |
+
logging.info(f" Successfully indexed {len(chunks)} chunks for document '{document_id}'.")
|
| 124 |
+
return {"status": "success", "chunks": len(chunks)}
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logging.error(f"Error indexing document '{document_id}': {e}")
|
| 128 |
+
return {"status": "error", "message": str(e)}
|
new.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from qdrant_client import QdrantClient
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# Connect to Qdrant
|
| 5 |
+
client = QdrantClient(host="localhost", port=6333)
|
| 6 |
+
|
| 7 |
+
# List all collections
|
| 8 |
+
collections = client.get_collections()
|
| 9 |
+
print("Available Collections:", collections)
|
| 10 |
+
|
| 11 |
+
# Count indexed documents
|
| 12 |
+
collection_name = "documents" # Change if needed
|
| 13 |
+
info = client.get_collection(collection_name)
|
| 14 |
+
print("Collection Info:", info)
|
preprocess.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import nltk
|
| 3 |
+
from nltk.corpus import stopwords
|
| 4 |
+
from nltk.tokenize import word_tokenize
|
| 5 |
+
from nltk.stem import WordNetLemmatizer
|
| 6 |
+
import string # Import the string module
|
| 7 |
+
|
| 8 |
+
# Initialize lemmatizer and stopwords
|
| 9 |
+
lemmatizer = WordNetLemmatizer()
|
| 10 |
+
stop_words = set(stopwords.words('english'))
|
| 11 |
+
|
| 12 |
+
# Text preprocessing function
|
| 13 |
+
def preprocess_text(text):
|
| 14 |
+
# Convert text to lowercase
|
| 15 |
+
text = text.lower()
|
| 16 |
+
|
| 17 |
+
# Normalize line breaks and remove unnecessary spaces
|
| 18 |
+
text = re.sub(r'\s+', ' ', text.strip())
|
| 19 |
+
|
| 20 |
+
# Split alphanumeric combinations (e.g., "hello1234world" -> "hello 1234 world")
|
| 21 |
+
text = re.sub(r'([a-zA-Z]+)(\d+)', r'\1 \2', text)
|
| 22 |
+
text = re.sub(r'(\d+)([a-zA-Z]+)', r'\1 \2', text)
|
| 23 |
+
|
| 24 |
+
# Tokenize the text into words, numbers, and special characters
|
| 25 |
+
tokens = word_tokenize(text)
|
| 26 |
+
|
| 27 |
+
# Process tokens: lemmatize words, keep numbers and special characters
|
| 28 |
+
cleaned_tokens = []
|
| 29 |
+
for token in tokens:
|
| 30 |
+
if token.isalpha(): # Alphabetic words
|
| 31 |
+
if token not in stop_words:
|
| 32 |
+
cleaned_tokens.append(lemmatizer.lemmatize(token))
|
| 33 |
+
elif token.isnumeric(): # Numbers
|
| 34 |
+
cleaned_tokens.append(token)
|
| 35 |
+
elif not token.isalnum() and token not in string.punctuation: # Special characters (excluding punctuation)
|
| 36 |
+
cleaned_tokens.append(token)
|
| 37 |
+
|
| 38 |
+
# Join the tokens back into a single string
|
| 39 |
+
return ' '.join(cleaned_tokens)
|
querying.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
import preprocess
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
+
|
| 9 |
+
# Configure Logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
|
| 12 |
+
# Load Qdrant Configuration from Environment
|
| 13 |
+
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
|
| 14 |
+
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))
|
| 15 |
+
|
| 16 |
+
# Initialize Qdrant Client
|
| 17 |
+
qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
| 18 |
+
|
| 19 |
+
# Load Sentence Transformer for Query Embeddings
|
| 20 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 21 |
+
|
| 22 |
+
# Load GPT-2 from Hugging Face
|
| 23 |
+
GPT2_MODEL_NAME = "gpt2" # You can also use "gpt2-medium", "gpt2-large", "gpt2-xl" for larger versions
|
| 24 |
+
tokenizer = AutoTokenizer.from_pretrained(GPT2_MODEL_NAME)
|
| 25 |
+
gpt2_model = AutoModelForCausalLM.from_pretrained(
|
| 26 |
+
GPT2_MODEL_NAME,
|
| 27 |
+
torch_dtype=torch.float16, # Lower memory usage
|
| 28 |
+
device_map="auto" # Auto-select GPU if available
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Function to Generate Answer Using GPT-2
|
| 32 |
+
def generate_answer(query, context):
|
| 33 |
+
"""Generates a response using GPT-2 based on the retrieved context."""
|
| 34 |
+
if not context.strip():
|
| 35 |
+
return "I couldn't find relevant information."
|
| 36 |
+
|
| 37 |
+
prompt = f"""
|
| 38 |
+
Context: {context}
|
| 39 |
+
|
| 40 |
+
Question: {query}
|
| 41 |
+
|
| 42 |
+
Answer:
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(gpt2_model.device)
|
| 46 |
+
outputs = gpt2_model.generate(
|
| 47 |
+
**inputs,
|
| 48 |
+
max_new_tokens=200,
|
| 49 |
+
temperature=0.7,
|
| 50 |
+
top_p=0.9
|
| 51 |
+
)
|
| 52 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 53 |
+
|
| 54 |
+
# Function to Query Documents from Qdrant
|
| 55 |
+
def query_documents(collection_name, user_query, top_k=5, score_threshold=0.5):
|
| 56 |
+
"""Queries Qdrant, retrieves matching documents, and generates an answer using GPT-2."""
|
| 57 |
+
try:
|
| 58 |
+
logging.info(f"🔍 Original Query: {user_query}")
|
| 59 |
+
processed_query = preprocess.preprocess_text(user_query)
|
| 60 |
+
logging.info(f" Preprocessed Query: {processed_query}")
|
| 61 |
+
|
| 62 |
+
# Generate Query Embedding
|
| 63 |
+
query_vector = embedding_model.encode(processed_query).tolist()
|
| 64 |
+
|
| 65 |
+
# Search in Qdrant
|
| 66 |
+
search_results = qdrant_client.search(
|
| 67 |
+
collection_name=collection_name,
|
| 68 |
+
query_vector=query_vector,
|
| 69 |
+
limit=top_k,
|
| 70 |
+
with_payload=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if not search_results:
|
| 74 |
+
logging.warning(" No results found. Try increasing top_k or checking indexing.")
|
| 75 |
+
|
| 76 |
+
# Filter Results
|
| 77 |
+
filtered_results = [
|
| 78 |
+
{
|
| 79 |
+
"id": res.id,
|
| 80 |
+
"score": res.score,
|
| 81 |
+
"text": res.payload.get("text", ""),
|
| 82 |
+
}
|
| 83 |
+
for res in search_results if res.score >= score_threshold and "text" in res.payload
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
# Extract Context for Answer Generation
|
| 87 |
+
context = " ".join(res["text"] for res in filtered_results) or "No relevant information found."
|
| 88 |
+
answer = generate_answer(user_query, context)
|
| 89 |
+
|
| 90 |
+
return {"answer": answer, "chunks": filtered_results}
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logging.error(f"Error during query: {e}")
|
| 94 |
+
return {"error": str(e)}
|
| 95 |
+
|
| 96 |
+
# Command-Line Execution
|
| 97 |
+
if _name_ == "_main_":
|
| 98 |
+
import argparse
|
| 99 |
+
|
| 100 |
+
parser = argparse.ArgumentParser(description="Query documents with GPT-2")
|
| 101 |
+
parser.add_argument("--collection", type=str, default="documents", help="Qdrant collection name")
|
| 102 |
+
parser.add_argument("--query", type=str, required=True, help="Your search query")
|
| 103 |
+
parser.add_argument("--top-k", type=int, default=3, help="Number of results to return")
|
| 104 |
+
args = parser.parse_args()
|
| 105 |
+
|
| 106 |
+
logging.info(f"Querying for: '{args.query}'")
|
| 107 |
+
result = query_documents(args.collection, args.query, args.top_k)
|
| 108 |
+
|
| 109 |
+
if "error" in result:
|
| 110 |
+
logging.error(f" Error: {result['error']}")
|
| 111 |
+
else:
|
| 112 |
+
logging.info("\n=== Generated Answer ===")
|
| 113 |
+
print(result["answer"])
|
| 114 |
+
|
| 115 |
+
logging.info("\n=== Relevant Chunks ===")
|
| 116 |
+
for i, chunk in enumerate(result["chunks"]):
|
| 117 |
+
print(f"\nChunk {i+1} (Score: {chunk['score']:.3f}):\n{chunk['text']}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask
|
| 2 |
+
python-docx
|
| 3 |
+
PyPDF2
|
| 4 |
+
sentence-transformers
|
| 5 |
+
gradio
|
| 6 |
+
nltk
|
tt.xml
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<mxfile host="app.diagrams.net" modified="2024-05-16T16:32:12.198Z" agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" version="24.5.3" etag="4Yj3hH6JYlqVxI2xUZbO" type="device">
|
| 2 |
+
<diagram name="Page-1" id="XIDM6lB2L0j4NQYgTgD7">
|
| 3 |
+
<mxGraphModel dx="1386" dy="778" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
|
| 4 |
+
<root>
|
| 5 |
+
<mxCell id="0" />
|
| 6 |
+
<mxCell id="1" parent="0" />
|
| 7 |
+
<mxCell id="2" value="User/API" style="swimlane;fontStyle=1;align=center;verticalAlign=top;childLayout=stackLayout;horizontal=1;startSize=26;horizontalStack=0;resizeParent=1;resizeParentMax=0;resizeLast=0;collapsible=1;marginBottom=0;rounded=1;shadow=1;" vertex="1" parent="1">
|
| 8 |
+
<mxGeometry x="80" y="40" width="160" height="130" as="geometry" />
|
| 9 |
+
</mxCell>
|
| 10 |
+
<mxCell id="3" value="Uploads PDF" style="text;strokeColor=none;fillColor=none;align=left;verticalAlign=top;spacingLeft=4;spacingRight=4;overflow=hidden;rotatable=0;points=[[0,0.5],[1,0.5]];portConstraint=eastwest;" vertex="1" parent="2">
|
| 11 |
+
<mxGeometry y="26" width="160" height="26" as="geometry" />
|
| 12 |
+
</mxCell>
|
| 13 |
+
<mxCell id="4" value="Submits Query" style="text;strokeColor=none;fillColor=none;align=left;verticalAlign=top;spacingLeft=4;spacingRight=4;overflow=hidden;rotatable=0;points=[[0,0.5],[1,0.5]];portConstraint=eastwest;" vertex="1" parent="2">
|
| 14 |
+
<mxGeometry y="52" width="160" height="26" as="geometry" />
|
| 15 |
+
</mxCell>
|
| 16 |
+
<mxCell id="5" value="Receives Answer" style="text;strokeColor=none;fillColor=none;align=left;verticalAlign=top;spacingLeft=4;spacingRight=4;overflow=hidden;rotatable=0;points=[[0,0.5],[1,0.5]];portConstraint=eastwest;" vertex="1" parent="2">
|
| 17 |
+
<mxGeometry y="78" width="160" height="26" as="geometry" />
|
| 18 |
+
</mxCell>
|
| 19 |
+
<mxCell id="6" value="Document Loader (PyPDFLoader)" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;" vertex="1" parent="1">
|
| 20 |
+
<mxGeometry x="320" y="80" width="160" height="60" as="geometry" />
|
| 21 |
+
</mxCell>
|
| 22 |
+
<mxCell id="7" value="Text Splitter" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;" vertex="1" parent="1">
|
| 23 |
+
<mxGeometry x="560" y="80" width="160" height="60" as="geometry" />
|
| 24 |
+
</mxCell>
|
| 25 |
+
<mxCell id="8" value="Embedding Model (SentenceTransformers)" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fff2cc;strokeColor=#d6b656;" vertex="1" parent="1">
|
| 26 |
+
<mxGeometry x="800" y="80" width="160" height="60" as="geometry" />
|
| 27 |
+
</mxCell>
|
| 28 |
+
<mxCell id="9" value="Qdrant DB" style="shape=cylinder3;whiteSpace=wrap;html=1;boundedLbl=1;backgroundOutline=1;size=15;fillColor=#f8cecc;strokeColor=#b85450;" vertex="1" parent="1">
|
| 29 |
+
<mxGeometry x="1040" y="80" width="160" height="80" as="geometry" />
|
| 30 |
+
</mxCell>
|
| 31 |
+
<mxCell id="10" value="Query Processor" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;" vertex="1" parent="1">
|
| 32 |
+
<mxGeometry x="320" y="240" width="160" height="60" as="geometry" />
|
| 33 |
+
</mxCell>
|
| 34 |
+
<mxCell id="11" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="2" target="6">
|
| 35 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 36 |
+
<mxPoint x="240" y="110" as="sourcePoint" />
|
| 37 |
+
<mxPoint x="320" y="110" as="targetPoint" />
|
| 38 |
+
</mxGeometry>
|
| 39 |
+
</mxCell>
|
| 40 |
+
<mxCell id="12" value="Raw Text" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="11">
|
| 41 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 42 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 43 |
+
</mxGeometry>
|
| 44 |
+
</mxCell>
|
| 45 |
+
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="6" target="7">
|
| 46 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 47 |
+
<mxPoint x="480" y="110" as="sourcePoint" />
|
| 48 |
+
<mxPoint x="560" y="110" as="targetPoint" />
|
| 49 |
+
</mxGeometry>
|
| 50 |
+
</mxCell>
|
| 51 |
+
<mxCell id="14" value="Splits into Chunks" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="13">
|
| 52 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 53 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 54 |
+
</mxGeometry>
|
| 55 |
+
</mxCell>
|
| 56 |
+
<mxCell id="15" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="7" target="8">
|
| 57 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 58 |
+
<mxPoint x="720" y="110" as="sourcePoint" />
|
| 59 |
+
<mxPoint x="800" y="110" as="targetPoint" />
|
| 60 |
+
</mxGeometry>
|
| 61 |
+
</mxCell>
|
| 62 |
+
<mxCell id="16" value="Generates Embeddings" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="15">
|
| 63 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 64 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 65 |
+
</mxGeometry>
|
| 66 |
+
</mxCell>
|
| 67 |
+
<mxCell id="17" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="8" target="9">
|
| 68 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 69 |
+
<mxPoint x="960" y="110" as="sourcePoint" />
|
| 70 |
+
<mxPoint x="1040" y="120" as="targetPoint" />
|
| 71 |
+
</mxGeometry>
|
| 72 |
+
</mxCell>
|
| 73 |
+
<mxCell id="18" value="Stores Vectors + Metadata" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="17">
|
| 74 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 75 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 76 |
+
</mxGeometry>
|
| 77 |
+
</mxCell>
|
| 78 |
+
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="2" target="10">
|
| 79 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 80 |
+
<mxPoint x="160" y="240" as="sourcePoint" />
|
| 81 |
+
<mxPoint x="400" y="240" as="targetPoint" />
|
| 82 |
+
</mxGeometry>
|
| 83 |
+
</mxCell>
|
| 84 |
+
<mxCell id="20" value="Query Text" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="19">
|
| 85 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 86 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 87 |
+
</mxGeometry>
|
| 88 |
+
</mxCell>
|
| 89 |
+
<mxCell id="21" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="10" target="9">
|
| 90 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 91 |
+
<mxPoint x="480" y="270" as="sourcePoint" />
|
| 92 |
+
<mxPoint x="1040" y="120" as="targetPoint" />
|
| 93 |
+
</mxGeometry>
|
| 94 |
+
</mxCell>
|
| 95 |
+
<mxCell id="22" value="Vectorized Query" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="21">
|
| 96 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 97 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 98 |
+
</mxGeometry>
|
| 99 |
+
</mxCell>
|
| 100 |
+
<mxCell id="23" value="" style="endArrow=classic;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="9" target="10">
|
| 101 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 102 |
+
<mxPoint x="1120" y="160" as="sourcePoint" />
|
| 103 |
+
<mxPoint x="400" y="240" as="targetPoint" />
|
| 104 |
+
</mxGeometry>
|
| 105 |
+
</mxCell>
|
| 106 |
+
<mxCell id="24" value="Top-K Chunks" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="23">
|
| 107 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 108 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 109 |
+
</mxGeometry>
|
| 110 |
+
</mxCell>
|
| 111 |
+
<mxCell id="25" value="" style="endArrow=classic;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="10" target="2">
|
| 112 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
| 113 |
+
<mxPoint x="400" y="300" as="sourcePoint" />
|
| 114 |
+
<mxPoint x="160" y="170" as="targetPoint" />
|
| 115 |
+
</mxGeometry>
|
| 116 |
+
</mxCell>
|
| 117 |
+
<mxCell id="26" value="Answer" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="25">
|
| 118 |
+
<mxGeometry x="-0.2" y="-1" relative="1" as="geometry">
|
| 119 |
+
<mxPoint x="1" y="1" as="offset" />
|
| 120 |
+
</mxGeometry>
|
| 121 |
+
</mxCell>
|
| 122 |
+
<mxCell id="27" value="Port: 6333 (HTTP), 6334 (gRPC)" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" vertex="1" parent="1">
|
| 123 |
+
<mxGeometry x="1040" y="160" width="160" height="20" as="geometry" />
|
| 124 |
+
</mxCell>
|
| 125 |
+
<mxCell id="28" value="Metadata: source, page, chunk_index" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" vertex="1" parent="1">
|
| 126 |
+
<mxGeometry x="1040" y="180" width="160" height="20" as="geometry" />
|
| 127 |
+
</mxCell>
|
| 128 |
+
</root>
|
| 129 |
+
</mxGraphModel>
|
| 130 |
+
</diagram>
|
| 131 |
+
</mxfile>
|