File size: 6,382 Bytes
d67f3b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99b509d
 
d67f3b0
 
 
 
 
 
 
 
d529d6e
 
93e2807
d529d6e
 
d67f3b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979bf04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import tempfile
import uuid
import zipfile
import io
from gtts import gTTS
from langchain_community.llms import OpenAI
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.memory import ConversationBufferMemory
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional
from openai import OpenAI as OpenAIClient
import gradio as gr


class LlamaLLM(LLM):
    client: Any = None

    def __init__(self, api_key: str):
        super().__init__()
        self.client = OpenAIClient(
            base_url="https://integrate.api.nvidia.com/v1",
            api_key=api_key
        )

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        completion = self.client.chat.completions.create(
            model="meta/llama-3.3-70b-instruct",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.2,
            top_p=0.7,
            max_tokens=1024,
        )
        return completion.choices[0].message.content

    @property
    def _llm_type(self) -> str:
        return "Llama 3.3"

def process_pdfs(zip_file, api_key):
    """Process uploaded ZIP file containing PDFs"""
    print("Processing ZIP file...")
    with tempfile.TemporaryDirectory() as temp_dir:
        print(f"Extracting ZIP to temporary directory: {temp_dir}")
        with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
            zip_ref.extractall(temp_dir)

        print("Loading PDFs...")
        loader = DirectoryLoader(temp_dir, glob="**/*.pdf", loader_cls=PyPDFLoader)
        documents = loader.load()

        if not documents:
            raise ValueError("No PDF files found in the uploaded ZIP")

        print(f"Loaded {len(documents)} documents.")
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
        texts = text_splitter.split_documents(documents)

        print("Creating embeddings...")
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vectorstore = FAISS.from_documents(texts, embeddings)

        memory = ConversationBufferMemory()
        llm = LlamaLLM(api_key=api_key)
        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=vectorstore.as_retriever(),
            memory=memory,
        )

        print("PDF processing complete.")
        return qa_chain, memory

def generate_audio(text: str) -> str:
    """Generate audio from text using gTTS"""
    try:
        tts = gTTS(text=text, lang='en')
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
        tts.save(temp_file.name)
        return temp_file.name
    except Exception as e:
        print(f"Audio generation error: {e}")
        return None

def chat_response(query, qa_chain, memory):
    print(f"Generating response for query: {query}")
    try:
        raw_response = qa_chain.invoke(query)
        print(f"Raw response: {raw_response}")

        royal_prompt = f"""
        Respond as a historical royal figure mentioned in the query.
        Use first-person perspective and be gender-specific.
        Respond in the query's language. Be authoritative but polite.
        Use only context information. If unsure, respond as a monarch would.
        Context: {raw_response}
        Previous conversation: {memory.buffer}
        Query: {query}
        Royal Response:"""

        # Access the LLM through the correct path in the chain
        final_response = qa_chain.combine_documents_chain.llm_chain.llm._call(royal_prompt)
        print(f"Final response: {final_response}")
        memory.save_context({'input': query}, {'output': final_response})
        return final_response, generate_audio(final_response)
    except Exception as e:
        print(f"Error in chat_response: {e}")
        raise gr.Error(f"Error generating response: {e}")

with gr.Blocks() as demo:
    gr.Markdown("""
    # πŸ‘‘ Royal Document Assistant  
    <small>This agent can help you with any historical material in a fun and engaging experience, including text and voice responses. But first, visit NVIDIA LLaMA 3.3 70B and get your API key..</small>
    """)
    
    qa_chain = gr.State()
    memory = gr.State()

    with gr.Row():
        with gr.Column():
            api_key_input = gr.Textbox(label="Enter your NVIDIA API Key", type="password")
            zip_upload = gr.File(label="Upload ZIP of PDFs", type="filepath")
            load_btn = gr.Button("Process Documents")
            load_status = gr.Markdown()

    with gr.Row(visible=False) as chat_row:
        with gr.Column():
            chat_input = gr.Textbox(label="Ask the Royal Assistant")
            chat_output = gr.Textbox(label="Response", interactive=False)
            audio_output = gr.Audio(label="Spoken Response", type="filepath")
            submit_btn = gr.Button("Ask")

    def load_docs(zip_file, api_key):
        try:
            chain, mem = process_pdfs(zip_file, api_key)
            return (
                gr.update(visible=True),
                chain,
                mem,
                "βœ… Documents processed! You may now ask questions"
            )
        except Exception as e:
            return (
                gr.update(visible=False),
                None,
                None,
                f"❌ Error processing documents: {str(e)}"
            )

    def ask_question(query, qa_chain, memory):
        if not qa_chain or not memory:
            raise gr.Error("Please process documents first!")
        try:
            response, audio = chat_response(query, qa_chain, memory)
            return response, audio
        except Exception as e:
            print(f"Error in ask_question: {e}")
            return f"Error: {str(e)}", None

    load_btn.click(
        load_docs,
        inputs=[zip_upload, api_key_input],
        outputs=[chat_row, qa_chain, memory, load_status]
    )

    submit_btn.click(
        ask_question,
        inputs=[chat_input, qa_chain, memory],
        outputs=[chat_output, audio_output]
    )

if __name__ == "__main__":
    demo.launch(share=True)