Spaces:
Runtime error
Runtime error
llama-index-update (#1)
Browse files- updated llama index demo (b83dc9cf5500e038115cf835e0c466eb12b30905)
- Merge branch 'main' of https://huggingface.co/spaces/hbui/RegBot4.0 (e262fe5388f1bbd46030de0fafcb1f1b1b32b6db)
- updated llm configs and prompts (ab0abe29e798a3f31c482100f4139634171c2fa7)
- app.py +21 -18
- models/llamaCustom.py +10 -11
- models/llamaCustomV2.py +229 -0
- models/llms.py +2 -0
- models/vector_database.py +34 -0
- pages/llama_custom_demo.py +34 -3
- requirements.txt +5 -2
app.py
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
-
# https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
|
2 |
-
|
3 |
import os
|
4 |
-
|
5 |
import openai
|
6 |
import requests
|
7 |
import streamlit as st
|
8 |
|
9 |
from utils.util import *
|
10 |
-
|
11 |
from langchain.memory import ConversationBufferMemory
|
12 |
|
13 |
SAVE_DIR = "uploaded_files"
|
@@ -17,30 +13,24 @@ os.makedirs(SAVE_DIR, exist_ok=True)
|
|
17 |
def init_session_state():
|
18 |
if "openai_api_key" not in st.session_state:
|
19 |
st.session_state.openai_api_key = ""
|
20 |
-
|
21 |
if "uploaded_files" not in st.session_state:
|
22 |
st.session_state.uploaded_files = os.listdir(SAVE_DIR)
|
|
|
|
|
23 |
|
24 |
|
25 |
init_session_state()
|
26 |
|
27 |
st.set_page_config(page_title="RegBotBeta", page_icon="📜🤖")
|
28 |
-
|
29 |
st.title("Welcome to RegBotBeta2.0")
|
30 |
-
st.header("Powered by `LlamaIndex🦙`, `Langchain
|
31 |
|
32 |
|
33 |
-
def init_session_state():
|
34 |
-
if "huggingface_token" not in st.session_state:
|
35 |
-
st.session_state.huggingface_token = ""
|
36 |
-
|
37 |
-
|
38 |
-
init_session_state()
|
39 |
-
|
40 |
uploaded_files = st.file_uploader(
|
41 |
"Upload Files",
|
42 |
accept_multiple_files=True,
|
43 |
type=["pdf", "docx", "txt", "csv"],
|
|
|
44 |
)
|
45 |
|
46 |
if uploaded_files:
|
@@ -48,14 +38,27 @@ if uploaded_files:
|
|
48 |
if file not in st.session_state.uploaded_files:
|
49 |
# add the file to session state
|
50 |
st.session_state.uploaded_files.append(file.name)
|
51 |
-
|
52 |
# save the file to the sample_data directory
|
53 |
with open(os.path.join(SAVE_DIR, file.name), "wb") as f:
|
54 |
f.write(file.getbuffer())
|
55 |
-
|
56 |
st.success("File(s) uploaded successfully!")
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
if st.session_state.uploaded_files:
|
59 |
st.write("Uploaded Files:")
|
60 |
-
for
|
61 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import openai
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
|
6 |
from utils.util import *
|
|
|
7 |
from langchain.memory import ConversationBufferMemory
|
8 |
|
9 |
SAVE_DIR = "uploaded_files"
|
|
|
13 |
def init_session_state():
|
14 |
if "openai_api_key" not in st.session_state:
|
15 |
st.session_state.openai_api_key = ""
|
|
|
16 |
if "uploaded_files" not in st.session_state:
|
17 |
st.session_state.uploaded_files = os.listdir(SAVE_DIR)
|
18 |
+
if "huggingface_token" not in st.session_state:
|
19 |
+
st.session_state.huggingface_token = ""
|
20 |
|
21 |
|
22 |
init_session_state()
|
23 |
|
24 |
st.set_page_config(page_title="RegBotBeta", page_icon="📜🤖")
|
|
|
25 |
st.title("Welcome to RegBotBeta2.0")
|
26 |
+
st.header("Powered by `LlamaIndex🦙`, `Langchain🦜🔗` and `OpenAI API`")
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
uploaded_files = st.file_uploader(
|
30 |
"Upload Files",
|
31 |
accept_multiple_files=True,
|
32 |
type=["pdf", "docx", "txt", "csv"],
|
33 |
+
label_visibility="hidden",
|
34 |
)
|
35 |
|
36 |
if uploaded_files:
|
|
|
38 |
if file not in st.session_state.uploaded_files:
|
39 |
# add the file to session state
|
40 |
st.session_state.uploaded_files.append(file.name)
|
|
|
41 |
# save the file to the sample_data directory
|
42 |
with open(os.path.join(SAVE_DIR, file.name), "wb") as f:
|
43 |
f.write(file.getbuffer())
|
|
|
44 |
st.success("File(s) uploaded successfully!")
|
45 |
|
46 |
+
|
47 |
+
def delete_file(filename):
|
48 |
+
"""Delete file from session state and local filesystem."""
|
49 |
+
if filename in st.session_state.uploaded_files and os.path.exists(
|
50 |
+
os.path.join(SAVE_DIR, filename)
|
51 |
+
):
|
52 |
+
st.session_state.uploaded_files.remove(filename)
|
53 |
+
os.remove(os.path.join(SAVE_DIR, filename))
|
54 |
+
st.success(f"Deleted {filename}!")
|
55 |
+
st.rerun()
|
56 |
+
|
57 |
+
|
58 |
if st.session_state.uploaded_files:
|
59 |
st.write("Uploaded Files:")
|
60 |
+
for index, filename in enumerate(st.session_state.uploaded_files):
|
61 |
+
col1, col2 = st.columns([4, 1])
|
62 |
+
col1.write(filename)
|
63 |
+
if col2.button("Delete", key=f"delete_{index}"):
|
64 |
+
delete_file(filename)
|
models/llamaCustom.py
CHANGED
@@ -18,11 +18,7 @@ from assets.prompts import custom_prompts
|
|
18 |
|
19 |
# llama index
|
20 |
from llama_index.core import (
|
21 |
-
StorageContext,
|
22 |
-
SimpleDirectoryReader,
|
23 |
VectorStoreIndex,
|
24 |
-
load_index_from_storage,
|
25 |
-
PromptHelper,
|
26 |
PromptTemplate,
|
27 |
)
|
28 |
from llama_index.core.llms import (
|
@@ -47,14 +43,12 @@ NUM_OUTPUT = 525
|
|
47 |
# set maximum chunk overlap
|
48 |
CHUNK_OVERLAP_RATION = 0.2
|
49 |
|
50 |
-
# TODO: use the following prompt to format the answer at the end of the context prompt
|
51 |
ANSWER_FORMAT = """
|
52 |
-
|
53 |
[FORMAT]
|
54 |
-
|
55 |
-
The answer to the user question.
|
56 |
Reference:
|
57 |
-
The list of references to the specific sections of the documents that support your answer.
|
58 |
[END_FORMAT]
|
59 |
"""
|
60 |
|
@@ -184,9 +178,13 @@ class LlamaCustom:
|
|
184 |
|
185 |
def get_response(self, query_str: str, chat_history: List[ChatMessage]):
|
186 |
# https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/
|
|
|
|
|
187 |
query_engine = self.index.as_query_engine(
|
188 |
-
text_qa_template=PromptTemplate(QUERY_ENGINE_QA_TEMPLATE),
|
189 |
-
refine_template=PromptTemplate(
|
|
|
|
|
190 |
verbose=self.verbose,
|
191 |
)
|
192 |
# chat_engine = self.index.as_chat_engine(
|
@@ -196,6 +194,7 @@ class LlamaCustom:
|
|
196 |
# condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
|
197 |
# # verbose=True,
|
198 |
# )
|
|
|
199 |
response = query_engine.query(query_str)
|
200 |
# response = chat_engine.chat(message=query_str, chat_history=chat_history)
|
201 |
|
|
|
18 |
|
19 |
# llama index
|
20 |
from llama_index.core import (
|
|
|
|
|
21 |
VectorStoreIndex,
|
|
|
|
|
22 |
PromptTemplate,
|
23 |
)
|
24 |
from llama_index.core.llms import (
|
|
|
43 |
# set maximum chunk overlap
|
44 |
CHUNK_OVERLAP_RATION = 0.2
|
45 |
|
|
|
46 |
ANSWER_FORMAT = """
|
47 |
+
Provide the answer to the user question in the following format:
|
48 |
[FORMAT]
|
49 |
+
Your answer to the user question above.
|
|
|
50 |
Reference:
|
51 |
+
The list of references (such as page number, title, chapter, section) to the specific sections of the documents that support your answer.
|
52 |
[END_FORMAT]
|
53 |
"""
|
54 |
|
|
|
178 |
|
179 |
def get_response(self, query_str: str, chat_history: List[ChatMessage]):
|
180 |
# https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/
|
181 |
+
# https://docs.llamaindex.ai/en/stable/examples/query_engine/citation_query_engine/
|
182 |
+
# https://docs.llamaindex.ai/en/stable/examples/query_engine/knowledge_graph_rag_query_engine/
|
183 |
query_engine = self.index.as_query_engine(
|
184 |
+
text_qa_template=PromptTemplate(QUERY_ENGINE_QA_TEMPLATE + ANSWER_FORMAT),
|
185 |
+
refine_template=PromptTemplate(
|
186 |
+
QUERY_ENGINE_REFINE_TEMPLATE
|
187 |
+
), # passing ANSWER_FORMAT here will not give the desired output, need to use the output parser from llama index?
|
188 |
verbose=self.verbose,
|
189 |
)
|
190 |
# chat_engine = self.index.as_chat_engine(
|
|
|
194 |
# condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
|
195 |
# # verbose=True,
|
196 |
# )
|
197 |
+
|
198 |
response = query_engine.query(query_str)
|
199 |
# response = chat_engine.chat(message=query_str, chat_history=chat_history)
|
200 |
|
models/llamaCustomV2.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from llama_index.core import VectorStoreIndex
|
4 |
+
from llama_index.core.query_pipeline import (
|
5 |
+
QueryPipeline,
|
6 |
+
InputComponent,
|
7 |
+
ArgPackComponent,
|
8 |
+
)
|
9 |
+
from llama_index.core.prompts import PromptTemplate
|
10 |
+
from llama_index.llms.openai import OpenAI
|
11 |
+
from llama_index.postprocessor.colbert_rerank import ColbertRerank
|
12 |
+
from typing import Any, Dict, List, Optional
|
13 |
+
from llama_index.core.bridge.pydantic import Field
|
14 |
+
from llama_index.core.llms import ChatMessage
|
15 |
+
from llama_index.core.query_pipeline import CustomQueryComponent
|
16 |
+
from llama_index.core.schema import NodeWithScore
|
17 |
+
from llama_index.core.memory import ChatMemoryBuffer
|
18 |
+
|
19 |
+
|
20 |
+
llm = OpenAI(
|
21 |
+
model="gpt-3.5-turbo-0125",
|
22 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
23 |
+
)
|
24 |
+
|
25 |
+
# First, we create an input component to capture the user query
|
26 |
+
input_component = InputComponent()
|
27 |
+
|
28 |
+
# Next, we use the LLM to rewrite a user query
|
29 |
+
rewrite = (
|
30 |
+
"Please write a query to a semantic search engine using the current conversation.\n"
|
31 |
+
"\n"
|
32 |
+
"\n"
|
33 |
+
"{chat_history_str}"
|
34 |
+
"\n"
|
35 |
+
"\n"
|
36 |
+
"Latest message: {query_str}\n"
|
37 |
+
'Query:"""\n'
|
38 |
+
)
|
39 |
+
rewrite_template = PromptTemplate(rewrite)
|
40 |
+
|
41 |
+
# we will retrieve two times, so we need to pack the retrieved nodes into a single list
|
42 |
+
argpack_component = ArgPackComponent()
|
43 |
+
|
44 |
+
# then postprocess/rerank with Colbert
|
45 |
+
reranker = ColbertRerank(top_n=3)
|
46 |
+
|
47 |
+
DEFAULT_CONTEXT_PROMPT = (
|
48 |
+
"Here is some context that may be relevant:\n"
|
49 |
+
"-----\n"
|
50 |
+
"{node_context}\n"
|
51 |
+
"-----\n"
|
52 |
+
"Please write a response to the following question, using the above context:\n"
|
53 |
+
"{query_str}\n"
|
54 |
+
"Please formate your response in the following way:\n"
|
55 |
+
"Your answer here.\n"
|
56 |
+
"Reference:\n"
|
57 |
+
" Your references here (e.g. page numbers, titles, etc.).\n"
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
class ResponseWithChatHistory(CustomQueryComponent):
|
62 |
+
llm: OpenAI = Field(..., description="OpenAI LLM")
|
63 |
+
system_prompt: Optional[str] = Field(
|
64 |
+
default=None, description="System prompt to use for the LLM"
|
65 |
+
)
|
66 |
+
context_prompt: str = Field(
|
67 |
+
default=DEFAULT_CONTEXT_PROMPT,
|
68 |
+
description="Context prompt to use for the LLM",
|
69 |
+
)
|
70 |
+
|
71 |
+
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
72 |
+
"""Validate component inputs during run_component."""
|
73 |
+
# NOTE: this is OPTIONAL but we show you where to do validation as an example
|
74 |
+
return input
|
75 |
+
|
76 |
+
@property
|
77 |
+
def _input_keys(self) -> set:
|
78 |
+
"""Input keys dict."""
|
79 |
+
# NOTE: These are required inputs. If you have optional inputs please override
|
80 |
+
# `optional_input_keys_dict`
|
81 |
+
return {"chat_history", "nodes", "query_str"}
|
82 |
+
|
83 |
+
@property
|
84 |
+
def _output_keys(self) -> set:
|
85 |
+
return {"response"}
|
86 |
+
|
87 |
+
def _prepare_context(
|
88 |
+
self,
|
89 |
+
chat_history: List[ChatMessage],
|
90 |
+
nodes: List[NodeWithScore],
|
91 |
+
query_str: str,
|
92 |
+
) -> List[ChatMessage]:
|
93 |
+
node_context = ""
|
94 |
+
for idx, node in enumerate(nodes):
|
95 |
+
node_text = node.get_content(metadata_mode="llm")
|
96 |
+
node_context += f"Context Chunk {idx}:\n{node_text}\n\n"
|
97 |
+
|
98 |
+
formatted_context = self.context_prompt.format(
|
99 |
+
node_context=node_context, query_str=query_str
|
100 |
+
)
|
101 |
+
user_message = ChatMessage(role="user", content=formatted_context)
|
102 |
+
|
103 |
+
chat_history.append(user_message)
|
104 |
+
|
105 |
+
if self.system_prompt is not None:
|
106 |
+
chat_history = [
|
107 |
+
ChatMessage(role="system", content=self.system_prompt)
|
108 |
+
] + chat_history
|
109 |
+
|
110 |
+
return chat_history
|
111 |
+
|
112 |
+
def _run_component(self, **kwargs) -> Dict[str, Any]:
|
113 |
+
"""Run the component."""
|
114 |
+
chat_history = kwargs["chat_history"]
|
115 |
+
nodes = kwargs["nodes"]
|
116 |
+
query_str = kwargs["query_str"]
|
117 |
+
|
118 |
+
prepared_context = self._prepare_context(chat_history, nodes, query_str)
|
119 |
+
|
120 |
+
response = llm.chat(prepared_context)
|
121 |
+
|
122 |
+
return {"response": response}
|
123 |
+
|
124 |
+
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
|
125 |
+
"""Run the component asynchronously."""
|
126 |
+
# NOTE: Optional, but async LLM calls are easy to implement
|
127 |
+
chat_history = kwargs["chat_history"]
|
128 |
+
nodes = kwargs["nodes"]
|
129 |
+
query_str = kwargs["query_str"]
|
130 |
+
|
131 |
+
prepared_context = self._prepare_context(chat_history, nodes, query_str)
|
132 |
+
|
133 |
+
response = await llm.achat(prepared_context)
|
134 |
+
|
135 |
+
return {"response": response}
|
136 |
+
|
137 |
+
|
138 |
+
class LlamaCustomV2:
|
139 |
+
response_component = ResponseWithChatHistory(
|
140 |
+
llm=llm,
|
141 |
+
system_prompt=(
|
142 |
+
"You are a Q&A system. You will be provided with the previous chat history, "
|
143 |
+
"as well as possibly relevant context, to assist in answering a user message."
|
144 |
+
),
|
145 |
+
)
|
146 |
+
|
147 |
+
def __init__(self, model_name: str, index: VectorStoreIndex):
|
148 |
+
self.model_name = model_name
|
149 |
+
self.index = index
|
150 |
+
self.retriever = index.as_retriever()
|
151 |
+
self.chat_mode = "condense_plus_context"
|
152 |
+
self.memory = ChatMemoryBuffer.from_defaults()
|
153 |
+
self.verbose = True
|
154 |
+
self._build_pipeline()
|
155 |
+
|
156 |
+
def _build_pipeline(self):
|
157 |
+
self.pipeline = QueryPipeline(
|
158 |
+
modules={
|
159 |
+
"input": input_component,
|
160 |
+
"rewrite_template": rewrite_template,
|
161 |
+
"llm": llm,
|
162 |
+
"rewrite_retriever": self.retriever,
|
163 |
+
"query_retriever": self.retriever,
|
164 |
+
"join": argpack_component,
|
165 |
+
"reranker": reranker,
|
166 |
+
"response_component": self.response_component,
|
167 |
+
},
|
168 |
+
verbose=self.verbose,
|
169 |
+
)
|
170 |
+
# run both retrievers -- once with the hallucinated query, once with the real query
|
171 |
+
self.pipeline.add_link(
|
172 |
+
"input", "rewrite_template", src_key="query_str", dest_key="query_str"
|
173 |
+
)
|
174 |
+
self.pipeline.add_link(
|
175 |
+
"input",
|
176 |
+
"rewrite_template",
|
177 |
+
src_key="chat_history_str",
|
178 |
+
dest_key="chat_history_str",
|
179 |
+
)
|
180 |
+
self.pipeline.add_link("rewrite_template", "llm")
|
181 |
+
self.pipeline.add_link("llm", "rewrite_retriever")
|
182 |
+
self.pipeline.add_link("input", "query_retriever", src_key="query_str")
|
183 |
+
|
184 |
+
# each input to the argpack component needs a dest key -- it can be anything
|
185 |
+
# then, the argpack component will pack all the inputs into a single list
|
186 |
+
self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes")
|
187 |
+
self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes")
|
188 |
+
|
189 |
+
# reranker needs the packed nodes and the query string
|
190 |
+
self.pipeline.add_link("join", "reranker", dest_key="nodes")
|
191 |
+
self.pipeline.add_link(
|
192 |
+
"input", "reranker", src_key="query_str", dest_key="query_str"
|
193 |
+
)
|
194 |
+
|
195 |
+
# synthesizer needs the reranked nodes and query str
|
196 |
+
self.pipeline.add_link("reranker", "response_component", dest_key="nodes")
|
197 |
+
self.pipeline.add_link(
|
198 |
+
"input", "response_component", src_key="query_str", dest_key="query_str"
|
199 |
+
)
|
200 |
+
self.pipeline.add_link(
|
201 |
+
"input",
|
202 |
+
"response_component",
|
203 |
+
src_key="chat_history",
|
204 |
+
dest_key="chat_history",
|
205 |
+
)
|
206 |
+
|
207 |
+
def get_response(self, query_str: str, chat_history: List[ChatMessage]):
|
208 |
+
chat_history = self.memory.get()
|
209 |
+
char_history_str = "\n".join([str(x) for x in chat_history])
|
210 |
+
|
211 |
+
response = self.pipeline.run(
|
212 |
+
query_str=query_str,
|
213 |
+
chat_history=chat_history,
|
214 |
+
chat_history_str=char_history_str,
|
215 |
+
)
|
216 |
+
|
217 |
+
user_msg = ChatMessage(role="user", content=query_str)
|
218 |
+
print("user_msg: ", str(user_msg))
|
219 |
+
print("response: ", str(response.message))
|
220 |
+
self.memory.put(user_msg)
|
221 |
+
self.memory.put(response.message)
|
222 |
+
|
223 |
+
return str(response.message)
|
224 |
+
|
225 |
+
def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]):
|
226 |
+
response = self.get_response(query_str=query_str, chat_history=chat_history)
|
227 |
+
for word in response.split():
|
228 |
+
yield word + " "
|
229 |
+
time.sleep(0.05)
|
models/llms.py
CHANGED
@@ -35,6 +35,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
|
|
35 |
llm_gpt_3_5_turbo_0125 = OpenAI(
|
36 |
model=model_name,
|
37 |
api_key=st.session_state.openai_api_key,
|
|
|
38 |
)
|
39 |
|
40 |
return llm_gpt_3_5_turbo_0125
|
@@ -45,6 +46,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
|
|
45 |
is_chat_model=True,
|
46 |
additional_kwargs={"max_new_tokens": 250},
|
47 |
prompt_key=st.session_state.replicate_api_token,
|
|
|
48 |
)
|
49 |
|
50 |
return llm_llama_13b_v2_replicate
|
|
|
35 |
llm_gpt_3_5_turbo_0125 = OpenAI(
|
36 |
model=model_name,
|
37 |
api_key=st.session_state.openai_api_key,
|
38 |
+
temperature=0.0,
|
39 |
)
|
40 |
|
41 |
return llm_gpt_3_5_turbo_0125
|
|
|
46 |
is_chat_model=True,
|
47 |
additional_kwargs={"max_new_tokens": 250},
|
48 |
prompt_key=st.session_state.replicate_api_token,
|
49 |
+
temperature=0.0,
|
50 |
)
|
51 |
|
52 |
return llm_llama_13b_v2_replicate
|
models/vector_database.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pinecone import Pinecone, ServerlessSpec
|
2 |
+
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
# Pinecone Vector Database
|
10 |
+
pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
|
11 |
+
pc_index_name = "llama-integration-pinecone"
|
12 |
+
# pc_index_name = "openai-embeddings"
|
13 |
+
pc_indexes = pc.list_indexes()
|
14 |
+
|
15 |
+
# Check if the index already exists
|
16 |
+
def index_exists(index_name):
|
17 |
+
for index in pc_indexes:
|
18 |
+
if index["name"] == index_name:
|
19 |
+
return True
|
20 |
+
return False
|
21 |
+
|
22 |
+
# Create the index if it doesn't exist
|
23 |
+
if not index_exists(pc_index_name):
|
24 |
+
pc.create_index(
|
25 |
+
name=pc_index_name,
|
26 |
+
dimension=1536,
|
27 |
+
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
|
28 |
+
)
|
29 |
+
|
30 |
+
# Initialize your index
|
31 |
+
pinecone_index = pc.Index(pc_index_name)
|
32 |
+
|
33 |
+
# Define the vector store
|
34 |
+
pinecone_vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
|
pages/llama_custom_demo.py
CHANGED
@@ -7,6 +7,9 @@ from typing import List
|
|
7 |
from models.llms import load_llm, integrated_llms
|
8 |
from models.embeddings import hf_embed_model, openai_embed_model
|
9 |
from models.llamaCustom import LlamaCustom
|
|
|
|
|
|
|
10 |
from utils.chatbox import show_previous_messages, show_chat_input
|
11 |
from utils.util import validate_openai_api_key
|
12 |
|
@@ -30,7 +33,8 @@ VECTOR_STORE_DIR = "vectorStores"
|
|
30 |
HF_REPO_ID = "zhtet/RegBotBeta"
|
31 |
|
32 |
# global
|
33 |
-
Settings.embed_model = hf_embed_model
|
|
|
34 |
|
35 |
# huggingface api
|
36 |
hf_api = HfApi()
|
@@ -62,9 +66,10 @@ def init_session_state():
|
|
62 |
|
63 |
|
64 |
# @st.cache_resource
|
65 |
-
def
|
66 |
filename: str,
|
67 |
) -> VectorStoreIndex:
|
|
|
68 |
try:
|
69 |
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
|
70 |
if pathlib.Path.exists(index_path):
|
@@ -89,6 +94,23 @@ def index_docs(
|
|
89 |
return index
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def check_api_key(model_name: str, source: str):
|
93 |
if source.startswith("openai"):
|
94 |
if not st.session_state.openai_api_key:
|
@@ -164,6 +186,13 @@ with tab1:
|
|
164 |
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
|
165 |
)
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
if st.button("Submit", key="submit", help="Submit the form"):
|
168 |
with st.status("Loading ...", expanded=True) as status:
|
169 |
try:
|
@@ -176,10 +205,12 @@ with tab1:
|
|
176 |
Settings.llm = llama_llm
|
177 |
|
178 |
st.write("Processing Data ...")
|
179 |
-
index =
|
|
|
180 |
|
181 |
st.write("Finishing Up ...")
|
182 |
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
|
|
|
183 |
st.session_state.llama_custom = llama_custom
|
184 |
|
185 |
status.update(label="Ready to query!", state="complete", expanded=False)
|
|
|
7 |
from models.llms import load_llm, integrated_llms
|
8 |
from models.embeddings import hf_embed_model, openai_embed_model
|
9 |
from models.llamaCustom import LlamaCustom
|
10 |
+
from models.llamaCustomV2 import LlamaCustomV2
|
11 |
+
|
12 |
+
# from models.vector_database import pinecone_vector_store
|
13 |
from utils.chatbox import show_previous_messages, show_chat_input
|
14 |
from utils.util import validate_openai_api_key
|
15 |
|
|
|
33 |
HF_REPO_ID = "zhtet/RegBotBeta"
|
34 |
|
35 |
# global
|
36 |
+
# Settings.embed_model = hf_embed_model
|
37 |
+
Settings.embed_model = openai_embed_model
|
38 |
|
39 |
# huggingface api
|
40 |
hf_api = HfApi()
|
|
|
66 |
|
67 |
|
68 |
# @st.cache_resource
|
69 |
+
def get_index(
|
70 |
filename: str,
|
71 |
) -> VectorStoreIndex:
|
72 |
+
"""This function loads the index from storage if it exists, otherwise it creates a new index from the document."""
|
73 |
try:
|
74 |
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
|
75 |
if pathlib.Path.exists(index_path):
|
|
|
94 |
return index
|
95 |
|
96 |
|
97 |
+
# def get_pinecone_index(filename: str) -> VectorStoreIndex:
|
98 |
+
# """Thie function loads the index from Pinecone if it exists, otherwise it creates a new index from the document."""
|
99 |
+
# reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
|
100 |
+
# docs = reader.load_data(show_progress=True)
|
101 |
+
# storage_context = StorageContext.from_defaults(vector_store=pinecone_vector_store)
|
102 |
+
# index = VectorStoreIndex.from_documents(
|
103 |
+
# documents=docs, show_progress=True, storage_context=storage_context
|
104 |
+
# )
|
105 |
+
|
106 |
+
# return index
|
107 |
+
|
108 |
+
|
109 |
+
def get_chroma_index(filename: str) -> VectorStoreIndex:
|
110 |
+
"""This function loads the index from Chroma if it exists, otherwise it creates a new index from the document."""
|
111 |
+
pass
|
112 |
+
|
113 |
+
|
114 |
def check_api_key(model_name: str, source: str):
|
115 |
if source.startswith("openai"):
|
116 |
if not st.session_state.openai_api_key:
|
|
|
186 |
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
|
187 |
)
|
188 |
|
189 |
+
if st.button("Clear all api keys"):
|
190 |
+
st.session_state.openai_api_key = ""
|
191 |
+
st.session_state.replicate_api_token = ""
|
192 |
+
st.session_state.hf_token = ""
|
193 |
+
st.success("All API keys cleared!")
|
194 |
+
st.rerun()
|
195 |
+
|
196 |
if st.button("Submit", key="submit", help="Submit the form"):
|
197 |
with st.status("Loading ...", expanded=True) as status:
|
198 |
try:
|
|
|
205 |
Settings.llm = llama_llm
|
206 |
|
207 |
st.write("Processing Data ...")
|
208 |
+
index = get_index(selected_file)
|
209 |
+
# index = get_pinecone_index(selected_file)
|
210 |
|
211 |
st.write("Finishing Up ...")
|
212 |
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
|
213 |
+
# llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index)
|
214 |
st.session_state.llama_custom = llama_custom
|
215 |
|
216 |
status.update(label="Ready to query!", state="complete", expanded=False)
|
requirements.txt
CHANGED
@@ -7,11 +7,14 @@ langchain_pinecone
|
|
7 |
openai
|
8 |
faiss-cpu
|
9 |
python-dotenv
|
10 |
-
streamlit
|
11 |
huggingface_hub<0.21.0
|
12 |
pypdf
|
13 |
llama-index-llms-huggingface>=0.1.4
|
14 |
llama-index-embeddings-langchain>=0.1.2
|
|
|
|
|
15 |
replicate>=0.25.1
|
16 |
llama-index-llms-replicate
|
17 |
-
sentence-transformers>=2.6.1
|
|
|
|
7 |
openai
|
8 |
faiss-cpu
|
9 |
python-dotenv
|
10 |
+
streamlit>=1.24.0
|
11 |
huggingface_hub<0.21.0
|
12 |
pypdf
|
13 |
llama-index-llms-huggingface>=0.1.4
|
14 |
llama-index-embeddings-langchain>=0.1.2
|
15 |
+
llama-index-vector-stores-pinecone
|
16 |
+
pinecone-client>=3.0.0
|
17 |
replicate>=0.25.1
|
18 |
llama-index-llms-replicate
|
19 |
+
sentence-transformers>=2.6.1
|
20 |
+
llama-index-postprocessor-colbert-rerank
|