mlk8s v 0.0.1
Browse files- .gitignore +2 -0
- app.py +297 -2
- kron/__init__.py +0 -0
- kron/indices/knowledge_graph/KronKnowledgeGraphIndex.py +61 -0
- kron/kg/__init__.py +0 -0
- kron/llm_predictor/KronBasetenCamelLLM.py +32 -0
- kron/llm_predictor/KronLLMPredictor.py +39 -0
- kron/llm_predictor/KronLangChainLLM.py +35 -0
- kron/llm_predictor/KronOpenAILLM.py +39 -0
- kron/llm_predictor/__init__.py +0 -0
- kron/llm_predictor/openai_utils.py +115 -0
- kron/llm_predictor/utils.py +16 -0
- kron/persistence/dynamodb_request_log.py +132 -0
- kron/prompts/kg_prompts.py +351 -0
- requirements.txt +21 -0
- storage/Writer-camel-5b-hf-default-no-coref/graph_store.json +0 -0
- storage/Writer-camel-5b-hf-default-no-coref/index_store.json +0 -0
- storage/Writer-camel-5b-hf-default-no-coref/vector_store.json +1 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
__pycache__
|
app.py
CHANGED
@@ -1,4 +1,299 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import sys
|
6 |
+
import logging
|
7 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
for key in st.session_state.keys():
|
14 |
+
#del st.session_state[key]
|
15 |
+
print(f'session state entry: {key} {st.session_state[key]}')
|
16 |
+
|
17 |
+
__spaces__ = os.environ.get('__SPACES__')
|
18 |
+
|
19 |
+
if __spaces__:
|
20 |
+
from kron.persistence.dynamodb_request_log import get_request_log;
|
21 |
+
st.session_state.request_log = get_request_log()
|
22 |
+
|
23 |
+
#third party service access
|
24 |
+
#hf inference api
|
25 |
+
hf_api_key = os.environ['HF_TOKEN']
|
26 |
+
ch_api_key = os.environ['COHERE_TOKEN']
|
27 |
+
bs_api_key = os.environ['BASETEN_TOKEN']
|
28 |
+
|
29 |
+
index_model = "Writer/camel-5b-hf"
|
30 |
+
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref"
|
31 |
+
persist_path = f"storage/{INDEX_NAME}"
|
32 |
+
MAX_LENGTH = 1024
|
33 |
+
|
34 |
+
import baseten
|
35 |
+
@st.cache_resource
|
36 |
+
def set_baseten_key(bs_api_key):
|
37 |
+
baseten.login(bs_api_key)
|
38 |
+
|
39 |
+
set_baseten_key(bs_api_key)
|
40 |
+
|
41 |
+
from llama_index import StorageContext
|
42 |
+
from llama_index import ServiceContext
|
43 |
+
from llama_index import load_index_from_storage
|
44 |
+
from llama_index.langchain_helpers.text_splitter import SentenceSplitter
|
45 |
+
from llama_index.node_parser import SimpleNodeParser
|
46 |
+
from llama_index import LLMPredictor
|
47 |
+
|
48 |
+
from langchain import HuggingFaceHub
|
49 |
+
from langchain.llms.cohere import Cohere
|
50 |
+
from langchain.llms import Baseten
|
51 |
+
|
52 |
+
import tiktoken
|
53 |
+
|
54 |
+
import openai
|
55 |
+
#extensions to llama_index to support openai compatible endpoints, e.g. llama-api
|
56 |
+
from kron.llm_predictor.KronOpenAILLM import KronOpenAI
|
57 |
+
#baseten deployment expects a specific request format
|
58 |
+
from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM
|
59 |
+
from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor
|
60 |
+
|
61 |
+
#writer/camel uses endoftext
|
62 |
+
from llama_index.utils import globals_helper
|
63 |
+
enc = tiktoken.get_encoding("gpt2")
|
64 |
+
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"})
|
65 |
+
globals_helper._tokenizer = tokenizer
|
66 |
+
|
67 |
+
|
68 |
+
def set_openai_local():
|
69 |
+
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY']
|
70 |
+
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE']
|
71 |
+
os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY']
|
72 |
+
os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE']
|
73 |
+
|
74 |
+
def set_openai():
|
75 |
+
openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY']
|
76 |
+
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE']
|
77 |
+
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY']
|
78 |
+
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE']
|
79 |
+
|
80 |
+
def get_hf_predictor(query_model):
|
81 |
+
# no embeddings for now
|
82 |
+
set_openai_local()
|
83 |
+
llm=HuggingFaceHub(repo_id=query_model, task="text-generation",
|
84 |
+
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
|
85 |
+
huggingfacehub_api_token=hf_api_key)
|
86 |
+
llm_predictor = LLMPredictor(llm)
|
87 |
+
return llm_predictor
|
88 |
+
|
89 |
+
def get_cohere_predictor(query_model):
|
90 |
+
# no embeddings for now
|
91 |
+
set_openai_local()
|
92 |
+
llm=Cohere(model='command', temperature = 0.01,
|
93 |
+
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
|
94 |
+
cohere_api_key=ch_api_key)
|
95 |
+
llm_predictor = LLMPredictor(llm)
|
96 |
+
return llm_predictor
|
97 |
+
|
98 |
+
def get_baseten_predictor(query_model):
|
99 |
+
# no embeddings for now
|
100 |
+
set_openai_local()
|
101 |
+
llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01,
|
102 |
+
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07},
|
103 |
+
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1},
|
104 |
+
cohere_api_key=ch_api_key)
|
105 |
+
llm_predictor = LLMPredictor(llm)
|
106 |
+
return llm_predictor
|
107 |
+
|
108 |
+
def get_kron_openai_predictor(query_model):
|
109 |
+
# define LLM
|
110 |
+
llm=KronOpenAI(temperature=0.01, model=query_model)
|
111 |
+
llm.max_tokens = MAX_LENGTH
|
112 |
+
llm_predictor = KronLLMPredictor(llm)
|
113 |
+
return llm_predictor
|
114 |
+
|
115 |
+
def get_servce_context(llm_predictor):
|
116 |
+
# define TextSplitter
|
117 |
+
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n')
|
118 |
+
#define NodeParser
|
119 |
+
node_parser = SimpleNodeParser(text_splitter=text_splitter)
|
120 |
+
#define ServiceContext
|
121 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser)
|
122 |
+
return service_context
|
123 |
+
|
124 |
+
def get_index(service_context, persist_path):
|
125 |
+
print(f'Loading index from {persist_path}')
|
126 |
+
# rebuild storage context
|
127 |
+
storage_context = StorageContext.from_defaults(persist_dir=persist_path)
|
128 |
+
# load index
|
129 |
+
index = load_index_from_storage(storage_context=storage_context,
|
130 |
+
service_context=service_context,
|
131 |
+
max_triplets_per_chunk=2,
|
132 |
+
show_progress = False)
|
133 |
+
return index
|
134 |
+
|
135 |
+
def get_query_engine(index):
|
136 |
+
#writer/camel does not understand the refine prompt
|
137 |
+
RESPONSE_MODE = 'accumulate'
|
138 |
+
query_engine = index.as_query_engine(response_mode = RESPONSE_MODE)
|
139 |
+
return query_engine
|
140 |
+
|
141 |
+
def load_query_engine(llm_predictor, persist_path):
|
142 |
+
service_context = get_servce_context(llm_predictor)
|
143 |
+
index = get_index(service_context, persist_path)
|
144 |
+
print(f'No query engine for {persist_path}; creating')
|
145 |
+
query_engine = get_query_engine(index)
|
146 |
+
return query_engine
|
147 |
+
|
148 |
+
@st.cache_resource
|
149 |
+
def build_kron_query_engine(query_model, persist_path):
|
150 |
+
llm_predictor = get_kron_openai_predictor(query_model)
|
151 |
+
query_engine = load_query_engine(llm_predictor, persist_path)
|
152 |
+
return query_engine
|
153 |
+
|
154 |
+
@st.cache_resource
|
155 |
+
def build_hf_query_engine(query_model, persist_path):
|
156 |
+
llm_predictor = get_hf_predictor(query_model)
|
157 |
+
query_engine = load_query_engine(llm_predictor, persist_path)
|
158 |
+
return query_engine
|
159 |
+
|
160 |
+
@st.cache_resource
|
161 |
+
def build_cohere_query_engine(query_model, persist_path):
|
162 |
+
llm_predictor = get_cohere_predictor(query_model)
|
163 |
+
query_engine = load_query_engine(llm_predictor, persist_path)
|
164 |
+
return query_engine
|
165 |
+
|
166 |
+
@st.cache_resource
|
167 |
+
def build_baseten_query_engine(query_model, persist_path):
|
168 |
+
llm_predictor = get_baseten_predictor(query_model)
|
169 |
+
query_engine = load_query_engine(llm_predictor, persist_path)
|
170 |
+
return query_engine
|
171 |
+
|
172 |
+
def format_response(answer):
|
173 |
+
# Replace any eventual --
|
174 |
+
dashes = r'(\-{2,50})'
|
175 |
+
answer.response = re.sub(dashes, '', answer.response)
|
176 |
+
return answer.response or "None"
|
177 |
+
|
178 |
+
def clear_question(query_model):
|
179 |
+
if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) :
|
180 |
+
if 'prev_model' in st.session_state:
|
181 |
+
print(f'clearing question {st.session_state.prev_model} {query_model}')
|
182 |
+
else:
|
183 |
+
print(f'clearing question None {query_model}')
|
184 |
+
if('question_input' in st.session_state):
|
185 |
+
st.session_state.question = st.session_state.question_input
|
186 |
+
st.session_state.question_input = ''
|
187 |
+
st.session_state.question_answered = False
|
188 |
+
st.session_state.answer = ''
|
189 |
+
st.session_state.prev_model = query_model
|
190 |
+
|
191 |
+
|
192 |
+
initial_query = ''
|
193 |
+
#st.session_state.prev_model = None
|
194 |
+
|
195 |
+
if 'question' not in st.session_state:
|
196 |
+
st.session_state.question = ''
|
197 |
+
|
198 |
+
if __spaces__ :
|
199 |
+
answer_model = st.radio(
|
200 |
+
"Choose the model used for inference:",
|
201 |
+
('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') #TODO start hf inference container on demand
|
202 |
+
# ('cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003')
|
203 |
+
)
|
204 |
+
else :
|
205 |
+
answer_model = st.radio(
|
206 |
+
"Choose the model used for inference:",
|
207 |
+
('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003')
|
208 |
+
)
|
209 |
+
|
210 |
+
if answer_model == 'openai/text-davinci-003':
|
211 |
+
print(answer_model)
|
212 |
+
query_model = 'text-davinci-003'
|
213 |
+
clear_question(query_model)
|
214 |
+
set_openai()
|
215 |
+
query_engine = build_kron_query_engine(query_model, persist_path)
|
216 |
+
elif answer_model == 'hf/tiiuae/falcon-7b-instruct':
|
217 |
+
print(answer_model)
|
218 |
+
query_model = 'tiiuae/falcon-7b-instruct'
|
219 |
+
clear_question(query_model)
|
220 |
+
query_engine = build_hf_query_engine(query_model, persist_path)
|
221 |
+
elif answer_model == 'cohere/command':
|
222 |
+
print(answer_model)
|
223 |
+
query_model = 'cohere/command'
|
224 |
+
clear_question(query_model)
|
225 |
+
query_engine = build_cohere_query_engine(query_model, persist_path)
|
226 |
+
elif answer_model == 'baseten/Camel-5b':
|
227 |
+
print(answer_model)
|
228 |
+
query_model = 'baseten/Camel-5b'
|
229 |
+
clear_question(query_model)
|
230 |
+
query_engine = build_baseten_query_engine(query_model, persist_path)
|
231 |
+
elif answer_model == 'Local-Camel':
|
232 |
+
query_model = 'Writer/camel-5b-hf'
|
233 |
+
print(answer_model)
|
234 |
+
clear_question(query_model)
|
235 |
+
set_openai_local()
|
236 |
+
query_engine = build_kron_query_engine(query_model, persist_path)
|
237 |
+
elif answer_model == 'HF-TKI':
|
238 |
+
query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl'
|
239 |
+
clear_question(query_model)
|
240 |
+
query_engine = build_hf_query_engine(query_model, persist_path)
|
241 |
+
else:
|
242 |
+
print('This is a bug.')
|
243 |
+
|
244 |
+
# to clear input box
|
245 |
+
def submit():
|
246 |
+
st.session_state.question = st.session_state.question_input
|
247 |
+
st.session_state.question_input = ''
|
248 |
+
st.session_state.question_answered = False
|
249 |
+
|
250 |
+
#def submit_rating(query_model, req, resp):
|
251 |
+
# print(f'query model {query_model}')
|
252 |
+
# if 'answer_rating' in st.session_state:
|
253 |
+
# print(f'rating {st.session_state.answer_rating}')
|
254 |
+
|
255 |
+
st.write(f'Model, question, answer and rating are logged to help with the improvement of this application.')
|
256 |
+
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit )
|
257 |
+
|
258 |
+
# answer_str = None
|
259 |
+
if(st.session_state.question):
|
260 |
+
col1, col2 = st.columns([2, 2])
|
261 |
+
with col1:
|
262 |
+
st.write(f'Answering: {st.session_state.question} with {query_model}.')
|
263 |
+
|
264 |
+
try :
|
265 |
+
if not st.session_state.question_answered:
|
266 |
+
answer = query_engine.query(st.session_state.question)
|
267 |
+
st.session_state.answer = answer
|
268 |
+
st.session_state.question_answered = True
|
269 |
+
else:
|
270 |
+
answer = st.session_state.answer
|
271 |
+
answer_str = format_response(answer)
|
272 |
+
st.write(answer_str)
|
273 |
+
with col1:
|
274 |
+
if answer_str:
|
275 |
+
st.write(f' Please rate this answer.')
|
276 |
+
with col2:
|
277 |
+
from streamlit_star_rating import st_star_rating
|
278 |
+
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating",
|
279 |
+
# customCSS = "div {background-color: red;}"
|
280 |
+
# on_change = submit_rating(query_model, st.session_state.question, answer_str)
|
281 |
+
)
|
282 |
+
print(f"------stars {stars}")
|
283 |
+
except Exception as e:
|
284 |
+
print(e)
|
285 |
+
answer_str = str(e)
|
286 |
+
st.session_state.answer_rating = -1
|
287 |
+
finally:
|
288 |
+
if 'question' in st.session_state:
|
289 |
+
req = st.session_state.question
|
290 |
+
#st.session_state.question = ''
|
291 |
+
if(__spaces__):
|
292 |
+
#request_log = get_request_log()
|
293 |
+
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)
|
294 |
+
|
295 |
+
# if "answer_rating" in st.session_state:
|
296 |
+
# if(__spaces__):
|
297 |
+
# print('time to log the rating')
|
298 |
+
# #request_log = get_request_log()
|
299 |
+
# st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)
|
kron/__init__.py
ADDED
File without changes
|
kron/indices/knowledge_graph/KronKnowledgeGraphIndex.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
3 |
+
|
4 |
+
from llama_index import KnowledgeGraphIndex
|
5 |
+
from llama_index.data_structs.data_structs import KG
|
6 |
+
from llama_index.indices.service_context import ServiceContext
|
7 |
+
from llama_index.prompts.prompts import KnowledgeGraphPrompt
|
8 |
+
from llama_index.storage.storage_context import StorageContext
|
9 |
+
from llama_index.schema import BaseNode
|
10 |
+
|
11 |
+
class KronKnowledgeGraphIndex(KnowledgeGraphIndex):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
nodes: Optional[Sequence[BaseNode]] = None,
|
15 |
+
index_struct: Optional[KG] = None,
|
16 |
+
service_context: Optional[ServiceContext] = None,
|
17 |
+
storage_context: Optional[StorageContext] = None,
|
18 |
+
kg_triple_extract_template: Optional[KnowledgeGraphPrompt] = None,
|
19 |
+
max_triplets_per_chunk: int = 10,
|
20 |
+
include_embeddings: bool = False,
|
21 |
+
**kwargs: Any,
|
22 |
+
) -> None:
|
23 |
+
super().__init__(
|
24 |
+
nodes,
|
25 |
+
index_struct,
|
26 |
+
service_context,
|
27 |
+
storage_context,
|
28 |
+
kg_triple_extract_template,
|
29 |
+
max_triplets_per_chunk,
|
30 |
+
include_embeddings,
|
31 |
+
kwargs
|
32 |
+
)
|
33 |
+
|
34 |
+
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
35 |
+
"""Extract keywords from text."""
|
36 |
+
#response, _ = self._service_context.llm_predictor.predict(
|
37 |
+
response = self._service_context.llm_predictor.predict(
|
38 |
+
self.kg_triple_extract_template,
|
39 |
+
text=text,
|
40 |
+
)
|
41 |
+
return self._kron_parse_triplet_response(response)
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def _kron_parse_triplet_response(response: str) -> List[Tuple[str, str, str]]:
|
45 |
+
print("_kron_parse_triplet_response")
|
46 |
+
knowledge_strs = response.strip().split("\n")
|
47 |
+
results = []
|
48 |
+
for text in knowledge_strs:
|
49 |
+
text = text.strip() #triples might not start at the begining of the line
|
50 |
+
#text = text.replace('<|endoftext|>', '')
|
51 |
+
#useful triplets are before <|endoftext|>
|
52 |
+
text = text.split("<|endoftext|>")[0]
|
53 |
+
if text == "" or text[0] != "(":
|
54 |
+
# skip empty lines and non-triplets
|
55 |
+
continue
|
56 |
+
tokens = text[1:-1].split(",")
|
57 |
+
if len(tokens) != 3:
|
58 |
+
continue
|
59 |
+
subj, pred, obj = tokens
|
60 |
+
results.append((subj.strip(), pred.strip(), obj.strip()))
|
61 |
+
return results
|
kron/kg/__init__.py
ADDED
File without changes
|
kron/llm_predictor/KronBasetenCamelLLM.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, List
|
2 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
3 |
+
from langchain.llms import Baseten
|
4 |
+
|
5 |
+
class KronBasetenCamelLLM(Baseten):
|
6 |
+
def _call(
|
7 |
+
self,
|
8 |
+
prompt: str,
|
9 |
+
stop: Optional[List[str]] = None,
|
10 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
11 |
+
**kwargs: Any,
|
12 |
+
) -> str:
|
13 |
+
"""Call to Baseten deployed model endpoint."""
|
14 |
+
try:
|
15 |
+
import baseten
|
16 |
+
except ImportError as exc:
|
17 |
+
raise ImportError(
|
18 |
+
"Could not import Baseten Python package. "
|
19 |
+
"Please install it with `pip install baseten`."
|
20 |
+
) from exc
|
21 |
+
|
22 |
+
# get the model and version
|
23 |
+
try:
|
24 |
+
model = baseten.deployed_model_version_id(self.model)
|
25 |
+
response = model.predict({"instruction": prompt, **kwargs})
|
26 |
+
except baseten.common.core.ApiError:
|
27 |
+
model = baseten.deployed_model_id(self.model)
|
28 |
+
response = model.predict({"instruction": prompt, **kwargs})
|
29 |
+
|
30 |
+
response_txt = response['completion']
|
31 |
+
#print(f'\n********{response_txt}')
|
32 |
+
return response_txt
|
kron/llm_predictor/KronLLMPredictor.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Any, Generator, Optional, Protocol, Tuple, runtime_checkable
|
3 |
+
|
4 |
+
from llama_index import LLMPredictor
|
5 |
+
from llama_index.llms.utils import LLMType
|
6 |
+
from llama_index.callbacks.base import CallbackManager
|
7 |
+
|
8 |
+
from kron.llm_predictor.utils import kron_resolve_llm
|
9 |
+
|
10 |
+
class KronLLMPredictor(LLMPredictor):
|
11 |
+
"""LLM predictor class.
|
12 |
+
|
13 |
+
Wrapper around an LLMChain from Langchain.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use
|
17 |
+
for predictions. Defaults to OpenAI's text-davinci-003 model.
|
18 |
+
Please see `Langchain's LLM Page
|
19 |
+
<https://langchain.readthedocs.io/en/latest/modules/llms.html>`_
|
20 |
+
for more details.
|
21 |
+
|
22 |
+
retry_on_throttling (bool): Whether to retry on rate limit errors.
|
23 |
+
Defaults to true.
|
24 |
+
|
25 |
+
cache (Optional[langchain.cache.BaseCache]) : use cached result for LLM
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
llm: Optional[LLMType] = None,
|
31 |
+
callback_manager: Optional[CallbackManager] = None,
|
32 |
+
) -> None:
|
33 |
+
"""Initialize params."""
|
34 |
+
self._llm = kron_resolve_llm(llm)
|
35 |
+
self.callback_manager = callback_manager or CallbackManager([])
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
kron/llm_predictor/KronLangChainLLM.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.bridge.langchain import BaseLanguageModel, BaseChatModel
|
2 |
+
from llama_index.llms.langchain import LangChainLLM
|
3 |
+
from llama_index.bridge.langchain import OpenAI, ChatOpenAI
|
4 |
+
|
5 |
+
from llama_index.llms.base import LLMMetadata
|
6 |
+
|
7 |
+
from kron.llm_predictor.openai_utils import kron_openai_modelname_to_contextsize
|
8 |
+
|
9 |
+
def is_chat_model(llm: BaseLanguageModel) -> bool:
|
10 |
+
return isinstance(llm, BaseChatModel)
|
11 |
+
|
12 |
+
class KronLangChainLLM(LangChainLLM):
|
13 |
+
"""Adapter for a LangChain LLM."""
|
14 |
+
|
15 |
+
def __init__(self, llm: BaseLanguageModel) -> None:
|
16 |
+
super().__init__(llm)
|
17 |
+
|
18 |
+
|
19 |
+
@property
|
20 |
+
def metadata(self) -> LLMMetadata:
|
21 |
+
is_chat_model_ = is_chat_model(self.llm)
|
22 |
+
if isinstance(self.llm, OpenAI):
|
23 |
+
return LLMMetadata(
|
24 |
+
context_window=kron_openai_modelname_to_contextsize(self.llm.model_name),
|
25 |
+
num_output=self.llm.max_tokens,
|
26 |
+
is_chat_model=is_chat_model_ ,
|
27 |
+
)
|
28 |
+
elif isinstance(self.llm, ChatOpenAI):
|
29 |
+
return LLMMetadata(
|
30 |
+
context_window=kron_openai_modelname_to_contextsize(self.llm.model_name),
|
31 |
+
num_output=self.llm.max_tokens or -1,
|
32 |
+
is_chat_model=is_chat_model_ ,
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
return super().metadata()
|
kron/llm_predictor/KronOpenAILLM.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence
|
2 |
+
|
3 |
+
from llama_index.bridge.langchain import BaseLanguageModel, BaseChatModel
|
4 |
+
from llama_index.llms.langchain import LangChainLLM
|
5 |
+
from llama_index.llms.openai import OpenAI
|
6 |
+
|
7 |
+
from llama_index.llms.base import (
|
8 |
+
LLM,
|
9 |
+
ChatMessage,
|
10 |
+
ChatResponse,
|
11 |
+
ChatResponseAsyncGen,
|
12 |
+
ChatResponseGen,
|
13 |
+
CompletionResponse,
|
14 |
+
CompletionResponseAsyncGen,
|
15 |
+
CompletionResponseGen,
|
16 |
+
LLMMetadata,
|
17 |
+
)
|
18 |
+
|
19 |
+
from kron.llm_predictor.openai_utils import kron_openai_modelname_to_contextsize
|
20 |
+
|
21 |
+
class KronOpenAI(OpenAI):
|
22 |
+
|
23 |
+
@property
|
24 |
+
def metadata(self) -> LLMMetadata:
|
25 |
+
return LLMMetadata(
|
26 |
+
context_window=kron_openai_modelname_to_contextsize(self.model),
|
27 |
+
num_output=self.max_tokens or -1,
|
28 |
+
is_chat_model=self._is_chat_model,
|
29 |
+
)
|
30 |
+
|
31 |
+
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
32 |
+
#print("KronOpenAI complete called")
|
33 |
+
response = super().complete(prompt, **kwargs)
|
34 |
+
text = response.text
|
35 |
+
text = text.strip() #triples might not start at the begining of the line
|
36 |
+
#useful triplets are before <|endoftext|>
|
37 |
+
text = text.split("<|endoftext|>")[0]
|
38 |
+
response.text = text
|
39 |
+
return response
|
kron/llm_predictor/__init__.py
ADDED
File without changes
|
kron/llm_predictor/openai_utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LOCAL_MODELS = {
|
2 |
+
"Writer/camel-5b-hf": 2048,
|
3 |
+
"mosaicml/mpt-7b-instruct": 2048,
|
4 |
+
"mosaicml/mpt-30b-instruct": 8192,
|
5 |
+
}
|
6 |
+
|
7 |
+
GPT4_MODELS = {
|
8 |
+
# stable model names:
|
9 |
+
# resolves to gpt-4-0314 before 2023-06-27,
|
10 |
+
# resolves to gpt-4-0613 after
|
11 |
+
"gpt-4": 8192,
|
12 |
+
"gpt-4-32k": 32768,
|
13 |
+
# 0613 models (function calling):
|
14 |
+
# https://openai.com/blog/function-calling-and-other-api-updates
|
15 |
+
"gpt-4-0613": 8192,
|
16 |
+
"gpt-4-32k-0613": 32768,
|
17 |
+
# 0314 models
|
18 |
+
"gpt-4-0314": 8192,
|
19 |
+
"gpt-4-32k-0314": 32768,
|
20 |
+
}
|
21 |
+
|
22 |
+
AZURE_TURBO_MODELS = {
|
23 |
+
"gpt-35-turbo-16k": 16384,
|
24 |
+
"gpt-35-turbo": 4096,
|
25 |
+
}
|
26 |
+
|
27 |
+
TURBO_MODELS = {
|
28 |
+
# stable model names:
|
29 |
+
# resolves to gpt-3.5-turbo-0301 before 2023-06-27,
|
30 |
+
# resolves to gpt-3.5-turbo-0613 after
|
31 |
+
"gpt-3.5-turbo": 4096,
|
32 |
+
# resolves to gpt-3.5-turbo-16k-0613
|
33 |
+
"gpt-3.5-turbo-16k": 16384,
|
34 |
+
# 0613 models (function calling):
|
35 |
+
# https://openai.com/blog/function-calling-and-other-api-updates
|
36 |
+
"gpt-3.5-turbo-0613": 4096,
|
37 |
+
"gpt-3.5-turbo-16k-0613": 16384,
|
38 |
+
# 0301 models
|
39 |
+
"gpt-3.5-turbo-0301": 4096,
|
40 |
+
}
|
41 |
+
|
42 |
+
GPT3_5_MODELS = {
|
43 |
+
"text-davinci-003": 4097,
|
44 |
+
"text-davinci-002": 4097,
|
45 |
+
}
|
46 |
+
|
47 |
+
GPT3_MODELS = {
|
48 |
+
"text-ada-001": 2049,
|
49 |
+
"text-babbage-001": 2040,
|
50 |
+
"text-curie-001": 2049,
|
51 |
+
"ada": 2049,
|
52 |
+
"babbage": 2049,
|
53 |
+
"curie": 2049,
|
54 |
+
"davinci": 2049,
|
55 |
+
}
|
56 |
+
|
57 |
+
ALL_AVAILABLE_MODELS = {
|
58 |
+
**GPT4_MODELS,
|
59 |
+
**TURBO_MODELS,
|
60 |
+
**GPT3_5_MODELS,
|
61 |
+
**GPT3_MODELS,
|
62 |
+
**LOCAL_MODELS,
|
63 |
+
}
|
64 |
+
|
65 |
+
CHAT_MODELS = {
|
66 |
+
**GPT4_MODELS,
|
67 |
+
**TURBO_MODELS,
|
68 |
+
**AZURE_TURBO_MODELS,
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
DISCONTINUED_MODELS = {
|
73 |
+
"code-davinci-002": 8001,
|
74 |
+
"code-davinci-001": 8001,
|
75 |
+
"code-cushman-002": 2048,
|
76 |
+
"code-cushman-001": 2048,
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def kron_openai_modelname_to_contextsize(modelname: str) -> int:
|
81 |
+
"""Calculate the maximum number of tokens possible to generate for a model.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
modelname: The modelname we want to know the context size for.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
The maximum context size
|
88 |
+
|
89 |
+
Example:
|
90 |
+
.. code-block:: python
|
91 |
+
|
92 |
+
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
93 |
+
|
94 |
+
Modified from:
|
95 |
+
https://github.com/hwchase17/langchain/blob/master/langchain/llms/openai.py
|
96 |
+
"""
|
97 |
+
# handling finetuned models
|
98 |
+
if "ft-" in modelname:
|
99 |
+
modelname = modelname.split(":")[0]
|
100 |
+
|
101 |
+
if modelname in DISCONTINUED_MODELS:
|
102 |
+
raise ValueError(
|
103 |
+
f"OpenAI model {modelname} has been discontinued. "
|
104 |
+
"Please choose another model."
|
105 |
+
)
|
106 |
+
|
107 |
+
context_size = ALL_AVAILABLE_MODELS.get(modelname, None)
|
108 |
+
|
109 |
+
if context_size is None:
|
110 |
+
raise ValueError(
|
111 |
+
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
112 |
+
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
|
113 |
+
)
|
114 |
+
|
115 |
+
return context_size
|
kron/llm_predictor/utils.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
from llama_index.llms.base import LLM
|
3 |
+
from langchain.base_language import BaseLanguageModel
|
4 |
+
|
5 |
+
from kron.llm_predictor.KronLangChainLLM import KronLangChainLLM
|
6 |
+
from llama_index.llms.openai import OpenAI
|
7 |
+
|
8 |
+
from llama_index.llms.utils import LLMType
|
9 |
+
|
10 |
+
|
11 |
+
def kron_resolve_llm(llm: Optional[LLMType] = None) -> LLM:
|
12 |
+
if isinstance(llm, BaseLanguageModel):
|
13 |
+
# NOTE: if it's a langchain model, wrap it in a LangChainLLM
|
14 |
+
return KronLangChainLLM(llm=llm)
|
15 |
+
|
16 |
+
return llm or OpenAI()
|
kron/persistence/dynamodb_request_log.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#dynamodb access#
|
2 |
+
from datetime import datetime
|
3 |
+
import boto3
|
4 |
+
from botocore.exceptions import ClientError
|
5 |
+
|
6 |
+
import logging
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
session = boto3.Session(
|
10 |
+
# aws_access_key_id=AWS_ACCESS_KEY_ID,
|
11 |
+
# aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
|
12 |
+
)
|
13 |
+
dynamodb = session.resource('dynamodb')
|
14 |
+
|
15 |
+
class RequestLog:
|
16 |
+
"""Encapsulates an Amazon DynamoDB table of request data."""
|
17 |
+
def __init__(self, dyn_resource):
|
18 |
+
"""
|
19 |
+
:param dyn_resource: A Boto3 DynamoDB resource.
|
20 |
+
"""
|
21 |
+
self.dyn_resource = dyn_resource
|
22 |
+
self.table = None
|
23 |
+
|
24 |
+
def exists(self, table_name):
|
25 |
+
"""
|
26 |
+
Determines whether a table exists. As a side effect, stores the table in
|
27 |
+
a member variable.
|
28 |
+
|
29 |
+
:param table_name: The name of the table to check.
|
30 |
+
:return: True when the table exists; otherwise, False.
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
table = self.dyn_resource.Table(table_name)
|
34 |
+
table.load()
|
35 |
+
exists = True
|
36 |
+
except ClientError as err:
|
37 |
+
if err.response['Error']['Code'] == 'ResourceNotFoundException':
|
38 |
+
exists = False
|
39 |
+
else:
|
40 |
+
logger.error(
|
41 |
+
"Couldn't check for existence of %s. Here's why: %s: %s",
|
42 |
+
table_name,
|
43 |
+
err.response['Error']['Code'], err.response['Error']['Message'])
|
44 |
+
raise
|
45 |
+
else:
|
46 |
+
self.table = table
|
47 |
+
return exists
|
48 |
+
|
49 |
+
def create_table(self, table_name):
|
50 |
+
"""
|
51 |
+
Creates an Amazon DynamoDB table that can be used to store request data.
|
52 |
+
The table uses the release year of the movie as the partition key and the
|
53 |
+
title as the sort key.
|
54 |
+
|
55 |
+
:param table_name: The name of the table to create.
|
56 |
+
:return: The newly created table.
|
57 |
+
"""
|
58 |
+
try:
|
59 |
+
self.table = self.dyn_resource.create_table(
|
60 |
+
TableName=table_name,
|
61 |
+
KeySchema=[
|
62 |
+
{'AttributeName': 'model', 'KeyType': 'HASH'}, # Partition key
|
63 |
+
{'AttributeName': 'timestamp', 'KeyType': 'RANGE'} # Sort key
|
64 |
+
],
|
65 |
+
AttributeDefinitions=[
|
66 |
+
{'AttributeName': 'model', 'AttributeType': 'S'},
|
67 |
+
{'AttributeName': 'timestamp', 'AttributeType': 'S'},
|
68 |
+
# {'AttributeName': 'request', 'AttributeType': 'S'},
|
69 |
+
# {'AttributeName': 'response', 'AttributeType': 'S'}
|
70 |
+
],
|
71 |
+
ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10})
|
72 |
+
self.table.wait_until_exists()
|
73 |
+
except ClientError as err:
|
74 |
+
logger.error(
|
75 |
+
"Couldn't create table %s. Here's why: %s: %s", table_name,
|
76 |
+
err.response['Error']['Code'], err.response['Error']['Message'])
|
77 |
+
raise
|
78 |
+
else:
|
79 |
+
return self.table
|
80 |
+
|
81 |
+
def log_request(self, req_timestamp_str, model, request_str, response_str, rating = 0):
|
82 |
+
"""
|
83 |
+
Log a request to the table.
|
84 |
+
|
85 |
+
# TODO
|
86 |
+
:param title: The title of the movie.
|
87 |
+
:param year: The release year of the movie.
|
88 |
+
:param plot: The plot summary of the movie.
|
89 |
+
:param rating: The quality rating of the movie.
|
90 |
+
"""
|
91 |
+
try:
|
92 |
+
self.table.put_item(
|
93 |
+
Item={
|
94 |
+
'timestamp': req_timestamp_str,
|
95 |
+
'model': model,
|
96 |
+
'request': request_str,
|
97 |
+
'response': response_str,
|
98 |
+
'rating': rating,
|
99 |
+
}
|
100 |
+
)
|
101 |
+
except ClientError as err:
|
102 |
+
logger.error(
|
103 |
+
"Couldn't add request log %s to table %s. Here's why: %s: %s",
|
104 |
+
model, self.table.name,
|
105 |
+
err.response['Error']['Code'], err.response['Error']['Message'])
|
106 |
+
raise
|
107 |
+
|
108 |
+
def add_request_log_entry(self, query_model, req, resp, rating=0):
|
109 |
+
"""
|
110 |
+
Logs the cuurent model, req and response
|
111 |
+
"""
|
112 |
+
today = datetime.now()
|
113 |
+
# Get current ISO 8601 datetime in string format
|
114 |
+
iso_date = today.isoformat()
|
115 |
+
self.log_request(iso_date, query_model, req, resp, rating)
|
116 |
+
|
117 |
+
table_name = 'hf-spaces-request-log'
|
118 |
+
|
119 |
+
def get_request_log():
|
120 |
+
request_log = RequestLog(dynamodb)
|
121 |
+
request_log_exists = request_log.exists(table_name)
|
122 |
+
if not request_log_exists:
|
123 |
+
print(f"\nCreating table {table_name}...")
|
124 |
+
request_log.create_table(table_name)
|
125 |
+
print(f"\nCreated table {request_log.table.name}.")
|
126 |
+
return request_log
|
127 |
+
|
128 |
+
#def add_request_log_entry(request_log, query_model, req, resp, rating=0):
|
129 |
+
# today = datetime.now()
|
130 |
+
# # Get current ISO 8601 datetime in string format
|
131 |
+
# iso_date = today.isoformat()
|
132 |
+
# request_log.log_request(iso_date, query_model, req, resp, rating)
|
kron/prompts/kg_prompts.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Set of default prompts."""
|
2 |
+
|
3 |
+
from llama_index.prompts.base import Prompt
|
4 |
+
from llama_index.prompts.prompt_type import PromptType
|
5 |
+
|
6 |
+
############################################
|
7 |
+
# Tree
|
8 |
+
############################################
|
9 |
+
|
10 |
+
DEFAULT_SUMMARY_PROMPT_TMPL = (
|
11 |
+
"Write a summary of the following. Try to use only the "
|
12 |
+
"information provided. "
|
13 |
+
"Try to include as many key details as possible.\n"
|
14 |
+
"\n"
|
15 |
+
"\n"
|
16 |
+
"{context_str}\n"
|
17 |
+
"\n"
|
18 |
+
"\n"
|
19 |
+
'SUMMARY:"""\n'
|
20 |
+
)
|
21 |
+
|
22 |
+
DEFAULT_SUMMARY_PROMPT = Prompt(
|
23 |
+
DEFAULT_SUMMARY_PROMPT_TMPL, prompt_type=PromptType.SUMMARY
|
24 |
+
)
|
25 |
+
|
26 |
+
# insert prompts
|
27 |
+
DEFAULT_INSERT_PROMPT_TMPL = (
|
28 |
+
"Context information is below. It is provided in a numbered list "
|
29 |
+
"(1 to {num_chunks}),"
|
30 |
+
"where each item in the list corresponds to a summary.\n"
|
31 |
+
"---------------------\n"
|
32 |
+
"{context_list}"
|
33 |
+
"---------------------\n"
|
34 |
+
"Given the context information, here is a new piece of "
|
35 |
+
"information: {new_chunk_text}\n"
|
36 |
+
"Answer with the number corresponding to the summary that should be updated. "
|
37 |
+
"The answer should be the number corresponding to the "
|
38 |
+
"summary that is most relevant to the question.\n"
|
39 |
+
)
|
40 |
+
DEFAULT_INSERT_PROMPT = Prompt(
|
41 |
+
DEFAULT_INSERT_PROMPT_TMPL, prompt_type=PromptType.TREE_INSERT
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
# # single choice
|
46 |
+
DEFAULT_QUERY_PROMPT_TMPL = (
|
47 |
+
"Some choices are given below. It is provided in a numbered list "
|
48 |
+
"(1 to {num_chunks}),"
|
49 |
+
"where each item in the list corresponds to a summary.\n"
|
50 |
+
"---------------------\n"
|
51 |
+
"{context_list}"
|
52 |
+
"\n---------------------\n"
|
53 |
+
"Using only the choices above and not prior knowledge, return "
|
54 |
+
"the choice that is most relevant to the question: '{query_str}'\n"
|
55 |
+
"Provide choice in the following format: 'ANSWER: <number>' and explain why "
|
56 |
+
"this summary was selected in relation to the question.\n"
|
57 |
+
)
|
58 |
+
DEFAULT_QUERY_PROMPT = Prompt(
|
59 |
+
DEFAULT_QUERY_PROMPT_TMPL, prompt_type=PromptType.TREE_SELECT
|
60 |
+
)
|
61 |
+
|
62 |
+
# multiple choice
|
63 |
+
DEFAULT_QUERY_PROMPT_MULTIPLE_TMPL = (
|
64 |
+
"Some choices are given below. It is provided in a numbered "
|
65 |
+
"list (1 to {num_chunks}), "
|
66 |
+
"where each item in the list corresponds to a summary.\n"
|
67 |
+
"---------------------\n"
|
68 |
+
"{context_list}"
|
69 |
+
"\n---------------------\n"
|
70 |
+
"Using only the choices above and not prior knowledge, return the top choices "
|
71 |
+
"(no more than {branching_factor}, ranked by most relevant to least) that "
|
72 |
+
"are most relevant to the question: '{query_str}'\n"
|
73 |
+
"Provide choices in the following format: 'ANSWER: <numbers>' and explain why "
|
74 |
+
"these summaries were selected in relation to the question.\n"
|
75 |
+
)
|
76 |
+
DEFAULT_QUERY_PROMPT_MULTIPLE = Prompt(
|
77 |
+
DEFAULT_QUERY_PROMPT_MULTIPLE_TMPL, prompt_type=PromptType.TREE_SELECT_MULTIPLE
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
DEFAULT_REFINE_PROMPT_TMPL = (
|
82 |
+
"The original question is as follows: {query_str}\n"
|
83 |
+
"We have provided an existing answer: {existing_answer}\n"
|
84 |
+
"We have the opportunity to refine the existing answer "
|
85 |
+
"(only if needed) with some more context below.\n"
|
86 |
+
"------------\n"
|
87 |
+
"{context_msg}\n"
|
88 |
+
"------------\n"
|
89 |
+
"Given the new context, refine the original answer to better "
|
90 |
+
"answer the question. "
|
91 |
+
"If the context isn't useful, return the original answer."
|
92 |
+
)
|
93 |
+
DEFAULT_REFINE_PROMPT = Prompt(
|
94 |
+
DEFAULT_REFINE_PROMPT_TMPL, prompt_type=PromptType.REFINE
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
DEFAULT_TEXT_QA_PROMPT_TMPL = (
|
99 |
+
"Context information is below.\n"
|
100 |
+
"---------------------\n"
|
101 |
+
"{context_str}\n"
|
102 |
+
"---------------------\n"
|
103 |
+
"Given the context information and not prior knowledge, "
|
104 |
+
"answer the question: {query_str}\n"
|
105 |
+
)
|
106 |
+
DEFAULT_TEXT_QA_PROMPT = Prompt(
|
107 |
+
DEFAULT_TEXT_QA_PROMPT_TMPL, prompt_type=PromptType.QUESTION_ANSWER
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
############################################
|
112 |
+
# Keyword Table
|
113 |
+
############################################
|
114 |
+
|
115 |
+
DEFAULT_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
|
116 |
+
"Some text is provided below. Given the text, extract up to {max_keywords} "
|
117 |
+
"keywords from the text. Avoid stopwords."
|
118 |
+
"---------------------\n"
|
119 |
+
"{text}\n"
|
120 |
+
"---------------------\n"
|
121 |
+
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
|
122 |
+
)
|
123 |
+
DEFAULT_KEYWORD_EXTRACT_TEMPLATE = Prompt(
|
124 |
+
DEFAULT_KEYWORD_EXTRACT_TEMPLATE_TMPL, prompt_type=PromptType.KEYWORD_EXTRACT
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
# NOTE: the keyword extraction for queries can be the same as
|
129 |
+
# the one used to build the index, but here we tune it to see if performance is better.
|
130 |
+
DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
|
131 |
+
"A question is provided below. Given the question, extract up to {max_keywords} "
|
132 |
+
"keywords from the text. Focus on extracting the keywords that we can use "
|
133 |
+
"to best lookup answers to the question. Avoid stopwords.\n"
|
134 |
+
"---------------------\n"
|
135 |
+
"{question}\n"
|
136 |
+
"---------------------\n"
|
137 |
+
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
|
138 |
+
)
|
139 |
+
DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE = Prompt(
|
140 |
+
DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL,
|
141 |
+
prompt_type=PromptType.QUERY_KEYWORD_EXTRACT,
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
############################################
|
146 |
+
# Structured Store
|
147 |
+
############################################
|
148 |
+
|
149 |
+
DEFAULT_SCHEMA_EXTRACT_TMPL = (
|
150 |
+
"We wish to extract relevant fields from an unstructured text chunk into "
|
151 |
+
"a structured schema. We first provide the unstructured text, and then "
|
152 |
+
"we provide the schema that we wish to extract. "
|
153 |
+
"-----------text-----------\n"
|
154 |
+
"{text}\n"
|
155 |
+
"-----------schema-----------\n"
|
156 |
+
"{schema}\n"
|
157 |
+
"---------------------\n"
|
158 |
+
"Given the text and schema, extract the relevant fields from the text in "
|
159 |
+
"the following format: "
|
160 |
+
"field1: <value>\nfield2: <value>\n...\n\n"
|
161 |
+
"If a field is not present in the text, don't include it in the output."
|
162 |
+
"If no fields are present in the text, return a blank string.\n"
|
163 |
+
"Fields: "
|
164 |
+
)
|
165 |
+
DEFAULT_SCHEMA_EXTRACT_PROMPT = Prompt(
|
166 |
+
DEFAULT_SCHEMA_EXTRACT_TMPL, prompt_type=PromptType.SCHEMA_EXTRACT
|
167 |
+
)
|
168 |
+
|
169 |
+
# NOTE: taken from langchain and adapted
|
170 |
+
# https://tinyurl.com/b772sd77
|
171 |
+
DEFAULT_TEXT_TO_SQL_TMPL = (
|
172 |
+
"Given an input question, first create a syntactically correct {dialect} "
|
173 |
+
"query to run, then look at the results of the query and return the answer. "
|
174 |
+
"You can order the results by a relevant column to return the most "
|
175 |
+
"interesting examples in the database.\n"
|
176 |
+
"Never query for all the columns from a specific table, only ask for a "
|
177 |
+
"few relevant columns given the question.\n"
|
178 |
+
"Pay attention to use only the column names that you can see in the schema "
|
179 |
+
"description. "
|
180 |
+
"Be careful to not query for columns that do not exist. "
|
181 |
+
"Pay attention to which column is in which table. "
|
182 |
+
"Also, qualify column names with the table name when needed.\n"
|
183 |
+
"Use the following format:\n"
|
184 |
+
"Question: Question here\n"
|
185 |
+
"SQLQuery: SQL Query to run\n"
|
186 |
+
"SQLResult: Result of the SQLQuery\n"
|
187 |
+
"Answer: Final answer here\n"
|
188 |
+
"Only use the tables listed below.\n"
|
189 |
+
"{schema}\n"
|
190 |
+
"Question: {query_str}\n"
|
191 |
+
"SQLQuery: "
|
192 |
+
)
|
193 |
+
|
194 |
+
DEFAULT_TEXT_TO_SQL_PROMPT = Prompt(
|
195 |
+
DEFAULT_TEXT_TO_SQL_TMPL,
|
196 |
+
stop_token="\nSQLResult:",
|
197 |
+
prompt_type=PromptType.TEXT_TO_SQL,
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
# NOTE: by partially filling schema, we can reduce to a QuestionAnswer prompt
|
202 |
+
# that we can feed to ur table
|
203 |
+
DEFAULT_TABLE_CONTEXT_TMPL = (
|
204 |
+
"We have provided a table schema below. "
|
205 |
+
"---------------------\n"
|
206 |
+
"{schema}\n"
|
207 |
+
"---------------------\n"
|
208 |
+
"We have also provided context information below. "
|
209 |
+
"{context_str}\n"
|
210 |
+
"---------------------\n"
|
211 |
+
"Given the context information and the table schema, "
|
212 |
+
"give a response to the following task: {query_str}"
|
213 |
+
)
|
214 |
+
|
215 |
+
DEFAULT_TABLE_CONTEXT_QUERY = (
|
216 |
+
"Provide a high-level description of the table, "
|
217 |
+
"as well as a description of each column in the table. "
|
218 |
+
"Provide answers in the following format:\n"
|
219 |
+
"TableDescription: <description>\n"
|
220 |
+
"Column1Description: <description>\n"
|
221 |
+
"Column2Description: <description>\n"
|
222 |
+
"...\n\n"
|
223 |
+
)
|
224 |
+
|
225 |
+
DEFAULT_TABLE_CONTEXT_PROMPT = Prompt(
|
226 |
+
DEFAULT_TABLE_CONTEXT_TMPL, prompt_type=PromptType.TABLE_CONTEXT
|
227 |
+
)
|
228 |
+
|
229 |
+
# NOTE: by partially filling schema, we can reduce to a RefinePrompt
|
230 |
+
# that we can feed to ur table
|
231 |
+
DEFAULT_REFINE_TABLE_CONTEXT_TMPL = (
|
232 |
+
"We have provided a table schema below. "
|
233 |
+
"---------------------\n"
|
234 |
+
"{schema}\n"
|
235 |
+
"---------------------\n"
|
236 |
+
"We have also provided some context information below. "
|
237 |
+
"{context_msg}\n"
|
238 |
+
"---------------------\n"
|
239 |
+
"Given the context information and the table schema, "
|
240 |
+
"give a response to the following task: {query_str}\n"
|
241 |
+
"We have provided an existing answer: {existing_answer}\n"
|
242 |
+
"Given the new context, refine the original answer to better "
|
243 |
+
"answer the question. "
|
244 |
+
"If the context isn't useful, return the original answer."
|
245 |
+
)
|
246 |
+
DEFAULT_REFINE_TABLE_CONTEXT_PROMPT = Prompt(
|
247 |
+
DEFAULT_REFINE_TABLE_CONTEXT_TMPL, prompt_type=PromptType.TABLE_CONTEXT
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
############################################
|
252 |
+
# Knowledge-Graph Table
|
253 |
+
############################################
|
254 |
+
|
255 |
+
KRON_KG_TRIPLET_EXTRACT_TMPL = (
|
256 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
257 |
+
"Write a response that appropriately completes the request.\n\n"
|
258 |
+
"### Instruction:\n"
|
259 |
+
"Some text is provided below. Given the text, extract up to {max_knowledge_triplets} knowledge triplets in the form of "
|
260 |
+
"(subject, predicate, object).\n\n"
|
261 |
+
"### Input: \n"
|
262 |
+
"Text: Alice is Bob's mother. \n"
|
263 |
+
"Triplets: \n"
|
264 |
+
" (Alice, is mother of, Bob) \n"
|
265 |
+
"Text: Philz is a coffee shop founded in Berkeley in 1982. \n"
|
266 |
+
"Triplets: \n"
|
267 |
+
" (Philz, is, coffee shop) \n"
|
268 |
+
" (Philz, founded in, Berkeley) \n"
|
269 |
+
" (Philz, founded in, 1982) \n"
|
270 |
+
"Text: This small and colorful book is for children. \n"
|
271 |
+
"Triplets: \n"
|
272 |
+
" (book, is for, children)\n"
|
273 |
+
" (book, is, small and colorful) \n"
|
274 |
+
" (small book, is for, children) \n"
|
275 |
+
" (this small book, is for, children) \n"
|
276 |
+
"Text: We saw these dwellings, brightly painted cottages, shining in the sun. \n"
|
277 |
+
"Triplets: \n"
|
278 |
+
" (dwellings, are, brightly painted cottages) \n"
|
279 |
+
" (cottages, shine in, the sun) \n"
|
280 |
+
"--------------------- \n"
|
281 |
+
"### Text: {text} \n\n"
|
282 |
+
"### Triplets: "
|
283 |
+
)
|
284 |
+
|
285 |
+
KRON_KG_TRIPLET_EXTRACT_PROMPT = Prompt(
|
286 |
+
KRON_KG_TRIPLET_EXTRACT_TMPL, prompt_type=PromptType.KNOWLEDGE_TRIPLET_EXTRACT
|
287 |
+
)
|
288 |
+
|
289 |
+
############################################
|
290 |
+
# HYDE
|
291 |
+
##############################################
|
292 |
+
|
293 |
+
HYDE_TMPL = (
|
294 |
+
"Please write a passage to answer the question\n"
|
295 |
+
"Try to include as many key details as possible.\n"
|
296 |
+
"\n"
|
297 |
+
"\n"
|
298 |
+
"{context_str}\n"
|
299 |
+
"\n"
|
300 |
+
"\n"
|
301 |
+
'Passage:"""\n'
|
302 |
+
)
|
303 |
+
|
304 |
+
DEFAULT_HYDE_PROMPT = Prompt(HYDE_TMPL, prompt_type=PromptType.SUMMARY)
|
305 |
+
|
306 |
+
|
307 |
+
############################################
|
308 |
+
# Simple Input
|
309 |
+
############################################
|
310 |
+
|
311 |
+
DEFAULT_SIMPLE_INPUT_TMPL = "{query_str}"
|
312 |
+
DEFAULT_SIMPLE_INPUT_PROMPT = Prompt(
|
313 |
+
DEFAULT_SIMPLE_INPUT_TMPL, prompt_type=PromptType.SIMPLE_INPUT
|
314 |
+
)
|
315 |
+
|
316 |
+
|
317 |
+
############################################
|
318 |
+
# Pandas
|
319 |
+
############################################
|
320 |
+
|
321 |
+
DEFAULT_PANDAS_TMPL = (
|
322 |
+
"You are working with a pandas dataframe in Python.\n"
|
323 |
+
"The name of the dataframe is `df`.\n"
|
324 |
+
"This is the result of `print(df.head())`:\n"
|
325 |
+
"{df_str}\n\n"
|
326 |
+
"Here is the input query: {query_str}.\n"
|
327 |
+
"Given the df information and the input query, please follow "
|
328 |
+
"these instructions:\n"
|
329 |
+
"{instruction_str}"
|
330 |
+
"Output:\n"
|
331 |
+
)
|
332 |
+
|
333 |
+
DEFAULT_PANDAS_PROMPT = Prompt(DEFAULT_PANDAS_TMPL, prompt_type=PromptType.PANDAS)
|
334 |
+
|
335 |
+
|
336 |
+
############################################
|
337 |
+
# JSON Path
|
338 |
+
############################################
|
339 |
+
|
340 |
+
DEFAULT_JSON_PATH_TMPL = (
|
341 |
+
"We have provided a JSON schema below:\n"
|
342 |
+
"{schema}\n"
|
343 |
+
"Given a task, respond with a JSON Path query that "
|
344 |
+
"can retrieve data from a JSON value that matches the schema.\n"
|
345 |
+
"Task: {query_str}\n"
|
346 |
+
"JSONPath: "
|
347 |
+
)
|
348 |
+
|
349 |
+
DEFAULT_JSON_PATH_PROMPT = Prompt(
|
350 |
+
DEFAULT_JSON_PATH_TMPL, prompt_type=PromptType.JSON_PATH
|
351 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#local
|
2 |
+
#conda create -n mlk8s python=3.9.15 -y
|
3 |
+
#conda activate mlk8s
|
4 |
+
#pip install --upgrade streamlit
|
5 |
+
#pip install --upgrade huggingface_hub
|
6 |
+
#pip install -r requirements.txt
|
7 |
+
#streamlit run appname.py
|
8 |
+
|
9 |
+
torch
|
10 |
+
transformers
|
11 |
+
llama_index
|
12 |
+
pyvis
|
13 |
+
nltk
|
14 |
+
python-dotenv
|
15 |
+
cohere
|
16 |
+
baseten
|
17 |
+
st-star-rating
|
18 |
+
amazon-dax-client>=1.1.7
|
19 |
+
boto3>=1.26.79
|
20 |
+
pytest>=7.2.1
|
21 |
+
requests>=2.28.2
|
storage/Writer-camel-5b-hf-default-no-coref/graph_store.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
storage/Writer-camel-5b-hf-default-no-coref/index_store.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
storage/Writer-camel-5b-hf-default-no-coref/vector_store.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"embedding_dict": {}, "text_id_to_ref_doc_id": {}}
|