|
import streamlit as st |
|
import os |
|
import gc |
|
import base64 |
|
import tempfile |
|
import uuid |
|
|
|
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.llms.huggingface import HuggingFaceLLM |
|
from llama_index.prompts import PromptTemplate |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_llm(): |
|
""" |
|
Load the DeepSeek-R1 700B (approx) model from Hugging Face, |
|
using 4-bit quantization and auto device mapping. |
|
""" |
|
model_id = "deepseek-ai/DeepSeek-R1" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
load_in_4bit=True, |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
|
|
llm = HuggingFaceLLM( |
|
model=model, |
|
tokenizer=tokenizer, |
|
streaming=True, |
|
temperature=0.7, |
|
max_new_tokens=512 |
|
) |
|
return llm |
|
|
|
|
|
|
|
|
|
if "id" not in st.session_state: |
|
st.session_state.id = uuid.uuid4() |
|
st.session_state.file_cache = {} |
|
|
|
def reset_chat(): |
|
st.session_state.messages = [] |
|
gc.collect() |
|
|
|
def display_pdf(file): |
|
st.markdown("### PDF Preview") |
|
base64_pdf = base64.b64encode(file.read()).decode("utf-8") |
|
pdf_display = f""" |
|
<iframe src="data:application/pdf;base64,{base64_pdf}" |
|
width="400" height="100%" |
|
style="height:100vh; width:100%"> |
|
</iframe> |
|
""" |
|
st.markdown(pdf_display, unsafe_allow_html=True) |
|
|
|
|
|
with st.sidebar: |
|
st.header("Add your documents!") |
|
|
|
uploaded_file = st.file_uploader("Choose a `.pdf` file", type="pdf") |
|
|
|
if uploaded_file: |
|
try: |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
file_path = os.path.join(temp_dir, uploaded_file.name) |
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.getvalue()) |
|
|
|
file_key = f"{st.session_state.id}-{uploaded_file.name}" |
|
st.write("Indexing your document...") |
|
|
|
if file_key not in st.session_state.get('file_cache', {}): |
|
if os.path.exists(temp_dir): |
|
loader = SimpleDirectoryReader( |
|
input_dir=temp_dir, |
|
required_exts=[".pdf"], |
|
recursive=True |
|
) |
|
else: |
|
st.error("Could not find the file. Please reupload.") |
|
st.stop() |
|
|
|
docs = loader.load_data() |
|
|
|
|
|
llm = load_llm() |
|
|
|
|
|
embed_model = HuggingFaceEmbedding( |
|
model_name="answerdotai/ModernBERT-large", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
service_context = ServiceContext.from_defaults( |
|
llm=llm, |
|
embed_model=embed_model |
|
) |
|
|
|
|
|
index = VectorStoreIndex.from_documents( |
|
docs, |
|
service_context=service_context, |
|
show_progress=True |
|
) |
|
|
|
query_engine = index.as_query_engine(streaming=True) |
|
|
|
|
|
qa_prompt_tmpl_str = ( |
|
"Context information is below.\n" |
|
"---------------------\n" |
|
"{context_str}\n" |
|
"---------------------\n" |
|
"Given the context info above, provide a concise answer.\n" |
|
"If you don't know, say 'I don't know'.\n" |
|
"Query: {query_str}\n" |
|
"Answer: " |
|
) |
|
qa_prompt = PromptTemplate(qa_prompt_tmpl_str) |
|
query_engine.update_prompts( |
|
{"response_synthesizer:text_qa_template": qa_prompt} |
|
) |
|
|
|
|
|
st.session_state.file_cache[file_key] = query_engine |
|
else: |
|
query_engine = st.session_state.file_cache[file_key] |
|
|
|
st.success("Ready to Chat!") |
|
display_pdf(uploaded_file) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
st.stop() |
|
|
|
col1, col2 = st.columns([6, 1]) |
|
with col1: |
|
st.markdown("# RAG with DeepSeek-R1 (700B)") |
|
|
|
with col2: |
|
st.button("Clear ↺", on_click=reset_chat) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
reset_chat() |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask a question about your PDF..."): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
if uploaded_file: |
|
file_key = f"{st.session_state.id}-{uploaded_file.name}" |
|
query_engine = st.session_state.file_cache.get(file_key) |
|
else: |
|
query_engine = None |
|
|
|
|
|
if not query_engine: |
|
answer = "No documents indexed. Please upload a PDF first." |
|
st.session_state.messages.append({"role": "assistant", "content": answer}) |
|
with st.chat_message("assistant"): |
|
st.markdown(answer) |
|
else: |
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
|
|
|
|
streaming_response = query_engine.query(prompt) |
|
for chunk in streaming_response.response_gen: |
|
full_response += chunk |
|
message_placeholder.markdown(full_response + "▌") |
|
|
|
message_placeholder.markdown(full_response) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|