rag / app.py
AnwinMJ's picture
Update app.py
893a06e verified
import os
import gradio as gr
import tempfile
from typing import List, Optional
import shutil
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
from langchain.llms.base import LLM
from groq import Groq
# ---- Custom GroqLLM class using LangChain LLM base ----
class GroqLLM(LLM):
model: str = "llama3-8b-8192"
api_key: str = os.environ.get("GROQ_API_KEY") # Load from HF secrets
temperature: float = 0.7
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
client = Groq(api_key=self.api_key)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
)
return response.choices[0].message.content
@property
def _llm_type(self) -> str:
return "groq-llm"
# Global cache for vectorstore
rag_context = {"retriever": None}
# ---- Step 1: Upload & Embed PDF ----
def process_pdf(file):
if file is None:
return "❌ Please upload a PDF."
# Save uploaded file to a real file path
with tempfile.TemporaryDirectory() as temp_dir:
# Gradio provides file path directly via file.name
temp_pdf_path = os.path.join(temp_dir, "uploaded.pdf")
shutil.copy(file.name, temp_pdf_path)
# Load and split PDF
try:
loader = PyPDFLoader(temp_pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = text_splitter.split_documents(documents)
# Create embeddings
embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(chunks, embedding, persist_directory=temp_dir)
vectorstore.persist()
rag_context["retriever"] = vectorstore.as_retriever()
return "βœ… PDF processed and ready. Ask your questions!"
except Exception as e:
return f"❌ Failed to load PDF: {e}"
# ---- Step 2: Ask questions to the RAG chain ----
def ask_question(query):
retriever = rag_context.get("retriever")
if retriever is None:
return "❌ Please upload and process a PDF first."
llm = GroqLLM()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
return_source_documents=True
)
result = qa_chain({"query": query})
answer = result["result"]
return f"### Answer:\n{answer}"
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown("# πŸ“š RAG Chatbot with Groq & LangChain\nUpload a PDF, then ask questions about it!")
with gr.Row():
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
upload_btn = gr.Button("Process PDF")
upload_status = gr.Textbox(label="Status", interactive=False)
upload_btn.click(process_pdf, inputs=pdf_input, outputs=upload_status)
query_input = gr.Textbox(label="Ask a question")
ask_btn = gr.Button("Get Answer")
answer_output = gr.Markdown()
ask_btn.click(ask_question, inputs=query_input, outputs=answer_output)
demo.launch()