riteshraut
feat/audio
be8f70c
import os
import uuid
from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context
from werkzeug.utils import secure_filename
from rag_processor import create_rag_chain
from typing import Sequence, Any, List
import fitz
import re
import io
from gtts import gTTS
from langchain_core.documents import Document
from langchain_community.document_loaders import TextLoader, Docx2txtLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain_community.retrievers import BM25Retriever
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.storage import InMemoryStore
from sentence_transformers.cross_encoder import CrossEncoder
app = Flask(__name__)
app.config['SECRET_KEY'] = os.urandom(24)
TEMPERATURE_LABELS = {
'0.2': 'Precise',
'0.4': 'Confident',
'0.6': 'Balanced',
'0.8': 'Flexible',
'1.0': 'Creative',
}
class LocalReranker(BaseDocumentCompressor):
model: Any
top_n: int = 5
class Config:
arbitrary_types_allowed = True
def compress_documents(self, documents: Sequence[Document], query: str,
callbacks=None) -> Sequence[Document]:
if not documents:
return []
pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(pairs, show_progress_bar=False)
doc_scores = list(zip(documents, scores))
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1],
reverse=True)
top_docs = []
for (doc, score) in sorted_doc_scores[:self.top_n]:
doc.metadata['rerank_score'] = float(score)
top_docs.append(doc)
return top_docs
def create_optimized_parent_child_chunks(all_docs):
if not all_docs:
print ('CHUNKING: No input documents provided!')
return ([], [], [])
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=900,
chunk_overlap=200, separators=['\n\n', '\n', '. ', '! ',
'? ', '; ', ', ', ' ', ''])
child_splitter = RecursiveCharacterTextSplitter(chunk_size=350,
chunk_overlap=80, separators=['\n', '. ', '! ', '? ', '; ',
', ', ' ', ''])
parent_docs = parent_splitter.split_documents(all_docs)
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
child_docs = []
for (i, parent_doc) in enumerate(parent_docs):
parent_id = doc_ids[i]
children = child_splitter.split_documents([parent_doc])
for (j, child) in enumerate(children):
child.metadata.update({'doc_id': parent_id,
'chunk_index': j,
'total_chunks': len(children),
'is_first_chunk': j == 0,
'is_last_chunk': j == len(children)
- 1})
if len(children) > 1:
if j == 0:
child.page_content = '[Beginning] ' + child.page_content
elif j == len(children) - 1:
child.page_content = '[Continues...] ' + child.page_content
child_docs.append(child)
print (f"CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks."
)
return (parent_docs, child_docs, doc_ids)
def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]:
if not docs:
return []
(parent_scores, child_content_by_parent) = ({}, {})
for doc in docs:
parent_id = doc.metadata.get('doc_id')
if parent_id:
parent_scores[parent_id] = parent_scores.get(parent_id, 0) \
+ 1
if parent_id not in child_content_by_parent:
child_content_by_parent[parent_id] = []
child_content_by_parent[parent_id].append(doc.page_content)
parent_ids = list(parent_scores.keys())
parents = store.mget(parent_ids)
enhanced_parents = []
for (i, parent) in enumerate(parents):
if parent is not None:
parent_id = parent_ids[i]
if parent_id in child_content_by_parent:
child_excerpts = '\n'.join(child_content_by_parent[parent_id][:3])
enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}"
enhanced_parent =Document(page_content=enhanced_content,
metadata={**parent.metadata,
'child_relevance_score': parent_scores[parent_id],
'matching_children': len(child_content_by_parent[parent_id])})
enhanced_parents.append(enhanced_parent)
else:
print (f"PARENT_FETCH: Parent {parent_ids[i]} not found in store!")
enhanced_parents.sort(key=lambda p: p.metadata.get('child_relevance_score', 0), reverse=True)
return enhanced_parents
is_hf_spaces = bool(os.getenv('SPACE_ID') or os.getenv('SPACES_ZERO_GPU'
))
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads'
try:
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
print (f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
except Exception as e:
print (f"Failed to create upload folder, falling back to /tmp: {e}")
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
session_data = {}
message_histories = {}
print ('Loading embedding model...')
try:
EMBEDDING_MODEL = \
HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'
, model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True})
print ('Embedding model loaded.')
except Exception as e:
print (f"FATAL: Could not load embedding model. Error: {e}")
raise e
print ('Loading reranker model...')
try:
RERANKER_MODEL = \
CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2',
device='cpu')
print ('Reranker model loaded.')
except Exception as e:
print (f"FATAL: Could not load reranker model. Error: {e}")
raise e
def load_pdf_with_fallback(filepath):
try:
docs = []
with fitz.open(filepath) as pdf_doc:
for (page_num, page) in enumerate(pdf_doc):
text = page.get_text()
if text.strip():
docs.append(Document(page_content=text,
metadata={'source': os.path.basename(filepath),
'page': page_num + 1}))
if docs:
print (f"Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages"
)
return docs
else:
raise ValueError('No text content found in PDF.')
except Exception as e:
print (f"PyMuPDF failed for {filepath}: {e}")
raise
LOADER_MAPPING = {'.txt': TextLoader, '.pdf': load_pdf_with_fallback,
'.docx': Docx2txtLoader}
def get_session_history(session_id: str) -> ChatMessageHistory:
if session_id not in message_histories:
message_histories[session_id] = ChatMessageHistory()
return message_histories[session_id]
@app.route('/health', methods=['GET'])
def health_check():
return (jsonify({'status': 'healthy'}), 200)
@app.route('/', methods=['GET'])
def index():
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload_files():
files = request.files.getlist('file')
temperature_str = request.form.get('temperature', '0.2')
temperature = float(temperature_str)
model_name = request.form.get('model_name',
'moonshotai/kimi-k2-instruct')
print (f"UPLOAD: Model: {model_name}, Temp: {temperature}")
if not files or all(f.filename == '' for f in files):
return (jsonify({'status': 'error',
'message': 'No selected files.'}), 400)
(all_docs, processed_files, failed_files) = ([], [], [])
print (f"Processing {len(files)} file(s)...")
for file in files:
if file and file.filename:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
try:
file.save(filepath)
file_ext = os.path.splitext(filename)[1].lower()
if file_ext not in LOADER_MAPPING:
raise ValueError('Unsupported file format.')
loader_func = LOADER_MAPPING[file_ext]
docs = loader_func(filepath) if file_ext == '.pdf' \
else loader_func(filepath).load()
if not docs:
raise ValueError('No content extracted.')
all_docs.extend(docs)
processed_files.append(filename)
except Exception as e:
print (f"✗ Error processing {filename}: {e}")
failed_files.append(f"{filename} ({e})")
if not all_docs:
return (jsonify({'status': 'error',
'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}"
}), 400)
print (f"UPLOAD: Processed {len(processed_files)} files.")
try:
print ('Starting RAG pipeline setup...')
(parent_docs, child_docs, doc_ids) = \
create_optimized_parent_child_chunks(all_docs)
if not child_docs:
raise ValueError('No child documents created during chunking.')
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
store = InMemoryStore()
store.mset(list(zip(doc_ids, parent_docs)))
print (f"Indexed {len(child_docs)} document chunks.")
bm25_retriever = BM25Retriever.from_documents(child_docs)
bm25_retriever.k = 12
faiss_retriever = vectorstore.as_retriever(search_kwargs={'k': 12})
ensemble_retriever = \
EnsembleRetriever(retrievers=[bm25_retriever,
faiss_retriever], weights=[0.6, 0.4])
reranker = LocalReranker(model=RERANKER_MODEL, top_n=5)
def get_parents(docs: List[Document]) -> List[Document]:
return get_context_aware_parents(docs, store)
compression_retriever = \
ContextualCompressionRetriever(base_compressor=reranker,
base_retriever=ensemble_retriever)
final_retriever = compression_retriever | get_parents
session_id = str(uuid.uuid4())
(rag_chain, api_key_manager) = \
create_rag_chain(retriever=final_retriever,
get_session_history_func=get_session_history,
model_name=model_name,
temperature=temperature)
session_data[session_id] = {'chain': rag_chain,
'model_name': model_name,
'temperature': temperature,
'api_key_manager': api_key_manager}
success_msg = f"Processed: {', '.join(processed_files)}"
if failed_files:
success_msg += f". Failed: {', '.join(failed_files)}"
mode_label = TEMPERATURE_LABELS.get(temperature_str,
temperature_str)
print (f"UPLOAD COMPLETE: Session {session_id} is ready.")
return jsonify({
'status': 'success',
'filename': success_msg,
'session_id': session_id,
'model_name': model_name,
'mode': mode_label,
})
except Exception as e:
import traceback
traceback.print_exc()
return (jsonify({'status': 'error',
'message': f'RAG setup failed: {e}'}), 500)
@app.route('/chat', methods=['POST', 'GET'])
def chat():
if request.method == 'GET':
question = request.args.get('question')
session_id = request.args.get('session_id')
print(f"Received GET request for chat: session={session_id}, question={question[:50]}...")
elif request.method == 'POST':
data = request.get_json()
question = data.get('question')
session_id = data.get('session_id') or session.get('session_id')
print(f"Received POST request for chat: session={session_id}, question={question[:50]}...")
else:
return (jsonify({'status': 'error', 'message': 'Method not allowed'}), 405)
if not question:
error_msg = "Error: No question provided."
print(f"CHAT Validation Error: {error_msg}")
if request.method == 'GET':
def error_stream():
yield f'data: {{"error": "{error_msg}"}}\n\n'
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400)
return jsonify({'status': 'error','message': error_msg}), 400
if not session_id or session_id not in session_data:
error_msg = "Error: Invalid session. Please upload documents first."
print(f"CHAT Validation Error: Invalid session {session_id}.")
if request.method == 'GET':
def error_stream():
yield f'data: {{"error": "{error_msg}"}}\n\n'
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400)
return jsonify({'status': 'error', 'message': error_msg }), 400
try:
session_info = session_data[session_id]
rag_chain = session_info['chain']
model_name = session_info['model_name']
temperature_float = session_info['temperature']
temperature_str = str(temperature_float)
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
print (f"CHAT: Streaming response for session {session_id} (Model: {model_name}, Temp: {temperature_float})...")
def generate_chunks():
full_response = ''
try:
stream_iterator = rag_chain.stream({'question': question},
config={'configurable': {'session_id': session_id}})
for chunk in stream_iterator:
if isinstance(chunk, str):
full_response += chunk
token_escaped = chunk.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
model_name_escaped = model_name.replace('"', '\\"')
mode_label_escaped = mode_label.replace('"', '\\"')
yield f'data: {{"token": "{token_escaped}", "model_name": "{model_name_escaped}", "mode": "{mode_label_escaped}"}}\n\n'
else:
print(f"Received non-string chunk: {type(chunk)}")
print ('CHAT: Streaming finished successfully.')
except Exception as e:
print(f"CHAT Error during streaming generation: {e}")
import traceback
traceback.print_exc()
error_msg = f"Error during response generation: {str(e)}".replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
yield f'data: {{"error": "{error_msg}"}}\n\n'
return Response(stream_with_context(generate_chunks()), mimetype='text/event-stream')
except Exception as e:
print(f"CHAT Setup Error: {e}")
import traceback
traceback.print_exc()
error_msg = f"Error setting up chat stream: {str(e)}"
if request.method == 'GET':
def error_stream():
clean_error_msg= error_msg.replace("\"", "\\\"").replace("n", "\\n")
yield f'data: {{"error": "{clean_error_msg}"}}\n\n'
return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=500)
return (jsonify({'status': 'error', 'message': error_msg}), 500)
def clean_markdown_for_tts(text: str) -> str:
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
text = re.sub(r'[`*_#]', '', text)
text = re.sub(r'^\s*[\-\*\+]\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
text = re.sub(r'\n+', ' ', text)
text = re.sub(r'\s{2,}', ' ', text)
return text.strip()
@app.route('/tts', methods=['POST'])
def text_to_speech():
data = request.get_json()
text = data.get('text')
if not text:
return (jsonify({'status': 'error',
'message': 'No text provided.'}), 400)
try:
clean_text = clean_markdown_for_tts(text)
if not clean_text:
return (jsonify({'status': 'error', 'message': 'No speakable text found.'}), 400)
tts = gTTS(clean_text, lang='en')
mp3_fp = io.BytesIO()
tts.write_to_fp(mp3_fp)
mp3_fp.seek(0)
return Response(mp3_fp, mimetype='audio/mpeg')
except Exception as e:
print (f"TTS Error: {e}")
return (jsonify({'status': 'error',
'message': 'Failed to generate audio.'}), 500)
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
print (f"Starting Flask app on port {port}")
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)