Spaces:
Runtime error
Runtime error
File size: 9,314 Bytes
8b75362 8fef609 8a3e7bf 8b75362 c4281a5 8b75362 8fef609 22995d6 8fef609 22995d6 0b7b0cd 3095365 22995d6 8fef609 22995d6 8fef609 8b75362 8fef609 22995d6 8fef609 22995d6 8fef609 22995d6 8fef609 8b75362 8fef609 22995d6 8b75362 22995d6 8fef609 22995d6 8fef609 22995d6 8fef609 22995d6 8fef609 22995d6 8b75362 8a3e7bf 8fef609 8a3e7bf 8fef609 8a3e7bf 8fef609 22995d6 8fef609 22995d6 8fef609 22995d6 25d2990 8fef609 22995d6 8fef609 22995d6 8fef609 22995d6 3095365 fd6ce21 4eba98f 0b7b0cd 4eba98f d8fb495 fd6ce21 4eba98f d8fb495 4eba98f fd6ce21 3095365 fd6ce21 d8fb495 3095365 0b7b0cd 3095365 0b7b0cd 3095365 22995d6 8a3e7bf 6a7b2ad 8a3e7bf ff16b2e 8a3e7bf 8356529 6a7b2ad ff16b2e 6a7b2ad ff16b2e 6a7b2ad 22995d6 ff16b2e 0b7b0cd 22995d6 59b8201 ff16b2e 59b8201 fd778cd 6a7b2ad 22995d6 8fef609 8b75362 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
#!/usr/bin/env python
import os
os.environ["HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
os.environ["HF_METRICS_CACHE"] = "/tmp/huggingface/metrics"
os.environ["GRADIO_FLAGGING_DIR"] = "/tmp/flagged"
import shutil
import json
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.configuration_llama import LlamaConfig
from huggingface_hub import hf_hub_download
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma, FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tempfile import mkdtemp
# Device config
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device => {device}")
# Load local LLaMA model
hf_token = os.environ.get("HF_TOKEN")
model_id = "ChienChung/my-llama-1b"
config_path = hf_hub_download(
repo_id=model_id,
filename="config.json",
use_auth_token=hf_token,
cache_dir="/tmp/huggingface"
)
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
if "rope_scaling" in config_dict:
config_dict["rope_scaling"] = {
"type": "dynamic",
"factor": config_dict["rope_scaling"].get("factor", 32.0)
}
model_config = LlamaConfig.from_dict(config_dict)
model_config.trust_remote_code = True
print("Loading Llama model...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=model_config,
trust_remote_code=True,
use_auth_token=hf_token,
cache_dir="/tmp/huggingface"
)
model.to(device)
print("Model loaded!")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
use_auth_token=hf_token,
cache_dir="/tmp/huggingface"
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded!")
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device != "cpu" else None,
do_sample=False,
temperature=0.0,
max_new_tokens=200,
return_full_text=False
)
# Load embeddings and vector store
print("Loading Chroma DB for Biden Speech...")
if not os.path.exists("/tmp/chroma_db"):
shutil.copytree("./chroma_db", "/tmp/chroma_db")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vectordb = Chroma(persist_directory="/tmp/chroma_db", embedding_function=embeddings)
retriever = vectordb.as_retriever()
# Prompt template for consistent answers
custom_prompt = PromptTemplate(
input_variables=["context", "question"],
template="""You are a helpful AI assistant. Use only the text from the context below to answer the user's question.
If the answer is not in the context, say "No relevant info found."
If the question is not in the context, say "No relevant info found."
Return only the final answer in one to three sentences.
Do not restate the question or context.
Do not include these instructions in your final output.
Context:
{context}
Question: {question}
Answer:
"""
)
# === Tab 1: Local LLaMA RAG ===
llm_local = HuggingFacePipeline(pipeline=query_pipeline)
qa_local = RetrievalQA.from_chain_type(
llm=llm_local,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": custom_prompt}
)
def rag_llama_qa(query):
output = qa_local.run(query)
lower_text = output.lower()
idx = lower_text.find("answer:")
if idx != -1:
return output[idx + len("answer:"):].strip()
return output
# === Tab 2: GPT-4 + FAISS (using same retriever) ===
openai_api_key = os.environ.get("OPENAI_API_KEY")
llm_gpt4 = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.2, openai_api_key=openai_api_key)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa_gpt = ConversationalRetrievalChain.from_llm(
llm=llm_gpt4,
retriever=retriever,
memory=memory,
combine_docs_chain_kwargs={"prompt": custom_prompt}
)
def rag_gpt4_qa(query):
result = qa_gpt.run(query)
return result
# === Tab 3: User Uploaded Docs (GPT-4 RAG) ===
def upload_and_chat(file, query):
# Check if file is a string, dictionary, or file object with a save method.
if isinstance(file, str):
file_path = file # file has already been the file directory
elif isinstance(file, dict):
file_path = file.get("name", None)
if file_path is None:
return "Can't acquire the uploaded directory"
elif hasattr(file, "save"):
temp_dir = mkdtemp()
file_path = os.path.join(temp_dir, file.name)
file.save(file_path)
else:
return "Can't process file format"
if file_path.lower().endswith(".pdf"):
loader = PyPDFLoader(file_path)
elif file_path.lower().endswith(".docx"):
loader = UnstructuredWordDocumentLoader(file_path)
else:
loader = TextLoader(file_path)
docs = loader.load()
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
db = FAISS.from_documents(chunks, embeddings)
retriever = db.as_retriever()
qa_temp = RetrievalQA.from_chain_type(
llm=llm_gpt4,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": custom_prompt}
)
return qa_temp.run(query)
# === Gradio UI ===
demo_description = """
**Context**:
This demo uses a Retrieval-Augmented Generation (RAG) system based on
Biden’s 2023 State of the Union Address.
All responses are grounded in this document.
If no relevant information is found in the document, the system will say "No relevant info found."
**Sample Questions**:
1. What were the main topics regarding infrastructure in this speech?
2. How does the speech address the competition with China?
3. What does Biden say about job growth in the past two years?
4. Does the speech mention anything about Social Security or Medicare?
5. What does the speech propose regarding Big Tech or online privacy?
*Note: The LLaMA module generates responses based solely on the current query without follow-up memory or chat history management.*
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
"""
demo_description2 = """
**Context**:
This demo uses a Retrieval-Augmented Generation (RAG) system based on
Biden’s 2023 State of the Union Address.
All responses are grounded in this document.
If no relevant information is found in the document, the system will say "No relevant info found."
**Sample Questions**:
1. What were the main topics regarding infrastructure in this speech?
2. How does the speech address the competition with China?
3. What does Biden say about job growth in the past two years?
4. Does the speech mention anything about Social Security or Medicare?
5. What does the speech propose regarding Big Tech or online privacy?
*Note: The GPT module supports follow-up questions with conversation history management, enabling more interactive and context-aware discussions.*
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
"""
demo_description3 = """
**Context**:
This demo uses a Retrieval-Augmented Generation (RAG) system based on
your uploaded document.
All responses are grounded in this document.
If no relevant information is found in the document, the system will say "No relevant info found."
*Note: The GPT module supports follow-up questions with conversation history management, enabling more interactive and context-aware discussions.*
Feel free to ask any question related to your document.
"""
demo = gr.TabbedInterface(
interface_list=[
gr.Interface(fn=rag_llama_qa, inputs="text", outputs="text", title="Biden Q&A (LLaMA)", allow_flagging="never",description=demo_description),
gr.Interface(fn=rag_gpt4_qa, inputs="text", outputs="text", title="Biden Q&A (GPT-4)", allow_flagging="never", description=demo_description2),
gr.Interface(fn=upload_and_chat, inputs=[gr.File(label="Upload PDF, TXT, or DOCX"), gr.Textbox(label="Ask a question")], outputs="text", title="Your Docs Q&A (Upload + GPT-4)", allow_flagging="never", description=demo_description3)
],
tab_names=[
"Biden Q&A (LLaMA)",
"Biden Q&A (GPT-4)",
"Your Docs Q&A (Upload + GPT-4)"
],
title="RAG – LLaMA + GPT-4"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |