hbui commited on
Commit
170741d
1 Parent(s): ae6c675

llama-index-update (#1)

Browse files

- updated llama index demo (b83dc9cf5500e038115cf835e0c466eb12b30905)
- Merge branch 'main' of https://huggingface.co/spaces/hbui/RegBot4.0 (e262fe5388f1bbd46030de0fafcb1f1b1b32b6db)
- updated llm configs and prompts (ab0abe29e798a3f31c482100f4139634171c2fa7)

app.py CHANGED
@@ -1,13 +1,9 @@
1
- # https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
2
-
3
  import os
4
-
5
  import openai
6
  import requests
7
  import streamlit as st
8
 
9
  from utils.util import *
10
-
11
  from langchain.memory import ConversationBufferMemory
12
 
13
  SAVE_DIR = "uploaded_files"
@@ -17,30 +13,24 @@ os.makedirs(SAVE_DIR, exist_ok=True)
17
  def init_session_state():
18
  if "openai_api_key" not in st.session_state:
19
  st.session_state.openai_api_key = ""
20
-
21
  if "uploaded_files" not in st.session_state:
22
  st.session_state.uploaded_files = os.listdir(SAVE_DIR)
 
 
23
 
24
 
25
  init_session_state()
26
 
27
  st.set_page_config(page_title="RegBotBeta", page_icon="📜🤖")
28
-
29
  st.title("Welcome to RegBotBeta2.0")
30
- st.header("Powered by `LlamaIndex🦙`, `Langchain🦜🔗 ` and `OpenAI API`")
31
 
32
 
33
- def init_session_state():
34
- if "huggingface_token" not in st.session_state:
35
- st.session_state.huggingface_token = ""
36
-
37
-
38
- init_session_state()
39
-
40
  uploaded_files = st.file_uploader(
41
  "Upload Files",
42
  accept_multiple_files=True,
43
  type=["pdf", "docx", "txt", "csv"],
 
44
  )
45
 
46
  if uploaded_files:
@@ -48,14 +38,27 @@ if uploaded_files:
48
  if file not in st.session_state.uploaded_files:
49
  # add the file to session state
50
  st.session_state.uploaded_files.append(file.name)
51
-
52
  # save the file to the sample_data directory
53
  with open(os.path.join(SAVE_DIR, file.name), "wb") as f:
54
  f.write(file.getbuffer())
55
-
56
  st.success("File(s) uploaded successfully!")
57
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if st.session_state.uploaded_files:
59
  st.write("Uploaded Files:")
60
- for i, filename in enumerate(st.session_state.uploaded_files, start=1):
61
- st.write(f"{i}. {filename}")
 
 
 
 
 
 
1
  import os
 
2
  import openai
3
  import requests
4
  import streamlit as st
5
 
6
  from utils.util import *
 
7
  from langchain.memory import ConversationBufferMemory
8
 
9
  SAVE_DIR = "uploaded_files"
 
13
  def init_session_state():
14
  if "openai_api_key" not in st.session_state:
15
  st.session_state.openai_api_key = ""
 
16
  if "uploaded_files" not in st.session_state:
17
  st.session_state.uploaded_files = os.listdir(SAVE_DIR)
18
+ if "huggingface_token" not in st.session_state:
19
+ st.session_state.huggingface_token = ""
20
 
21
 
22
  init_session_state()
23
 
24
  st.set_page_config(page_title="RegBotBeta", page_icon="📜🤖")
 
25
  st.title("Welcome to RegBotBeta2.0")
26
+ st.header("Powered by `LlamaIndex🦙`, `Langchain🦜🔗` and `OpenAI API`")
27
 
28
 
 
 
 
 
 
 
 
29
  uploaded_files = st.file_uploader(
30
  "Upload Files",
31
  accept_multiple_files=True,
32
  type=["pdf", "docx", "txt", "csv"],
33
+ label_visibility="hidden",
34
  )
35
 
36
  if uploaded_files:
 
38
  if file not in st.session_state.uploaded_files:
39
  # add the file to session state
40
  st.session_state.uploaded_files.append(file.name)
 
41
  # save the file to the sample_data directory
42
  with open(os.path.join(SAVE_DIR, file.name), "wb") as f:
43
  f.write(file.getbuffer())
 
44
  st.success("File(s) uploaded successfully!")
45
 
46
+
47
+ def delete_file(filename):
48
+ """Delete file from session state and local filesystem."""
49
+ if filename in st.session_state.uploaded_files and os.path.exists(
50
+ os.path.join(SAVE_DIR, filename)
51
+ ):
52
+ st.session_state.uploaded_files.remove(filename)
53
+ os.remove(os.path.join(SAVE_DIR, filename))
54
+ st.success(f"Deleted {filename}!")
55
+ st.rerun()
56
+
57
+
58
  if st.session_state.uploaded_files:
59
  st.write("Uploaded Files:")
60
+ for index, filename in enumerate(st.session_state.uploaded_files):
61
+ col1, col2 = st.columns([4, 1])
62
+ col1.write(filename)
63
+ if col2.button("Delete", key=f"delete_{index}"):
64
+ delete_file(filename)
models/llamaCustom.py CHANGED
@@ -18,11 +18,7 @@ from assets.prompts import custom_prompts
18
 
19
  # llama index
20
  from llama_index.core import (
21
- StorageContext,
22
- SimpleDirectoryReader,
23
  VectorStoreIndex,
24
- load_index_from_storage,
25
- PromptHelper,
26
  PromptTemplate,
27
  )
28
  from llama_index.core.llms import (
@@ -47,14 +43,12 @@ NUM_OUTPUT = 525
47
  # set maximum chunk overlap
48
  CHUNK_OVERLAP_RATION = 0.2
49
 
50
- # TODO: use the following prompt to format the answer at the end of the context prompt
51
  ANSWER_FORMAT = """
52
- Use the following example format for your answer:
53
  [FORMAT]
54
- Answer:
55
- The answer to the user question.
56
  Reference:
57
- The list of references to the specific sections of the documents that support your answer.
58
  [END_FORMAT]
59
  """
60
 
@@ -184,9 +178,13 @@ class LlamaCustom:
184
 
185
  def get_response(self, query_str: str, chat_history: List[ChatMessage]):
186
  # https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/
 
 
187
  query_engine = self.index.as_query_engine(
188
- text_qa_template=PromptTemplate(QUERY_ENGINE_QA_TEMPLATE),
189
- refine_template=PromptTemplate(QUERY_ENGINE_REFINE_TEMPLATE),
 
 
190
  verbose=self.verbose,
191
  )
192
  # chat_engine = self.index.as_chat_engine(
@@ -196,6 +194,7 @@ class LlamaCustom:
196
  # condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
197
  # # verbose=True,
198
  # )
 
199
  response = query_engine.query(query_str)
200
  # response = chat_engine.chat(message=query_str, chat_history=chat_history)
201
 
 
18
 
19
  # llama index
20
  from llama_index.core import (
 
 
21
  VectorStoreIndex,
 
 
22
  PromptTemplate,
23
  )
24
  from llama_index.core.llms import (
 
43
  # set maximum chunk overlap
44
  CHUNK_OVERLAP_RATION = 0.2
45
 
 
46
  ANSWER_FORMAT = """
47
+ Provide the answer to the user question in the following format:
48
  [FORMAT]
49
+ Your answer to the user question above.
 
50
  Reference:
51
+ The list of references (such as page number, title, chapter, section) to the specific sections of the documents that support your answer.
52
  [END_FORMAT]
53
  """
54
 
 
178
 
179
  def get_response(self, query_str: str, chat_history: List[ChatMessage]):
180
  # https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/
181
+ # https://docs.llamaindex.ai/en/stable/examples/query_engine/citation_query_engine/
182
+ # https://docs.llamaindex.ai/en/stable/examples/query_engine/knowledge_graph_rag_query_engine/
183
  query_engine = self.index.as_query_engine(
184
+ text_qa_template=PromptTemplate(QUERY_ENGINE_QA_TEMPLATE + ANSWER_FORMAT),
185
+ refine_template=PromptTemplate(
186
+ QUERY_ENGINE_REFINE_TEMPLATE
187
+ ), # passing ANSWER_FORMAT here will not give the desired output, need to use the output parser from llama index?
188
  verbose=self.verbose,
189
  )
190
  # chat_engine = self.index.as_chat_engine(
 
194
  # condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
195
  # # verbose=True,
196
  # )
197
+
198
  response = query_engine.query(query_str)
199
  # response = chat_engine.chat(message=query_str, chat_history=chat_history)
200
 
models/llamaCustomV2.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from llama_index.core import VectorStoreIndex
4
+ from llama_index.core.query_pipeline import (
5
+ QueryPipeline,
6
+ InputComponent,
7
+ ArgPackComponent,
8
+ )
9
+ from llama_index.core.prompts import PromptTemplate
10
+ from llama_index.llms.openai import OpenAI
11
+ from llama_index.postprocessor.colbert_rerank import ColbertRerank
12
+ from typing import Any, Dict, List, Optional
13
+ from llama_index.core.bridge.pydantic import Field
14
+ from llama_index.core.llms import ChatMessage
15
+ from llama_index.core.query_pipeline import CustomQueryComponent
16
+ from llama_index.core.schema import NodeWithScore
17
+ from llama_index.core.memory import ChatMemoryBuffer
18
+
19
+
20
+ llm = OpenAI(
21
+ model="gpt-3.5-turbo-0125",
22
+ api_key=os.getenv("OPENAI_API_KEY"),
23
+ )
24
+
25
+ # First, we create an input component to capture the user query
26
+ input_component = InputComponent()
27
+
28
+ # Next, we use the LLM to rewrite a user query
29
+ rewrite = (
30
+ "Please write a query to a semantic search engine using the current conversation.\n"
31
+ "\n"
32
+ "\n"
33
+ "{chat_history_str}"
34
+ "\n"
35
+ "\n"
36
+ "Latest message: {query_str}\n"
37
+ 'Query:"""\n'
38
+ )
39
+ rewrite_template = PromptTemplate(rewrite)
40
+
41
+ # we will retrieve two times, so we need to pack the retrieved nodes into a single list
42
+ argpack_component = ArgPackComponent()
43
+
44
+ # then postprocess/rerank with Colbert
45
+ reranker = ColbertRerank(top_n=3)
46
+
47
+ DEFAULT_CONTEXT_PROMPT = (
48
+ "Here is some context that may be relevant:\n"
49
+ "-----\n"
50
+ "{node_context}\n"
51
+ "-----\n"
52
+ "Please write a response to the following question, using the above context:\n"
53
+ "{query_str}\n"
54
+ "Please formate your response in the following way:\n"
55
+ "Your answer here.\n"
56
+ "Reference:\n"
57
+ " Your references here (e.g. page numbers, titles, etc.).\n"
58
+ )
59
+
60
+
61
+ class ResponseWithChatHistory(CustomQueryComponent):
62
+ llm: OpenAI = Field(..., description="OpenAI LLM")
63
+ system_prompt: Optional[str] = Field(
64
+ default=None, description="System prompt to use for the LLM"
65
+ )
66
+ context_prompt: str = Field(
67
+ default=DEFAULT_CONTEXT_PROMPT,
68
+ description="Context prompt to use for the LLM",
69
+ )
70
+
71
+ def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
72
+ """Validate component inputs during run_component."""
73
+ # NOTE: this is OPTIONAL but we show you where to do validation as an example
74
+ return input
75
+
76
+ @property
77
+ def _input_keys(self) -> set:
78
+ """Input keys dict."""
79
+ # NOTE: These are required inputs. If you have optional inputs please override
80
+ # `optional_input_keys_dict`
81
+ return {"chat_history", "nodes", "query_str"}
82
+
83
+ @property
84
+ def _output_keys(self) -> set:
85
+ return {"response"}
86
+
87
+ def _prepare_context(
88
+ self,
89
+ chat_history: List[ChatMessage],
90
+ nodes: List[NodeWithScore],
91
+ query_str: str,
92
+ ) -> List[ChatMessage]:
93
+ node_context = ""
94
+ for idx, node in enumerate(nodes):
95
+ node_text = node.get_content(metadata_mode="llm")
96
+ node_context += f"Context Chunk {idx}:\n{node_text}\n\n"
97
+
98
+ formatted_context = self.context_prompt.format(
99
+ node_context=node_context, query_str=query_str
100
+ )
101
+ user_message = ChatMessage(role="user", content=formatted_context)
102
+
103
+ chat_history.append(user_message)
104
+
105
+ if self.system_prompt is not None:
106
+ chat_history = [
107
+ ChatMessage(role="system", content=self.system_prompt)
108
+ ] + chat_history
109
+
110
+ return chat_history
111
+
112
+ def _run_component(self, **kwargs) -> Dict[str, Any]:
113
+ """Run the component."""
114
+ chat_history = kwargs["chat_history"]
115
+ nodes = kwargs["nodes"]
116
+ query_str = kwargs["query_str"]
117
+
118
+ prepared_context = self._prepare_context(chat_history, nodes, query_str)
119
+
120
+ response = llm.chat(prepared_context)
121
+
122
+ return {"response": response}
123
+
124
+ async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
125
+ """Run the component asynchronously."""
126
+ # NOTE: Optional, but async LLM calls are easy to implement
127
+ chat_history = kwargs["chat_history"]
128
+ nodes = kwargs["nodes"]
129
+ query_str = kwargs["query_str"]
130
+
131
+ prepared_context = self._prepare_context(chat_history, nodes, query_str)
132
+
133
+ response = await llm.achat(prepared_context)
134
+
135
+ return {"response": response}
136
+
137
+
138
+ class LlamaCustomV2:
139
+ response_component = ResponseWithChatHistory(
140
+ llm=llm,
141
+ system_prompt=(
142
+ "You are a Q&A system. You will be provided with the previous chat history, "
143
+ "as well as possibly relevant context, to assist in answering a user message."
144
+ ),
145
+ )
146
+
147
+ def __init__(self, model_name: str, index: VectorStoreIndex):
148
+ self.model_name = model_name
149
+ self.index = index
150
+ self.retriever = index.as_retriever()
151
+ self.chat_mode = "condense_plus_context"
152
+ self.memory = ChatMemoryBuffer.from_defaults()
153
+ self.verbose = True
154
+ self._build_pipeline()
155
+
156
+ def _build_pipeline(self):
157
+ self.pipeline = QueryPipeline(
158
+ modules={
159
+ "input": input_component,
160
+ "rewrite_template": rewrite_template,
161
+ "llm": llm,
162
+ "rewrite_retriever": self.retriever,
163
+ "query_retriever": self.retriever,
164
+ "join": argpack_component,
165
+ "reranker": reranker,
166
+ "response_component": self.response_component,
167
+ },
168
+ verbose=self.verbose,
169
+ )
170
+ # run both retrievers -- once with the hallucinated query, once with the real query
171
+ self.pipeline.add_link(
172
+ "input", "rewrite_template", src_key="query_str", dest_key="query_str"
173
+ )
174
+ self.pipeline.add_link(
175
+ "input",
176
+ "rewrite_template",
177
+ src_key="chat_history_str",
178
+ dest_key="chat_history_str",
179
+ )
180
+ self.pipeline.add_link("rewrite_template", "llm")
181
+ self.pipeline.add_link("llm", "rewrite_retriever")
182
+ self.pipeline.add_link("input", "query_retriever", src_key="query_str")
183
+
184
+ # each input to the argpack component needs a dest key -- it can be anything
185
+ # then, the argpack component will pack all the inputs into a single list
186
+ self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes")
187
+ self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes")
188
+
189
+ # reranker needs the packed nodes and the query string
190
+ self.pipeline.add_link("join", "reranker", dest_key="nodes")
191
+ self.pipeline.add_link(
192
+ "input", "reranker", src_key="query_str", dest_key="query_str"
193
+ )
194
+
195
+ # synthesizer needs the reranked nodes and query str
196
+ self.pipeline.add_link("reranker", "response_component", dest_key="nodes")
197
+ self.pipeline.add_link(
198
+ "input", "response_component", src_key="query_str", dest_key="query_str"
199
+ )
200
+ self.pipeline.add_link(
201
+ "input",
202
+ "response_component",
203
+ src_key="chat_history",
204
+ dest_key="chat_history",
205
+ )
206
+
207
+ def get_response(self, query_str: str, chat_history: List[ChatMessage]):
208
+ chat_history = self.memory.get()
209
+ char_history_str = "\n".join([str(x) for x in chat_history])
210
+
211
+ response = self.pipeline.run(
212
+ query_str=query_str,
213
+ chat_history=chat_history,
214
+ chat_history_str=char_history_str,
215
+ )
216
+
217
+ user_msg = ChatMessage(role="user", content=query_str)
218
+ print("user_msg: ", str(user_msg))
219
+ print("response: ", str(response.message))
220
+ self.memory.put(user_msg)
221
+ self.memory.put(response.message)
222
+
223
+ return str(response.message)
224
+
225
+ def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]):
226
+ response = self.get_response(query_str=query_str, chat_history=chat_history)
227
+ for word in response.split():
228
+ yield word + " "
229
+ time.sleep(0.05)
models/llms.py CHANGED
@@ -35,6 +35,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
35
  llm_gpt_3_5_turbo_0125 = OpenAI(
36
  model=model_name,
37
  api_key=st.session_state.openai_api_key,
 
38
  )
39
 
40
  return llm_gpt_3_5_turbo_0125
@@ -45,6 +46,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
45
  is_chat_model=True,
46
  additional_kwargs={"max_new_tokens": 250},
47
  prompt_key=st.session_state.replicate_api_token,
 
48
  )
49
 
50
  return llm_llama_13b_v2_replicate
 
35
  llm_gpt_3_5_turbo_0125 = OpenAI(
36
  model=model_name,
37
  api_key=st.session_state.openai_api_key,
38
+ temperature=0.0,
39
  )
40
 
41
  return llm_gpt_3_5_turbo_0125
 
46
  is_chat_model=True,
47
  additional_kwargs={"max_new_tokens": 250},
48
  prompt_key=st.session_state.replicate_api_token,
49
+ temperature=0.0,
50
  )
51
 
52
  return llm_llama_13b_v2_replicate
models/vector_database.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone, ServerlessSpec
2
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
3
+ from dotenv import load_dotenv
4
+
5
+ import os
6
+
7
+ load_dotenv()
8
+
9
+ # Pinecone Vector Database
10
+ pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
11
+ pc_index_name = "llama-integration-pinecone"
12
+ # pc_index_name = "openai-embeddings"
13
+ pc_indexes = pc.list_indexes()
14
+
15
+ # Check if the index already exists
16
+ def index_exists(index_name):
17
+ for index in pc_indexes:
18
+ if index["name"] == index_name:
19
+ return True
20
+ return False
21
+
22
+ # Create the index if it doesn't exist
23
+ if not index_exists(pc_index_name):
24
+ pc.create_index(
25
+ name=pc_index_name,
26
+ dimension=1536,
27
+ spec=ServerlessSpec(cloud="aws", region="us-east-1"),
28
+ )
29
+
30
+ # Initialize your index
31
+ pinecone_index = pc.Index(pc_index_name)
32
+
33
+ # Define the vector store
34
+ pinecone_vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
pages/llama_custom_demo.py CHANGED
@@ -7,6 +7,9 @@ from typing import List
7
  from models.llms import load_llm, integrated_llms
8
  from models.embeddings import hf_embed_model, openai_embed_model
9
  from models.llamaCustom import LlamaCustom
 
 
 
10
  from utils.chatbox import show_previous_messages, show_chat_input
11
  from utils.util import validate_openai_api_key
12
 
@@ -30,7 +33,8 @@ VECTOR_STORE_DIR = "vectorStores"
30
  HF_REPO_ID = "zhtet/RegBotBeta"
31
 
32
  # global
33
- Settings.embed_model = hf_embed_model
 
34
 
35
  # huggingface api
36
  hf_api = HfApi()
@@ -62,9 +66,10 @@ def init_session_state():
62
 
63
 
64
  # @st.cache_resource
65
- def index_docs(
66
  filename: str,
67
  ) -> VectorStoreIndex:
 
68
  try:
69
  index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
70
  if pathlib.Path.exists(index_path):
@@ -89,6 +94,23 @@ def index_docs(
89
  return index
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def check_api_key(model_name: str, source: str):
93
  if source.startswith("openai"):
94
  if not st.session_state.openai_api_key:
@@ -164,6 +186,13 @@ with tab1:
164
  label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
165
  )
166
 
 
 
 
 
 
 
 
167
  if st.button("Submit", key="submit", help="Submit the form"):
168
  with st.status("Loading ...", expanded=True) as status:
169
  try:
@@ -176,10 +205,12 @@ with tab1:
176
  Settings.llm = llama_llm
177
 
178
  st.write("Processing Data ...")
179
- index = index_docs(selected_file)
 
180
 
181
  st.write("Finishing Up ...")
182
  llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
 
183
  st.session_state.llama_custom = llama_custom
184
 
185
  status.update(label="Ready to query!", state="complete", expanded=False)
 
7
  from models.llms import load_llm, integrated_llms
8
  from models.embeddings import hf_embed_model, openai_embed_model
9
  from models.llamaCustom import LlamaCustom
10
+ from models.llamaCustomV2 import LlamaCustomV2
11
+
12
+ # from models.vector_database import pinecone_vector_store
13
  from utils.chatbox import show_previous_messages, show_chat_input
14
  from utils.util import validate_openai_api_key
15
 
 
33
  HF_REPO_ID = "zhtet/RegBotBeta"
34
 
35
  # global
36
+ # Settings.embed_model = hf_embed_model
37
+ Settings.embed_model = openai_embed_model
38
 
39
  # huggingface api
40
  hf_api = HfApi()
 
66
 
67
 
68
  # @st.cache_resource
69
+ def get_index(
70
  filename: str,
71
  ) -> VectorStoreIndex:
72
+ """This function loads the index from storage if it exists, otherwise it creates a new index from the document."""
73
  try:
74
  index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
75
  if pathlib.Path.exists(index_path):
 
94
  return index
95
 
96
 
97
+ # def get_pinecone_index(filename: str) -> VectorStoreIndex:
98
+ # """Thie function loads the index from Pinecone if it exists, otherwise it creates a new index from the document."""
99
+ # reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
100
+ # docs = reader.load_data(show_progress=True)
101
+ # storage_context = StorageContext.from_defaults(vector_store=pinecone_vector_store)
102
+ # index = VectorStoreIndex.from_documents(
103
+ # documents=docs, show_progress=True, storage_context=storage_context
104
+ # )
105
+
106
+ # return index
107
+
108
+
109
+ def get_chroma_index(filename: str) -> VectorStoreIndex:
110
+ """This function loads the index from Chroma if it exists, otherwise it creates a new index from the document."""
111
+ pass
112
+
113
+
114
  def check_api_key(model_name: str, source: str):
115
  if source.startswith("openai"):
116
  if not st.session_state.openai_api_key:
 
186
  label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
187
  )
188
 
189
+ if st.button("Clear all api keys"):
190
+ st.session_state.openai_api_key = ""
191
+ st.session_state.replicate_api_token = ""
192
+ st.session_state.hf_token = ""
193
+ st.success("All API keys cleared!")
194
+ st.rerun()
195
+
196
  if st.button("Submit", key="submit", help="Submit the form"):
197
  with st.status("Loading ...", expanded=True) as status:
198
  try:
 
205
  Settings.llm = llama_llm
206
 
207
  st.write("Processing Data ...")
208
+ index = get_index(selected_file)
209
+ # index = get_pinecone_index(selected_file)
210
 
211
  st.write("Finishing Up ...")
212
  llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
213
+ # llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index)
214
  st.session_state.llama_custom = llama_custom
215
 
216
  status.update(label="Ready to query!", state="complete", expanded=False)
requirements.txt CHANGED
@@ -7,11 +7,14 @@ langchain_pinecone
7
  openai
8
  faiss-cpu
9
  python-dotenv
10
- streamlit==1.29.0
11
  huggingface_hub<0.21.0
12
  pypdf
13
  llama-index-llms-huggingface>=0.1.4
14
  llama-index-embeddings-langchain>=0.1.2
 
 
15
  replicate>=0.25.1
16
  llama-index-llms-replicate
17
- sentence-transformers>=2.6.1
 
 
7
  openai
8
  faiss-cpu
9
  python-dotenv
10
+ streamlit>=1.24.0
11
  huggingface_hub<0.21.0
12
  pypdf
13
  llama-index-llms-huggingface>=0.1.4
14
  llama-index-embeddings-langchain>=0.1.2
15
+ llama-index-vector-stores-pinecone
16
+ pinecone-client>=3.0.0
17
  replicate>=0.25.1
18
  llama-index-llms-replicate
19
+ sentence-transformers>=2.6.1
20
+ llama-index-postprocessor-colbert-rerank