File size: 4,955 Bytes
f46b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

import os
import gradio as gr
from langchain_community.document_loaders import PDFPlumberLoader
from langchain_experimental.text_splitter import SemanticChunker
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain_groq import ChatGroq
from pathlib import Path
import pickle

# Function to load or create vector store
def load_or_create_vector_store(pdf_path, cache_path="vector_store.faiss"):
    try:
        # Initialize embeddings
        embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

        # Check if cached vector store exists
        if Path(cache_path).exists():
            print("Loading cached vector store...")
            return FAISS.load_local(cache_path, embedder, allow_dangerous_deserialization=True)

        # Load and process PDF
        if not Path(pdf_path).exists():
            raise FileNotFoundError(f"PDF file not found at: {pdf_path}")
        loader = PDFPlumberLoader(pdf_path)
        docs = loader.load()
        if not docs:
            raise ValueError("No content extracted from the PDF.")

        # Split documents
        text_splitter = SemanticChunker(embedder)
        documents = text_splitter.split_documents(docs)

        # Create and save vector store
        vector_store = FAISS.from_documents(documents, embedder)
        vector_store.save_local(cache_path)
        return vector_store
    except Exception as e:
        raise Exception(f"Error processing PDF: {str(e)}")

# Initialize language model
def initialize_llm(api_key):
    try:
        return ChatGroq(
            groq_api_key=api_key,
            temperature=0,
            model_name="deepseek-r1-distill-llama-70b"
        )
    except Exception as e:
        raise Exception(f"Error initializing LLM: {str(e)}")

# Define prompt template
prompt_template = """
Use the following context to answer the question. If you don't know the answer, say so. Keep the answer to 3 sentences. Always end with "Thanks for asking!"
{context}
Question: {question}
Helpful Answer:
"""
QA_PROMPT = PromptTemplate.from_template(prompt_template)

# Gradio interface function
def query_rag(pdf_file, question, api_key, k=2, chat_history=[]):
    try:
        if not question.strip():
            return "Please enter a valid question.", [], chat_history
        if not api_key:
            return "Please provide a valid Grok API key.", [], chat_history
        if not pdf_file:
            return "Please upload a valid PDF file.", [], chat_history

        # Handle Gradio file upload
        pdf_path = "temp_uploaded_pdf.pdf"
        if isinstance(pdf_file, str):  # Gradio provides file path
            pdf_path = pdf_file
        else:  # Gradio provides file-like object
            with open(pdf_path, "wb") as f:
                f.write(pdf_file)  # Write file content directly

        # Load or create vector store
        vector_store = load_or_create_vector_store(pdf_path)
        retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k})

        # Initialize LLM
        llm = initialize_llm(api_key)

        # Set up conversational chain
        qa_chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=retriever,
            combine_docs_chain_kwargs={"prompt": QA_PROMPT},
            return_source_documents=True
        )

        # Run query
        result = qa_chain({"question": question, "chat_history": chat_history})
        answer = result["answer"]
        sources = [doc.page_content[:200] + "..." for doc in result["source_documents"]]

        # Update chat history
        chat_history.append((question, answer))

        return answer, sources, chat_history
    except Exception as e:
        return f"Error: {str(e)}", [], chat_history

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# RAG Question-Answering System")
    gr.Markdown("Upload a PDF and ask questions about its content. Provide your Grok API key to proceed.")

    api_key = gr.Textbox(label="Grok API Key", type="password", placeholder="Enter your GROQ_API_KEY")
    pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
    question_input = gr.Textbox(label="Ask a Question", placeholder="e.g., What is a cost function in ML?")
    k_slider = gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Number of Retrieved Documents")
    output = gr.Textbox(label="Answer")
    sources = gr.Textbox(label="Source Documents (First 200 Characters)")
    chat_history = gr.State(value=[])

    submit_btn = gr.Button("Submit")
    submit_btn.click(
        fn=query_rag,
        inputs=[pdf_input, question_input, api_key, k_slider, chat_history],
        outputs=[output, sources, chat_history]
    )

# Launch the interface
demo.launch()