import os import uuid from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context from werkzeug.utils import secure_filename from rag_processor import create_rag_chain from typing import Sequence, Any, List import fitz import re import io from gtts import gTTS from langchain_core.documents import Document from langchain_community.document_loaders import TextLoader, Docx2txtLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain_community.retrievers import BM25Retriever from langchain_community.chat_message_histories import ChatMessageHistory from langchain.storage import InMemoryStore from sentence_transformers.cross_encoder import CrossEncoder app = Flask(__name__) app.config['SECRET_KEY'] = os.urandom(24) TEMPERATURE_LABELS = { '0.2': 'Precise', '0.4': 'Confident', '0.6': 'Balanced', '0.8': 'Flexible', '1.0': 'Creative', } class LocalReranker(BaseDocumentCompressor): model: Any top_n: int = 5 class Config: arbitrary_types_allowed = True def compress_documents(self, documents: Sequence[Document], query: str, callbacks=None) -> Sequence[Document]: if not documents: return [] pairs = [[query, doc.page_content] for doc in documents] scores = self.model.predict(pairs, show_progress_bar=False) doc_scores = list(zip(documents, scores)) sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True) top_docs = [] for (doc, score) in sorted_doc_scores[:self.top_n]: doc.metadata['rerank_score'] = float(score) top_docs.append(doc) return top_docs def create_optimized_parent_child_chunks(all_docs): if not all_docs: print ('CHUNKING: No input documents provided!') return ([], [], []) parent_splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=200, separators=['\n\n', '\n', '. ', '! ', '? ', '; ', ', ', ' ', '']) child_splitter = RecursiveCharacterTextSplitter(chunk_size=350, chunk_overlap=80, separators=['\n', '. ', '! ', '? ', '; ', ', ', ' ', '']) parent_docs = parent_splitter.split_documents(all_docs) doc_ids = [str(uuid.uuid4()) for _ in parent_docs] child_docs = [] for (i, parent_doc) in enumerate(parent_docs): parent_id = doc_ids[i] children = child_splitter.split_documents([parent_doc]) for (j, child) in enumerate(children): child.metadata.update({'doc_id': parent_id, 'chunk_index': j, 'total_chunks': len(children), 'is_first_chunk': j == 0, 'is_last_chunk': j == len(children) - 1}) if len(children) > 1: if j == 0: child.page_content = '[Beginning] ' + child.page_content elif j == len(children) - 1: child.page_content = '[Continues...] ' + child.page_content child_docs.append(child) print (f"CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks." ) return (parent_docs, child_docs, doc_ids) def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]: if not docs: return [] (parent_scores, child_content_by_parent) = ({}, {}) for doc in docs: parent_id = doc.metadata.get('doc_id') if parent_id: parent_scores[parent_id] = parent_scores.get(parent_id, 0) \ + 1 if parent_id not in child_content_by_parent: child_content_by_parent[parent_id] = [] child_content_by_parent[parent_id].append(doc.page_content) parent_ids = list(parent_scores.keys()) parents = store.mget(parent_ids) enhanced_parents = [] for (i, parent) in enumerate(parents): if parent is not None: parent_id = parent_ids[i] if parent_id in child_content_by_parent: child_excerpts = '\n'.join(child_content_by_parent[parent_id][:3]) enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}" enhanced_parent =Document(page_content=enhanced_content, metadata={**parent.metadata, 'child_relevance_score': parent_scores[parent_id], 'matching_children': len(child_content_by_parent[parent_id])}) enhanced_parents.append(enhanced_parent) else: print (f"PARENT_FETCH: Parent {parent_ids[i]} not found in store!") enhanced_parents.sort(key=lambda p: p.metadata.get('child_relevance_score', 0), reverse=True) return enhanced_parents is_hf_spaces = bool(os.getenv('SPACE_ID') or os.getenv('SPACES_ZERO_GPU' )) app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads' try: os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) print (f"Upload folder ready: {app.config['UPLOAD_FOLDER']}") except Exception as e: print (f"Failed to create upload folder, falling back to /tmp: {e}") app.config['UPLOAD_FOLDER'] = '/tmp/uploads' os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) session_data = {} message_histories = {} print ('Loading embedding model...') try: EMBEDDING_MODEL = \ HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2' , model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True}) print ('Embedding model loaded.') except Exception as e: print (f"FATAL: Could not load embedding model. Error: {e}") raise e print ('Loading reranker model...') try: RERANKER_MODEL = \ CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device='cpu') print ('Reranker model loaded.') except Exception as e: print (f"FATAL: Could not load reranker model. Error: {e}") raise e def load_pdf_with_fallback(filepath): try: docs = [] with fitz.open(filepath) as pdf_doc: for (page_num, page) in enumerate(pdf_doc): text = page.get_text() if text.strip(): docs.append(Document(page_content=text, metadata={'source': os.path.basename(filepath), 'page': page_num + 1})) if docs: print (f"Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages" ) return docs else: raise ValueError('No text content found in PDF.') except Exception as e: print (f"PyMuPDF failed for {filepath}: {e}") raise LOADER_MAPPING = {'.txt': TextLoader, '.pdf': load_pdf_with_fallback, '.docx': Docx2txtLoader} def get_session_history(session_id: str) -> ChatMessageHistory: if session_id not in message_histories: message_histories[session_id] = ChatMessageHistory() return message_histories[session_id] @app.route('/health', methods=['GET']) def health_check(): return (jsonify({'status': 'healthy'}), 200) @app.route('/', methods=['GET']) def index(): return render_template('index.html') @app.route('/upload', methods=['POST']) def upload_files(): files = request.files.getlist('file') temperature_str = request.form.get('temperature', '0.2') temperature = float(temperature_str) model_name = request.form.get('model_name', 'moonshotai/kimi-k2-instruct') print (f"UPLOAD: Model: {model_name}, Temp: {temperature}") if not files or all(f.filename == '' for f in files): return (jsonify({'status': 'error', 'message': 'No selected files.'}), 400) (all_docs, processed_files, failed_files) = ([], [], []) print (f"Processing {len(files)} file(s)...") for file in files: if file and file.filename: filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) try: file.save(filepath) file_ext = os.path.splitext(filename)[1].lower() if file_ext not in LOADER_MAPPING: raise ValueError('Unsupported file format.') loader_func = LOADER_MAPPING[file_ext] docs = loader_func(filepath) if file_ext == '.pdf' \ else loader_func(filepath).load() if not docs: raise ValueError('No content extracted.') all_docs.extend(docs) processed_files.append(filename) except Exception as e: print (f"✗ Error processing {filename}: {e}") failed_files.append(f"{filename} ({e})") if not all_docs: return (jsonify({'status': 'error', 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}" }), 400) print (f"UPLOAD: Processed {len(processed_files)} files.") try: print ('Starting RAG pipeline setup...') (parent_docs, child_docs, doc_ids) = \ create_optimized_parent_child_chunks(all_docs) if not child_docs: raise ValueError('No child documents created during chunking.') vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL) store = InMemoryStore() store.mset(list(zip(doc_ids, parent_docs))) print (f"Indexed {len(child_docs)} document chunks.") bm25_retriever = BM25Retriever.from_documents(child_docs) bm25_retriever.k = 12 faiss_retriever = vectorstore.as_retriever(search_kwargs={'k': 12}) ensemble_retriever = \ EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.6, 0.4]) reranker = LocalReranker(model=RERANKER_MODEL, top_n=5) def get_parents(docs: List[Document]) -> List[Document]: return get_context_aware_parents(docs, store) compression_retriever = \ ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever) final_retriever = compression_retriever | get_parents session_id = str(uuid.uuid4()) (rag_chain, api_key_manager) = \ create_rag_chain(retriever=final_retriever, get_session_history_func=get_session_history, model_name=model_name, temperature=temperature) session_data[session_id] = {'chain': rag_chain, 'model_name': model_name, 'temperature': temperature, 'api_key_manager': api_key_manager} success_msg = f"Processed: {', '.join(processed_files)}" if failed_files: success_msg += f". Failed: {', '.join(failed_files)}" mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str) print (f"UPLOAD COMPLETE: Session {session_id} is ready.") return jsonify({ 'status': 'success', 'filename': success_msg, 'session_id': session_id, 'model_name': model_name, 'mode': mode_label, }) except Exception as e: import traceback traceback.print_exc() return (jsonify({'status': 'error', 'message': f'RAG setup failed: {e}'}), 500) @app.route('/chat', methods=['POST', 'GET']) def chat(): if request.method == 'GET': question = request.args.get('question') session_id = request.args.get('session_id') print(f"Received GET request for chat: session={session_id}, question={question[:50]}...") elif request.method == 'POST': data = request.get_json() question = data.get('question') session_id = data.get('session_id') or session.get('session_id') print(f"Received POST request for chat: session={session_id}, question={question[:50]}...") else: return (jsonify({'status': 'error', 'message': 'Method not allowed'}), 405) if not question: error_msg = "Error: No question provided." print(f"CHAT Validation Error: {error_msg}") if request.method == 'GET': def error_stream(): yield f'data: {{"error": "{error_msg}"}}\n\n' return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) return jsonify({'status': 'error','message': error_msg}), 400 if not session_id or session_id not in session_data: error_msg = "Error: Invalid session. Please upload documents first." print(f"CHAT Validation Error: Invalid session {session_id}.") if request.method == 'GET': def error_stream(): yield f'data: {{"error": "{error_msg}"}}\n\n' return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) return jsonify({'status': 'error', 'message': error_msg }), 400 try: session_info = session_data[session_id] rag_chain = session_info['chain'] model_name = session_info['model_name'] temperature_float = session_info['temperature'] temperature_str = str(temperature_float) mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str) print (f"CHAT: Streaming response for session {session_id} (Model: {model_name}, Temp: {temperature_float})...") def generate_chunks(): full_response = '' try: stream_iterator = rag_chain.stream({'question': question}, config={'configurable': {'session_id': session_id}}) for chunk in stream_iterator: if isinstance(chunk, str): full_response += chunk token_escaped = chunk.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') model_name_escaped = model_name.replace('"', '\\"') mode_label_escaped = mode_label.replace('"', '\\"') yield f'data: {{"token": "{token_escaped}", "model_name": "{model_name_escaped}", "mode": "{mode_label_escaped}"}}\n\n' else: print(f"Received non-string chunk: {type(chunk)}") print ('CHAT: Streaming finished successfully.') except Exception as e: print(f"CHAT Error during streaming generation: {e}") import traceback traceback.print_exc() error_msg = f"Error during response generation: {str(e)}".replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') yield f'data: {{"error": "{error_msg}"}}\n\n' return Response(stream_with_context(generate_chunks()), mimetype='text/event-stream') except Exception as e: print(f"CHAT Setup Error: {e}") import traceback traceback.print_exc() error_msg = f"Error setting up chat stream: {str(e)}" if request.method == 'GET': def error_stream(): clean_error_msg= error_msg.replace("\"", "\\\"").replace("n", "\\n") yield f'data: {{"error": "{clean_error_msg}"}}\n\n' return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=500) return (jsonify({'status': 'error', 'message': error_msg}), 500) def clean_markdown_for_tts(text: str) -> str: text = re.sub(r'\[.*?\]\(.*?\)', '', text) text = re.sub(r'[`*_#]', '', text) text = re.sub(r'^\s*[\-\*\+]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE) text = re.sub(r'\n+', ' ', text) text = re.sub(r'\s{2,}', ' ', text) return text.strip() @app.route('/tts', methods=['POST']) def text_to_speech(): data = request.get_json() text = data.get('text') if not text: return (jsonify({'status': 'error', 'message': 'No text provided.'}), 400) try: clean_text = clean_markdown_for_tts(text) if not clean_text: return (jsonify({'status': 'error', 'message': 'No speakable text found.'}), 400) tts = gTTS(clean_text, lang='en') mp3_fp = io.BytesIO() tts.write_to_fp(mp3_fp) mp3_fp.seek(0) return Response(mp3_fp, mimetype='audio/mpeg') except Exception as e: print (f"TTS Error: {e}") return (jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500) if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) print (f"Starting Flask app on port {port}") app.run(host='0.0.0.0', port=port, debug=False, threaded=True)