File size: 9,314 Bytes
8b75362
8fef609
8a3e7bf
8b75362
 
 
 
 
 
c4281a5
8b75362
8fef609
 
 
 
 
 
 
22995d6
8fef609
22995d6
 
 
 
 
 
 
 
0b7b0cd
3095365
 
22995d6
 
8fef609
 
 
 
 
 
 
 
22995d6
8fef609
 
 
8b75362
 
 
 
 
 
8fef609
 
22995d6
8fef609
22995d6
 
 
 
8fef609
 
 
 
22995d6
8fef609
 
 
 
 
8b75362
8fef609
 
 
 
 
22995d6
 
 
 
8b75362
22995d6
8fef609
 
 
 
 
 
 
 
 
22995d6
8fef609
 
 
 
 
 
22995d6
 
 
 
8fef609
22995d6
 
 
8fef609
22995d6
 
 
 
8b75362
8a3e7bf
 
 
8fef609
8a3e7bf
 
8fef609
 
 
 
8a3e7bf
8fef609
 
 
22995d6
8fef609
22995d6
 
 
 
 
 
 
 
8fef609
22995d6
 
 
 
 
 
 
 
 
 
 
25d2990
8fef609
22995d6
 
 
8fef609
22995d6
8fef609
 
22995d6
 
 
 
3095365
 
fd6ce21
4eba98f
0b7b0cd
4eba98f
d8fb495
 
fd6ce21
4eba98f
d8fb495
 
 
4eba98f
fd6ce21
3095365
fd6ce21
d8fb495
3095365
0b7b0cd
 
3095365
 
0b7b0cd
3095365
 
 
 
 
 
 
 
 
 
 
 
 
 
22995d6
8a3e7bf
 
 
 
 
6a7b2ad
8a3e7bf
 
 
 
 
 
 
 
ff16b2e
 
8a3e7bf
 
8356529
6a7b2ad
 
ff16b2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a7b2ad
 
 
 
ff16b2e
 
6a7b2ad
 
22995d6
 
ff16b2e
 
0b7b0cd
22995d6
59b8201
ff16b2e
 
59b8201
fd778cd
6a7b2ad
22995d6
8fef609
 
8b75362
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#!/usr/bin/env python
import os

os.environ["HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
os.environ["HF_METRICS_CACHE"] = "/tmp/huggingface/metrics"
os.environ["GRADIO_FLAGGING_DIR"] = "/tmp/flagged"

import shutil
import json
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.configuration_llama import LlamaConfig
from huggingface_hub import hf_hub_download
import gradio as gr

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma, FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tempfile import mkdtemp

# Device config
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device => {device}")

# Load local LLaMA model
hf_token = os.environ.get("HF_TOKEN")
model_id = "ChienChung/my-llama-1b"

config_path = hf_hub_download(
    repo_id=model_id, 
    filename="config.json", 
    use_auth_token=hf_token, 
    cache_dir="/tmp/huggingface"
)
with open(config_path, "r", encoding="utf-8") as f:
    config_dict = json.load(f)

if "rope_scaling" in config_dict:
    config_dict["rope_scaling"] = {
        "type": "dynamic",
        "factor": config_dict["rope_scaling"].get("factor", 32.0)
    }

model_config = LlamaConfig.from_dict(config_dict)
model_config.trust_remote_code = True

print("Loading Llama model...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    config=model_config,
    trust_remote_code=True,
    use_auth_token=hf_token,
    cache_dir="/tmp/huggingface"
)
model.to(device)
print("Model loaded!")

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    trust_remote_code=True,
    use_auth_token=hf_token,
    cache_dir="/tmp/huggingface"
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded!")

query_pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device != "cpu" else None,
    do_sample=False,
    temperature=0.0,
    max_new_tokens=200,
    return_full_text=False
)

# Load embeddings and vector store
print("Loading Chroma DB for Biden Speech...")
if not os.path.exists("/tmp/chroma_db"):
    shutil.copytree("./chroma_db", "/tmp/chroma_db")

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vectordb = Chroma(persist_directory="/tmp/chroma_db", embedding_function=embeddings)
retriever = vectordb.as_retriever()

# Prompt template for consistent answers
custom_prompt = PromptTemplate(
    input_variables=["context", "question"],
    template="""You are a helpful AI assistant. Use only the text from the context below to answer the user's question.
If the answer is not in the context, say "No relevant info found."
If the question is not in the context, say "No relevant info found."


Return only the final answer in one to three sentences.
Do not restate the question or context. 
Do not include these instructions in your final output.

Context:
{context}

Question: {question}

Answer:
"""
)

# === Tab 1: Local LLaMA RAG ===
llm_local = HuggingFacePipeline(pipeline=query_pipeline)
qa_local = RetrievalQA.from_chain_type(
    llm=llm_local,
    chain_type="stuff",
    retriever=retriever,
    return_source_documents=False,
    chain_type_kwargs={"prompt": custom_prompt}
)

def rag_llama_qa(query):
    output = qa_local.run(query)
    lower_text = output.lower()
    idx = lower_text.find("answer:")
    if idx != -1:
        return output[idx + len("answer:"):].strip()
    return output

# === Tab 2: GPT-4 + FAISS (using same retriever) ===
openai_api_key = os.environ.get("OPENAI_API_KEY")
llm_gpt4 = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.2, openai_api_key=openai_api_key)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa_gpt = ConversationalRetrievalChain.from_llm(
    llm=llm_gpt4,
    retriever=retriever,
    memory=memory,
    combine_docs_chain_kwargs={"prompt": custom_prompt}
)

def rag_gpt4_qa(query):
    result = qa_gpt.run(query)
    return result

# === Tab 3: User Uploaded Docs (GPT-4 RAG) ===
def upload_and_chat(file, query):
    # Check if file is a string, dictionary, or file object with a save method.
    if isinstance(file, str):
        file_path = file  # file has already been the file directory 
    elif isinstance(file, dict):
        file_path = file.get("name", None)
        if file_path is None:
            return "Can't acquire the uploaded directory"
    elif hasattr(file, "save"):
        temp_dir = mkdtemp()
        file_path = os.path.join(temp_dir, file.name)
        file.save(file_path)
    else:
        return "Can't process file format"

    
    if file_path.lower().endswith(".pdf"):
        loader = PyPDFLoader(file_path)
    elif file_path.lower().endswith(".docx"):
        loader = UnstructuredWordDocumentLoader(file_path)
    else:
        loader = TextLoader(file_path)
        
    docs = loader.load()
    chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
    db = FAISS.from_documents(chunks, embeddings)
    retriever = db.as_retriever()

    qa_temp = RetrievalQA.from_chain_type(
        llm=llm_gpt4,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=False,
        chain_type_kwargs={"prompt": custom_prompt}
    )
    return qa_temp.run(query)

# === Gradio UI ===
demo_description = """
**Context**:
This demo uses a Retrieval-Augmented Generation (RAG) system based on 
Biden’s 2023 State of the Union Address. 
All responses are grounded in this document. 
If no relevant information is found in the document, the system will say "No relevant info found."

**Sample Questions**:
1. What were the main topics regarding infrastructure in this speech?
2. How does the speech address the competition with China?
3. What does Biden say about job growth in the past two years?
4. Does the speech mention anything about Social Security or Medicare?
5. What does the speech propose regarding Big Tech or online privacy?

*Note: The LLaMA module generates responses based solely on the current query without follow-up memory or chat history management.*

Feel free to ask any question related to Biden’s 2023 State of the Union Address.
"""
demo_description2 = """
**Context**:
This demo uses a Retrieval-Augmented Generation (RAG) system based on 
Biden’s 2023 State of the Union Address. 
All responses are grounded in this document. 
If no relevant information is found in the document, the system will say "No relevant info found."

**Sample Questions**:
1. What were the main topics regarding infrastructure in this speech?
2. How does the speech address the competition with China?
3. What does Biden say about job growth in the past two years?
4. Does the speech mention anything about Social Security or Medicare?
5. What does the speech propose regarding Big Tech or online privacy?

*Note: The GPT module supports follow-up questions with conversation history management, enabling more interactive and context-aware discussions.*

Feel free to ask any question related to Biden’s 2023 State of the Union Address.
"""
demo_description3 = """
**Context**:
This demo uses a Retrieval-Augmented Generation (RAG) system based on 
your uploaded document. 
All responses are grounded in this document. 
If no relevant information is found in the document, the system will say "No relevant info found."

*Note: The GPT module supports follow-up questions with conversation history management, enabling more interactive and context-aware discussions.*

Feel free to ask any question related to your document.
"""
demo = gr.TabbedInterface(
    interface_list=[
        gr.Interface(fn=rag_llama_qa, inputs="text", outputs="text", title="Biden Q&A (LLaMA)", allow_flagging="never",description=demo_description),
        gr.Interface(fn=rag_gpt4_qa, inputs="text", outputs="text", title="Biden Q&A (GPT-4)", allow_flagging="never", description=demo_description2),
        gr.Interface(fn=upload_and_chat, inputs=[gr.File(label="Upload PDF, TXT, or DOCX"), gr.Textbox(label="Ask a question")], outputs="text", title="Your Docs Q&A (Upload + GPT-4)", allow_flagging="never", description=demo_description3)
    ],
    tab_names=[
        "Biden Q&A (LLaMA)",
        "Biden Q&A (GPT-4)",
        "Your Docs Q&A (Upload + GPT-4)"
    ],
    title="RAG – LLaMA + GPT-4"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)