dsr1-demo / app.py
Nick White
ADD initial files
aa1c44a
import streamlit as st
import os
import gc
import base64
import tempfile
import uuid
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# ----------------------------
# 1) LLM LOADING
# ----------------------------
@st.cache_resource
def load_llm():
"""
Load the DeepSeek-R1 700B (approx) model from Hugging Face,
using 4-bit quantization and auto device mapping.
"""
model_id = "deepseek-ai/DeepSeek-R1"
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
# model in 4-bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto", # auto-shard across all available GPUs
load_in_4bit=True, # bitsandbytes 4-bit quantization
torch_dtype=torch.float16
)
# wrap with LlamaIndex's HuggingFaceLLM
llm = HuggingFaceLLM(
model=model,
tokenizer=tokenizer,
streaming=True,
temperature=0.7,
max_new_tokens=512
)
return llm
# ----------------------------
# 2) STREAMLIT + INDEX SETUP
# ----------------------------
if "id" not in st.session_state:
st.session_state.id = uuid.uuid4()
st.session_state.file_cache = {}
def reset_chat():
st.session_state.messages = []
gc.collect()
def display_pdf(file):
st.markdown("### PDF Preview")
base64_pdf = base64.b64encode(file.read()).decode("utf-8")
pdf_display = f"""
<iframe src="data:application/pdf;base64,{base64_pdf}"
width="400" height="100%"
style="height:100vh; width:100%">
</iframe>
"""
st.markdown(pdf_display, unsafe_allow_html=True)
# Sidebar for file upload
with st.sidebar:
st.header("Add your documents!")
uploaded_file = st.file_uploader("Choose a `.pdf` file", type="pdf")
if uploaded_file:
try:
# Indexing the doc
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getvalue())
file_key = f"{st.session_state.id}-{uploaded_file.name}"
st.write("Indexing your document...")
if file_key not in st.session_state.get('file_cache', {}):
if os.path.exists(temp_dir):
loader = SimpleDirectoryReader(
input_dir=temp_dir,
required_exts=[".pdf"],
recursive=True
)
else:
st.error("Could not find the file. Please reupload.")
st.stop()
docs = loader.load_data()
# Load the HF-based LLM (DeepSeek-R1)
llm = load_llm()
# HuggingFace Embeddings for the VectorStore
embed_model = HuggingFaceEmbedding(
model_name="answerdotai/ModernBERT-large",
trust_remote_code=True
)
# create a service context
service_context = ServiceContext.from_defaults(
llm=llm,
embed_model=embed_model
)
# build the index
index = VectorStoreIndex.from_documents(
docs,
service_context=service_context,
show_progress=True
)
query_engine = index.as_query_engine(streaming=True)
# custom QA prompt
qa_prompt_tmpl_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context info above, provide a concise answer.\n"
"If you don't know, say 'I don't know'.\n"
"Query: {query_str}\n"
"Answer: "
)
qa_prompt = PromptTemplate(qa_prompt_tmpl_str)
query_engine.update_prompts(
{"response_synthesizer:text_qa_template": qa_prompt}
)
# store in session state
st.session_state.file_cache[file_key] = query_engine
else:
query_engine = st.session_state.file_cache[file_key]
st.success("Ready to Chat!")
display_pdf(uploaded_file)
except Exception as e:
st.error(f"An error occurred: {e}")
st.stop()
col1, col2 = st.columns([6, 1])
with col1:
st.markdown("# RAG with DeepSeek-R1 (700B)")
with col2:
st.button("Clear ↺", on_click=reset_chat)
# Initialize chat if needed
if "messages" not in st.session_state:
reset_chat()
# Render past messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Ask a question about your PDF..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Retrieve the engine
if uploaded_file:
file_key = f"{st.session_state.id}-{uploaded_file.name}"
query_engine = st.session_state.file_cache.get(file_key)
else:
query_engine = None
# If no docs, just return a quick message
if not query_engine:
answer = "No documents indexed. Please upload a PDF first."
st.session_state.messages.append({"role": "assistant", "content": answer})
with st.chat_message("assistant"):
st.markdown(answer)
else:
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
# Streaming generator from LlamaIndex
streaming_response = query_engine.query(prompt)
for chunk in streaming_response.response_gen:
full_response += chunk
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})