halimbahae commited on
Commit
a881166
1 Parent(s): 1335f41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -77
app.py CHANGED
@@ -1,30 +1,20 @@
1
  import os
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
 
 
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.vectorstores import Chroma
6
- from langchain.embeddings import HuggingFaceBgeEmbeddings
7
- from langchain.document_loaders import PyPDFLoader, UnstructuredFileLoader, CSVLoader
8
- from langchain.chains import RetrievalQA
9
- from langchain.prompts import PromptTemplate
10
 
11
- # # Initialize the Zephyr client
12
- # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
13
-
14
- from huggingface_hub import InferenceClient
15
-
16
- # Access the Hugging Face token from environment variables
17
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
18
-
19
  if not HF_API_TOKEN:
20
  raise ValueError("Hugging Face API token is not set in environment variables.")
21
 
22
- # Initialize the client with the token
23
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=HF_API_TOKEN)
24
 
25
-
26
-
27
- # Function to load documents based on file type
28
  def load_documents(file_path):
29
  if file_path.endswith(".pdf"):
30
  loader = PyPDFLoader(file_path)
@@ -34,54 +24,50 @@ def load_documents(file_path):
34
  loader = CSVLoader(file_path)
35
  else:
36
  raise ValueError("Unsupported file format")
37
-
38
  documents = loader.load()
39
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
40
  return text_splitter.split_documents(documents)
41
 
42
- # Function to create or update vector store
43
  def create_vector_store(documents, persist_dir="vector_db"):
44
  embeddings = HuggingFaceBgeEmbeddings(
45
- model_name="BAAI/bge-large-en",
46
- model_kwargs={"device": "cpu"}
47
  )
48
  vector_store = Chroma.from_documents(documents, embeddings, persist_directory=persist_dir)
49
  return vector_store
50
 
51
- # Function to handle user queries
52
- def respond(message, history, system_message, max_tokens, temperature, top_p, retriever):
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Retrieve relevant context
54
  relevant_docs = retriever.get_relevant_documents(message)
55
  context = "\n".join([doc.page_content for doc in relevant_docs])
56
-
57
- # Format the prompt
58
- prompt_template = """
59
- Use the following context to answer the user's question.
60
- If you don't know the answer, say "I don't know."
61
-
62
  Context:
63
  {context}
64
-
65
  Question:
66
- {question}
67
-
68
- Answer:
69
- """
70
- formatted_prompt = prompt_template.format(context=context, question=message)
71
-
72
- # Build conversational history
73
- messages = [{"role": "system", "content": system_message}]
74
- for val in history:
75
- if val[0]:
76
- messages.append({"role": "user", "content": val[0]})
77
- if val[1]:
78
- messages.append({"role": "assistant", "content": val[1]})
79
- messages.append({"role": "user", "content": formatted_prompt})
80
-
81
- # Stream response from Zephyr
82
  response = ""
83
  for msg in client.chat_completion(
84
- messages=messages,
85
  max_tokens=max_tokens,
86
  stream=True,
87
  temperature=temperature,
@@ -91,37 +77,9 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, re
91
  response += token
92
  yield response
93
 
94
- # Initialize the vector store
95
- persist_dir = "vector_db"
96
- retriever = None # Will be initialized dynamically
97
-
98
- def handle_query(message, history, system_message, max_tokens, temperature, top_p, file=None):
99
- global retriever
100
- if file: # Process uploaded file
101
- documents = load_documents(file.name)
102
- vector_store = create_vector_store(documents, persist_dir)
103
- retriever = vector_store.as_retriever()
104
- if not retriever:
105
- return "No documents have been uploaded yet. Please upload a file to provide context."
106
- return respond(message, history, system_message, max_tokens, temperature, top_p, retriever)
107
-
108
  # Gradio app setup
109
- # demo = gr.ChatInterface(
110
- # fn=handle_query,
111
- # additional_inputs=[
112
- # gr.File(label="Upload File", type="file"),
113
- # gr.Textbox(value="You are a knowledgeable assistant.", label="System Message"),
114
- # gr.Slider(1, 2048, step=1, value=512, label="Max Tokens"),
115
- # gr.Slider(0.1, 4.0, step=0.1, value=0.7, label="Temperature"),
116
- # gr.Slider(0.1, 1.0, step=0.05, value=0.95, label="Top-p"),
117
- # ],
118
- # outputs="text",
119
- # title="RAG with Zephyr-7B",
120
- # description="A Retrieval-Augmented Generation chatbot powered by Zephyr-7B and Chroma vector database.",
121
- # )
122
-
123
  demo = gr.Interface(
124
- fn=handle_uploaded_file, # Handle uploaded files
125
  inputs=[
126
  gr.File(label="Upload Document"),
127
  gr.Textbox(value="You are a knowledgeable assistant.", label="System Message"),
@@ -135,7 +93,5 @@ demo = gr.Interface(
135
  description="Upload documents and ask questions using RAG.",
136
  )
137
 
138
-
139
-
140
  if __name__ == "__main__":
141
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
6
+ from langchain_community.document_loaders import PyPDFLoader, UnstructuredFileLoader, CSVLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
 
 
 
8
 
9
+ # Load Hugging Face API token
 
 
 
 
 
10
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
 
11
  if not HF_API_TOKEN:
12
  raise ValueError("Hugging Face API token is not set in environment variables.")
13
 
14
+ # Initialize Zephyr client
15
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=HF_API_TOKEN)
16
 
17
+ # Load documents based on file type
 
 
18
  def load_documents(file_path):
19
  if file_path.endswith(".pdf"):
20
  loader = PyPDFLoader(file_path)
 
24
  loader = CSVLoader(file_path)
25
  else:
26
  raise ValueError("Unsupported file format")
 
27
  documents = loader.load()
28
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
29
  return text_splitter.split_documents(documents)
30
 
31
+ # Create vector store
32
  def create_vector_store(documents, persist_dir="vector_db"):
33
  embeddings = HuggingFaceBgeEmbeddings(
34
+ model_name="BAAI/bge-large-en",
35
+ model_kwargs={"device": "cpu"},
36
  )
37
  vector_store = Chroma.from_documents(documents, embeddings, persist_directory=persist_dir)
38
  return vector_store
39
 
40
+ # Initialize retriever and vector store
41
+ persist_dir = "vector_db"
42
+ retriever = None # Will be dynamically updated
43
+
44
+ # Handle queries and uploads
45
+ def handle_query(message, history, system_message, max_tokens, temperature, top_p, file=None):
46
+ global retriever
47
+ if file: # If a file is uploaded, process it
48
+ documents = load_documents(file.name)
49
+ vector_store = create_vector_store(documents, persist_dir)
50
+ retriever = vector_store.as_retriever()
51
+ if not retriever:
52
+ return "No documents uploaded yet. Please upload a file first."
53
+
54
  # Retrieve relevant context
55
  relevant_docs = retriever.get_relevant_documents(message)
56
  context = "\n".join([doc.page_content for doc in relevant_docs])
57
+
58
+ # Build the prompt
59
+ prompt = f"""
60
+ Use the following context to answer the user's question.
 
 
61
  Context:
62
  {context}
63
+
64
  Question:
65
+ {message}
66
+ Answer:"""
67
+
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  response = ""
69
  for msg in client.chat_completion(
70
+ messages=[{"role": "system", "content": system_message}, {"role": "user", "content": prompt}],
71
  max_tokens=max_tokens,
72
  stream=True,
73
  temperature=temperature,
 
77
  response += token
78
  yield response
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Gradio app setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  demo = gr.Interface(
82
+ fn=handle_query,
83
  inputs=[
84
  gr.File(label="Upload Document"),
85
  gr.Textbox(value="You are a knowledgeable assistant.", label="System Message"),
 
93
  description="Upload documents and ask questions using RAG.",
94
  )
95
 
 
 
96
  if __name__ == "__main__":
97
  demo.launch()