|
|
import os |
|
|
import uuid |
|
|
from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context |
|
|
from werkzeug.utils import secure_filename |
|
|
from rag_processor import create_rag_chain |
|
|
from typing import Sequence, Any, List |
|
|
import fitz |
|
|
import re |
|
|
import io |
|
|
from gtts import gTTS |
|
|
from langchain_core.documents import Document |
|
|
from langchain_community.document_loaders import TextLoader, Docx2txtLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever |
|
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
|
from langchain.storage import InMemoryStore |
|
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
|
|
|
|
app = Flask(__name__) |
|
|
app.config['SECRET_KEY'] = os.urandom(24) |
|
|
TEMPERATURE_LABELS = { |
|
|
'0.2': 'Precise', |
|
|
'0.4': 'Confident', |
|
|
'0.6': 'Balanced', |
|
|
'0.8': 'Flexible', |
|
|
'1.0': 'Creative', |
|
|
} |
|
|
class LocalReranker(BaseDocumentCompressor): |
|
|
model: Any |
|
|
top_n: int = 5 |
|
|
|
|
|
class Config: |
|
|
arbitrary_types_allowed = True |
|
|
|
|
|
def compress_documents(self, documents: Sequence[Document], query: str, |
|
|
callbacks=None) -> Sequence[Document]: |
|
|
if not documents: |
|
|
return [] |
|
|
pairs = [[query, doc.page_content] for doc in documents] |
|
|
scores = self.model.predict(pairs, show_progress_bar=False) |
|
|
doc_scores = list(zip(documents, scores)) |
|
|
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], |
|
|
reverse=True) |
|
|
top_docs = [] |
|
|
for (doc, score) in sorted_doc_scores[:self.top_n]: |
|
|
doc.metadata['rerank_score'] = float(score) |
|
|
top_docs.append(doc) |
|
|
return top_docs |
|
|
|
|
|
|
|
|
def create_optimized_parent_child_chunks(all_docs): |
|
|
if not all_docs: |
|
|
print ('CHUNKING: No input documents provided!') |
|
|
return ([], [], []) |
|
|
|
|
|
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=900, |
|
|
chunk_overlap=200, separators=['\n\n', '\n', '. ', '! ', |
|
|
'? ', '; ', ', ', ' ', '']) |
|
|
child_splitter = RecursiveCharacterTextSplitter(chunk_size=350, |
|
|
chunk_overlap=80, separators=['\n', '. ', '! ', '? ', '; ', |
|
|
', ', ' ', '']) |
|
|
parent_docs = parent_splitter.split_documents(all_docs) |
|
|
doc_ids = [str(uuid.uuid4()) for _ in parent_docs] |
|
|
child_docs = [] |
|
|
|
|
|
for (i, parent_doc) in enumerate(parent_docs): |
|
|
parent_id = doc_ids[i] |
|
|
children = child_splitter.split_documents([parent_doc]) |
|
|
for (j, child) in enumerate(children): |
|
|
child.metadata.update({'doc_id': parent_id, |
|
|
'chunk_index': j, |
|
|
'total_chunks': len(children), |
|
|
'is_first_chunk': j == 0, |
|
|
'is_last_chunk': j == len(children) |
|
|
- 1}) |
|
|
if len(children) > 1: |
|
|
if j == 0: |
|
|
child.page_content = '[Beginning] ' + child.page_content |
|
|
elif j == len(children) - 1: |
|
|
child.page_content = '[Continues...] ' + child.page_content |
|
|
child_docs.append(child) |
|
|
|
|
|
print (f"CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks." |
|
|
) |
|
|
return (parent_docs, child_docs, doc_ids) |
|
|
def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]: |
|
|
if not docs: |
|
|
return [] |
|
|
(parent_scores, child_content_by_parent) = ({}, {}) |
|
|
for doc in docs: |
|
|
parent_id = doc.metadata.get('doc_id') |
|
|
if parent_id: |
|
|
parent_scores[parent_id] = parent_scores.get(parent_id, 0) \ |
|
|
+ 1 |
|
|
if parent_id not in child_content_by_parent: |
|
|
child_content_by_parent[parent_id] = [] |
|
|
child_content_by_parent[parent_id].append(doc.page_content) |
|
|
|
|
|
parent_ids = list(parent_scores.keys()) |
|
|
parents = store.mget(parent_ids) |
|
|
enhanced_parents = [] |
|
|
|
|
|
for (i, parent) in enumerate(parents): |
|
|
if parent is not None: |
|
|
parent_id = parent_ids[i] |
|
|
if parent_id in child_content_by_parent: |
|
|
child_excerpts = '\n'.join(child_content_by_parent[parent_id][:3]) |
|
|
enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}" |
|
|
enhanced_parent =Document(page_content=enhanced_content, |
|
|
metadata={**parent.metadata, |
|
|
'child_relevance_score': parent_scores[parent_id], |
|
|
'matching_children': len(child_content_by_parent[parent_id])}) |
|
|
enhanced_parents.append(enhanced_parent) |
|
|
else: |
|
|
print (f"PARENT_FETCH: Parent {parent_ids[i]} not found in store!") |
|
|
|
|
|
enhanced_parents.sort(key=lambda p: p.metadata.get('child_relevance_score', 0), reverse=True) |
|
|
return enhanced_parents |
|
|
|
|
|
|
|
|
is_hf_spaces = bool(os.getenv('SPACE_ID') or os.getenv('SPACES_ZERO_GPU' |
|
|
)) |
|
|
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads' |
|
|
|
|
|
try: |
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
|
|
print (f"Upload folder ready: {app.config['UPLOAD_FOLDER']}") |
|
|
except Exception as e: |
|
|
print (f"Failed to create upload folder, falling back to /tmp: {e}") |
|
|
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' |
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
|
|
|
|
|
session_data = {} |
|
|
message_histories = {} |
|
|
|
|
|
print ('Loading embedding model...') |
|
|
try: |
|
|
EMBEDDING_MODEL = \ |
|
|
HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2' |
|
|
, model_kwargs={'device': 'cpu'}, |
|
|
encode_kwargs={'normalize_embeddings': True}) |
|
|
print ('Embedding model loaded.') |
|
|
except Exception as e: |
|
|
print (f"FATAL: Could not load embedding model. Error: {e}") |
|
|
raise e |
|
|
|
|
|
print ('Loading reranker model...') |
|
|
try: |
|
|
RERANKER_MODEL = \ |
|
|
CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', |
|
|
device='cpu') |
|
|
print ('Reranker model loaded.') |
|
|
except Exception as e: |
|
|
print (f"FATAL: Could not load reranker model. Error: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
def load_pdf_with_fallback(filepath): |
|
|
try: |
|
|
docs = [] |
|
|
with fitz.open(filepath) as pdf_doc: |
|
|
for (page_num, page) in enumerate(pdf_doc): |
|
|
text = page.get_text() |
|
|
if text.strip(): |
|
|
docs.append(Document(page_content=text, |
|
|
metadata={'source': os.path.basename(filepath), |
|
|
'page': page_num + 1})) |
|
|
if docs: |
|
|
print (f"Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages" |
|
|
) |
|
|
return docs |
|
|
else: |
|
|
raise ValueError('No text content found in PDF.') |
|
|
except Exception as e: |
|
|
print (f"PyMuPDF failed for {filepath}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
LOADER_MAPPING = {'.txt': TextLoader, '.pdf': load_pdf_with_fallback, |
|
|
'.docx': Docx2txtLoader} |
|
|
|
|
|
|
|
|
def get_session_history(session_id: str) -> ChatMessageHistory: |
|
|
if session_id not in message_histories: |
|
|
message_histories[session_id] = ChatMessageHistory() |
|
|
return message_histories[session_id] |
|
|
|
|
|
|
|
|
@app.route('/health', methods=['GET']) |
|
|
def health_check(): |
|
|
return (jsonify({'status': 'healthy'}), 200) |
|
|
|
|
|
|
|
|
@app.route('/', methods=['GET']) |
|
|
def index(): |
|
|
return render_template('index.html') |
|
|
|
|
|
|
|
|
@app.route('/upload', methods=['POST']) |
|
|
def upload_files(): |
|
|
files = request.files.getlist('file') |
|
|
|
|
|
temperature_str = request.form.get('temperature', '0.2') |
|
|
temperature = float(temperature_str) |
|
|
model_name = request.form.get('model_name', |
|
|
'moonshotai/kimi-k2-instruct') |
|
|
print (f"UPLOAD: Model: {model_name}, Temp: {temperature}") |
|
|
|
|
|
if not files or all(f.filename == '' for f in files): |
|
|
return (jsonify({'status': 'error', |
|
|
'message': 'No selected files.'}), 400) |
|
|
|
|
|
(all_docs, processed_files, failed_files) = ([], [], []) |
|
|
print (f"Processing {len(files)} file(s)...") |
|
|
for file in files: |
|
|
if file and file.filename: |
|
|
filename = secure_filename(file.filename) |
|
|
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) |
|
|
try: |
|
|
file.save(filepath) |
|
|
file_ext = os.path.splitext(filename)[1].lower() |
|
|
if file_ext not in LOADER_MAPPING: |
|
|
raise ValueError('Unsupported file format.') |
|
|
loader_func = LOADER_MAPPING[file_ext] |
|
|
docs = loader_func(filepath) if file_ext == '.pdf' \ |
|
|
else loader_func(filepath).load() |
|
|
if not docs: |
|
|
raise ValueError('No content extracted.') |
|
|
all_docs.extend(docs) |
|
|
processed_files.append(filename) |
|
|
except Exception as e: |
|
|
print (f"✗ Error processing {filename}: {e}") |
|
|
failed_files.append(f"{filename} ({e})") |
|
|
|
|
|
if not all_docs: |
|
|
return (jsonify({'status': 'error', |
|
|
'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}" |
|
|
}), 400) |
|
|
|
|
|
print (f"UPLOAD: Processed {len(processed_files)} files.") |
|
|
try: |
|
|
print ('Starting RAG pipeline setup...') |
|
|
(parent_docs, child_docs, doc_ids) = \ |
|
|
create_optimized_parent_child_chunks(all_docs) |
|
|
if not child_docs: |
|
|
raise ValueError('No child documents created during chunking.') |
|
|
|
|
|
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL) |
|
|
store = InMemoryStore() |
|
|
store.mset(list(zip(doc_ids, parent_docs))) |
|
|
print (f"Indexed {len(child_docs)} document chunks.") |
|
|
|
|
|
bm25_retriever = BM25Retriever.from_documents(child_docs) |
|
|
bm25_retriever.k = 12 |
|
|
faiss_retriever = vectorstore.as_retriever(search_kwargs={'k': 12}) |
|
|
ensemble_retriever = \ |
|
|
EnsembleRetriever(retrievers=[bm25_retriever, |
|
|
faiss_retriever], weights=[0.6, 0.4]) |
|
|
reranker = LocalReranker(model=RERANKER_MODEL, top_n=5) |
|
|
|
|
|
def get_parents(docs: List[Document]) -> List[Document]: |
|
|
return get_context_aware_parents(docs, store) |
|
|
|
|
|
compression_retriever = \ |
|
|
ContextualCompressionRetriever(base_compressor=reranker, |
|
|
base_retriever=ensemble_retriever) |
|
|
final_retriever = compression_retriever | get_parents |
|
|
|
|
|
session_id = str(uuid.uuid4()) |
|
|
(rag_chain, api_key_manager) = \ |
|
|
create_rag_chain(retriever=final_retriever, |
|
|
get_session_history_func=get_session_history, |
|
|
model_name=model_name, |
|
|
temperature=temperature) |
|
|
|
|
|
session_data[session_id] = {'chain': rag_chain, |
|
|
'model_name': model_name, |
|
|
'temperature': temperature, |
|
|
'api_key_manager': api_key_manager} |
|
|
|
|
|
success_msg = f"Processed: {', '.join(processed_files)}" |
|
|
if failed_files: |
|
|
success_msg += f". Failed: {', '.join(failed_files)}" |
|
|
|
|
|
mode_label = TEMPERATURE_LABELS.get(temperature_str, |
|
|
temperature_str) |
|
|
|
|
|
print (f"UPLOAD COMPLETE: Session {session_id} is ready.") |
|
|
|
|
|
return jsonify({ |
|
|
'status': 'success', |
|
|
'filename': success_msg, |
|
|
'session_id': session_id, |
|
|
'model_name': model_name, |
|
|
'mode': mode_label, |
|
|
}) |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return (jsonify({'status': 'error', |
|
|
'message': f'RAG setup failed: {e}'}), 500) |
|
|
|
|
|
@app.route('/chat', methods=['POST', 'GET']) |
|
|
def chat(): |
|
|
if request.method == 'GET': |
|
|
question = request.args.get('question') |
|
|
session_id = request.args.get('session_id') |
|
|
print(f"Received GET request for chat: session={session_id}, question={question[:50]}...") |
|
|
elif request.method == 'POST': |
|
|
data = request.get_json() |
|
|
question = data.get('question') |
|
|
session_id = data.get('session_id') or session.get('session_id') |
|
|
print(f"Received POST request for chat: session={session_id}, question={question[:50]}...") |
|
|
else: |
|
|
return (jsonify({'status': 'error', 'message': 'Method not allowed'}), 405) |
|
|
|
|
|
if not question: |
|
|
error_msg = "Error: No question provided." |
|
|
print(f"CHAT Validation Error: {error_msg}") |
|
|
if request.method == 'GET': |
|
|
def error_stream(): |
|
|
yield f'data: {{"error": "{error_msg}"}}\n\n' |
|
|
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) |
|
|
return jsonify({'status': 'error','message': error_msg}), 400 |
|
|
|
|
|
if not session_id or session_id not in session_data: |
|
|
error_msg = "Error: Invalid session. Please upload documents first." |
|
|
print(f"CHAT Validation Error: Invalid session {session_id}.") |
|
|
if request.method == 'GET': |
|
|
def error_stream(): |
|
|
yield f'data: {{"error": "{error_msg}"}}\n\n' |
|
|
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) |
|
|
return jsonify({'status': 'error', 'message': error_msg }), 400 |
|
|
try: |
|
|
session_info = session_data[session_id] |
|
|
rag_chain = session_info['chain'] |
|
|
model_name = session_info['model_name'] |
|
|
temperature_float = session_info['temperature'] |
|
|
temperature_str = str(temperature_float) |
|
|
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str) |
|
|
|
|
|
print (f"CHAT: Streaming response for session {session_id} (Model: {model_name}, Temp: {temperature_float})...") |
|
|
|
|
|
def generate_chunks(): |
|
|
full_response = '' |
|
|
try: |
|
|
stream_iterator = rag_chain.stream({'question': question}, |
|
|
config={'configurable': {'session_id': session_id}}) |
|
|
|
|
|
for chunk in stream_iterator: |
|
|
if isinstance(chunk, str): |
|
|
full_response += chunk |
|
|
token_escaped = chunk.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') |
|
|
model_name_escaped = model_name.replace('"', '\\"') |
|
|
mode_label_escaped = mode_label.replace('"', '\\"') |
|
|
yield f'data: {{"token": "{token_escaped}", "model_name": "{model_name_escaped}", "mode": "{mode_label_escaped}"}}\n\n' |
|
|
else: |
|
|
print(f"Received non-string chunk: {type(chunk)}") |
|
|
|
|
|
|
|
|
print ('CHAT: Streaming finished successfully.') |
|
|
|
|
|
except Exception as e: |
|
|
print(f"CHAT Error during streaming generation: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
error_msg = f"Error during response generation: {str(e)}".replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') |
|
|
yield f'data: {{"error": "{error_msg}"}}\n\n' |
|
|
return Response(stream_with_context(generate_chunks()), mimetype='text/event-stream') |
|
|
|
|
|
except Exception as e: |
|
|
print(f"CHAT Setup Error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
error_msg = f"Error setting up chat stream: {str(e)}" |
|
|
if request.method == 'GET': |
|
|
def error_stream(): |
|
|
clean_error_msg= error_msg.replace("\"", "\\\"").replace("n", "\\n") |
|
|
yield f'data: {{"error": "{clean_error_msg}"}}\n\n' |
|
|
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=500) |
|
|
return (jsonify({'status': 'error', 'message': error_msg}), 500) |
|
|
|
|
|
|
|
|
def clean_markdown_for_tts(text: str) -> str: |
|
|
text = re.sub(r'\[.*?\]\(.*?\)', '', text) |
|
|
text = re.sub(r'[`*_#]', '', text) |
|
|
text = re.sub(r'^\s*[\-\*\+]\s+', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'\n+', ' ', text) |
|
|
text = re.sub(r'\s{2,}', ' ', text) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
@app.route('/tts', methods=['POST']) |
|
|
def text_to_speech(): |
|
|
data = request.get_json() |
|
|
text = data.get('text') |
|
|
if not text: |
|
|
return (jsonify({'status': 'error', |
|
|
'message': 'No text provided.'}), 400) |
|
|
try: |
|
|
clean_text = clean_markdown_for_tts(text) |
|
|
if not clean_text: |
|
|
return (jsonify({'status': 'error', 'message': 'No speakable text found.'}), 400) |
|
|
|
|
|
tts = gTTS(clean_text, lang='en') |
|
|
mp3_fp = io.BytesIO() |
|
|
tts.write_to_fp(mp3_fp) |
|
|
mp3_fp.seek(0) |
|
|
return Response(mp3_fp, mimetype='audio/mpeg') |
|
|
except Exception as e: |
|
|
print (f"TTS Error: {e}") |
|
|
return (jsonify({'status': 'error', |
|
|
'message': 'Failed to generate audio.'}), 500) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
port = int(os.environ.get('PORT', 7860)) |
|
|
print (f"Starting Flask app on port {port}") |
|
|
app.run(host='0.0.0.0', port=port, debug=False, threaded=True) |