File size: 6,889 Bytes
aa1c44a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import streamlit as st
import os
import gc
import base64
import tempfile
import uuid

from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.prompts import PromptTemplate

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# ----------------------------
# 1) LLM LOADING
# ----------------------------
@st.cache_resource
def load_llm():
    """
    Load the DeepSeek-R1 700B (approx) model from Hugging Face,
    using 4-bit quantization and auto device mapping.
    """
    model_id = "deepseek-ai/DeepSeek-R1"

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_id, 
        trust_remote_code=True
    )

    # model in 4-bit
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        device_map="auto",        # auto-shard across all available GPUs
        load_in_4bit=True,        # bitsandbytes 4-bit quantization
        torch_dtype=torch.float16
    )

    # wrap with LlamaIndex's HuggingFaceLLM
    llm = HuggingFaceLLM(
        model=model,
        tokenizer=tokenizer,
        streaming=True,
        temperature=0.7,
        max_new_tokens=512
    )
    return llm

# ----------------------------
# 2) STREAMLIT + INDEX SETUP
# ----------------------------
if "id" not in st.session_state:
    st.session_state.id = uuid.uuid4()
    st.session_state.file_cache = {}

def reset_chat():
    st.session_state.messages = []
    gc.collect()

def display_pdf(file):
    st.markdown("### PDF Preview")
    base64_pdf = base64.b64encode(file.read()).decode("utf-8")
    pdf_display = f"""
    <iframe src="data:application/pdf;base64,{base64_pdf}" 
            width="400" height="100%"
            style="height:100vh; width:100%">
    </iframe>
    """
    st.markdown(pdf_display, unsafe_allow_html=True)

# Sidebar for file upload
with st.sidebar:
    st.header("Add your documents!")
    
    uploaded_file = st.file_uploader("Choose a `.pdf` file", type="pdf")

    if uploaded_file:
        try:
            # Indexing the doc
            with tempfile.TemporaryDirectory() as temp_dir:
                file_path = os.path.join(temp_dir, uploaded_file.name)
                with open(file_path, "wb") as f:
                    f.write(uploaded_file.getvalue())
                
                file_key = f"{st.session_state.id}-{uploaded_file.name}"
                st.write("Indexing your document...")

                if file_key not in st.session_state.get('file_cache', {}):
                    if os.path.exists(temp_dir):
                        loader = SimpleDirectoryReader(
                            input_dir=temp_dir,
                            required_exts=[".pdf"],
                            recursive=True
                        )
                    else:
                        st.error("Could not find the file. Please reupload.")
                        st.stop()
                    
                    docs = loader.load_data()

                    # Load the HF-based LLM (DeepSeek-R1)
                    llm = load_llm()
                    
                    # HuggingFace Embeddings for the VectorStore
                    embed_model = HuggingFaceEmbedding(
                        model_name="answerdotai/ModernBERT-large",
                        trust_remote_code=True
                    )

                    # create a service context
                    service_context = ServiceContext.from_defaults(
                        llm=llm,
                        embed_model=embed_model
                    )

                    # build the index
                    index = VectorStoreIndex.from_documents(
                        docs,
                        service_context=service_context,
                        show_progress=True
                    )

                    query_engine = index.as_query_engine(streaming=True)

                    # custom QA prompt
                    qa_prompt_tmpl_str = (
                        "Context information is below.\n"
                        "---------------------\n"
                        "{context_str}\n"
                        "---------------------\n"
                        "Given the context info above, provide a concise answer.\n"
                        "If you don't know, say 'I don't know'.\n"
                        "Query: {query_str}\n"
                        "Answer: "
                    )
                    qa_prompt = PromptTemplate(qa_prompt_tmpl_str)
                    query_engine.update_prompts(
                        {"response_synthesizer:text_qa_template": qa_prompt}
                    )

                    # store in session state
                    st.session_state.file_cache[file_key] = query_engine
                else:
                    query_engine = st.session_state.file_cache[file_key]

                st.success("Ready to Chat!")
                display_pdf(uploaded_file)

        except Exception as e:
            st.error(f"An error occurred: {e}")
            st.stop()

col1, col2 = st.columns([6, 1])
with col1:
    st.markdown("# RAG with DeepSeek-R1 (700B)")

with col2:
    st.button("Clear ↺", on_click=reset_chat)

# Initialize chat if needed
if "messages" not in st.session_state:
    reset_chat()

# Render past messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Chat input
if prompt := st.chat_input("Ask a question about your PDF..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    # Retrieve the engine
    if uploaded_file:
        file_key = f"{st.session_state.id}-{uploaded_file.name}"
        query_engine = st.session_state.file_cache.get(file_key)
    else:
        query_engine = None
    
    # If no docs, just return a quick message
    if not query_engine:
        answer = "No documents indexed. Please upload a PDF first."
        st.session_state.messages.append({"role": "assistant", "content": answer})
        with st.chat_message("assistant"):
            st.markdown(answer)
    else:
        with st.chat_message("assistant"):
            message_placeholder = st.empty()
            full_response = ""

            # Streaming generator from LlamaIndex
            streaming_response = query_engine.query(prompt)
            for chunk in streaming_response.response_gen:
                full_response += chunk
                message_placeholder.markdown(full_response + "▌")

            message_placeholder.markdown(full_response)

        st.session_state.messages.append({"role": "assistant", "content": full_response})