Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import traceback | |
| # ---------------- LangChain (STABLE 0.1.x) ---------------- | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.retrievers import EnsembleRetriever | |
| # Providers | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # Community | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.document_loaders import ( | |
| PyPDFLoader, | |
| TextLoader, | |
| Docx2txtLoader | |
| ) | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # ---------------- CONFIG ---------------- | |
| GROQ_API_KEY = os.getenv("GROQ_API") | |
| STRICT_PROMPT = PromptTemplate( | |
| template=""" | |
| You are a strict document-based assistant. | |
| Rules: | |
| 1. ONLY use the provided context. | |
| 2. If the answer is not in the context, say: | |
| "I'm sorry, but the provided documents do not contain information to answer this question." | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer: | |
| """, | |
| input_variables=["context", "question"] | |
| ) | |
| # ---------------- FILE LOADER ---------------- | |
| def load_any(path: str): | |
| p = path.lower() | |
| if p.endswith(".pdf"): | |
| return PyPDFLoader(path).load() | |
| if p.endswith(".txt"): | |
| return TextLoader(path, encoding="utf-8").load() | |
| if p.endswith(".docx"): | |
| return Docx2txtLoader(path).load() | |
| return [] | |
| # ---------------- BUILD CHAIN ---------------- | |
| def process_files(files, response_length): | |
| if not files: | |
| return None, "β No files uploaded" | |
| if not GROQ_API_KEY: | |
| return None, "β GROQ_API secret not set" | |
| try: | |
| docs = [] | |
| for f in files: | |
| docs.extend(load_any(str(f))) # π₯ THIS IS THE FIX | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=100 | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| faiss_db = FAISS.from_documents(chunks, embeddings) | |
| faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3}) | |
| bm25 = BM25Retriever.from_documents(chunks) | |
| bm25.k = 3 | |
| retriever = EnsembleRetriever( | |
| retrievers=[faiss_retriever, bm25], | |
| weights=[0.5, 0.5] | |
| ) | |
| llm = ChatGroq( | |
| groq_api_key=GROQ_API_KEY, | |
| model="llama-3.3-70b-versatile", | |
| temperature=0, | |
| max_tokens=int(response_length) | |
| ) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| output_key="answer" | |
| ) | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| combine_docs_chain_kwargs={"prompt": STRICT_PROMPT}, | |
| return_source_documents=True, | |
| output_key="answer" | |
| ) | |
| return chain, "β Chatbot built successfully" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"β {repr(e)}" | |
| # ---------------- CHAT ---------------- | |
| def chat_function(message, history, chain): | |
| if chain is None: | |
| return "β οΈ Build the chatbot first" | |
| result = chain.invoke({ | |
| "question": message, | |
| "chat_history": history | |
| }) | |
| answer = result["answer"] | |
| sources = { | |
| os.path.basename( | |
| d.metadata.get("source", d.metadata.get("file_path", "unknown")) | |
| ) | |
| for d in result.get("source_documents", []) | |
| } | |
| if sources: | |
| answer += "\n\n---\n**Sources:** " + ", ".join(sources) | |
| return answer | |
| # ---------------- UI ---------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Multi-RAG Chatbot") | |
| chain_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| files = gr.File(file_count="multiple", label="Upload Documents") | |
| tokens = gr.Slider(100, 4000, value=1000, step=100, label="Max Tokens") | |
| build = gr.Button("Build Chatbot", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| gr.ChatInterface( | |
| fn=chat_function, | |
| additional_inputs=[chain_state] | |
| ) | |
| build.click( | |
| process_files, | |
| inputs=[files, tokens], | |
| outputs=[chain_state, status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |