Arylwen commited on
Commit
c0cd1dc
·
1 Parent(s): 62cb359

mlk8s v 0.0.1

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ __pycache__
app.py CHANGED
@@ -1,4 +1,299 @@
1
  import streamlit as st
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x*x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
+ import os
4
+ import re
5
+ import sys
6
+ import logging
7
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+
13
+ for key in st.session_state.keys():
14
+ #del st.session_state[key]
15
+ print(f'session state entry: {key} {st.session_state[key]}')
16
+
17
+ __spaces__ = os.environ.get('__SPACES__')
18
+
19
+ if __spaces__:
20
+ from kron.persistence.dynamodb_request_log import get_request_log;
21
+ st.session_state.request_log = get_request_log()
22
+
23
+ #third party service access
24
+ #hf inference api
25
+ hf_api_key = os.environ['HF_TOKEN']
26
+ ch_api_key = os.environ['COHERE_TOKEN']
27
+ bs_api_key = os.environ['BASETEN_TOKEN']
28
+
29
+ index_model = "Writer/camel-5b-hf"
30
+ INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref"
31
+ persist_path = f"storage/{INDEX_NAME}"
32
+ MAX_LENGTH = 1024
33
+
34
+ import baseten
35
+ @st.cache_resource
36
+ def set_baseten_key(bs_api_key):
37
+ baseten.login(bs_api_key)
38
+
39
+ set_baseten_key(bs_api_key)
40
+
41
+ from llama_index import StorageContext
42
+ from llama_index import ServiceContext
43
+ from llama_index import load_index_from_storage
44
+ from llama_index.langchain_helpers.text_splitter import SentenceSplitter
45
+ from llama_index.node_parser import SimpleNodeParser
46
+ from llama_index import LLMPredictor
47
+
48
+ from langchain import HuggingFaceHub
49
+ from langchain.llms.cohere import Cohere
50
+ from langchain.llms import Baseten
51
+
52
+ import tiktoken
53
+
54
+ import openai
55
+ #extensions to llama_index to support openai compatible endpoints, e.g. llama-api
56
+ from kron.llm_predictor.KronOpenAILLM import KronOpenAI
57
+ #baseten deployment expects a specific request format
58
+ from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM
59
+ from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor
60
+
61
+ #writer/camel uses endoftext
62
+ from llama_index.utils import globals_helper
63
+ enc = tiktoken.get_encoding("gpt2")
64
+ tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"})
65
+ globals_helper._tokenizer = tokenizer
66
+
67
+
68
+ def set_openai_local():
69
+ openai.api_key = os.environ['LOCAL_OPENAI_API_KEY']
70
+ openai.api_base = os.environ['LOCAL_OPENAI_API_BASE']
71
+ os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY']
72
+ os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE']
73
+
74
+ def set_openai():
75
+ openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY']
76
+ openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE']
77
+ os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY']
78
+ os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE']
79
+
80
+ def get_hf_predictor(query_model):
81
+ # no embeddings for now
82
+ set_openai_local()
83
+ llm=HuggingFaceHub(repo_id=query_model, task="text-generation",
84
+ model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
85
+ huggingfacehub_api_token=hf_api_key)
86
+ llm_predictor = LLMPredictor(llm)
87
+ return llm_predictor
88
+
89
+ def get_cohere_predictor(query_model):
90
+ # no embeddings for now
91
+ set_openai_local()
92
+ llm=Cohere(model='command', temperature = 0.01,
93
+ # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
94
+ cohere_api_key=ch_api_key)
95
+ llm_predictor = LLMPredictor(llm)
96
+ return llm_predictor
97
+
98
+ def get_baseten_predictor(query_model):
99
+ # no embeddings for now
100
+ set_openai_local()
101
+ llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01,
102
+ # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07},
103
+ model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1},
104
+ cohere_api_key=ch_api_key)
105
+ llm_predictor = LLMPredictor(llm)
106
+ return llm_predictor
107
+
108
+ def get_kron_openai_predictor(query_model):
109
+ # define LLM
110
+ llm=KronOpenAI(temperature=0.01, model=query_model)
111
+ llm.max_tokens = MAX_LENGTH
112
+ llm_predictor = KronLLMPredictor(llm)
113
+ return llm_predictor
114
+
115
+ def get_servce_context(llm_predictor):
116
+ # define TextSplitter
117
+ text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n')
118
+ #define NodeParser
119
+ node_parser = SimpleNodeParser(text_splitter=text_splitter)
120
+ #define ServiceContext
121
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser)
122
+ return service_context
123
+
124
+ def get_index(service_context, persist_path):
125
+ print(f'Loading index from {persist_path}')
126
+ # rebuild storage context
127
+ storage_context = StorageContext.from_defaults(persist_dir=persist_path)
128
+ # load index
129
+ index = load_index_from_storage(storage_context=storage_context,
130
+ service_context=service_context,
131
+ max_triplets_per_chunk=2,
132
+ show_progress = False)
133
+ return index
134
+
135
+ def get_query_engine(index):
136
+ #writer/camel does not understand the refine prompt
137
+ RESPONSE_MODE = 'accumulate'
138
+ query_engine = index.as_query_engine(response_mode = RESPONSE_MODE)
139
+ return query_engine
140
+
141
+ def load_query_engine(llm_predictor, persist_path):
142
+ service_context = get_servce_context(llm_predictor)
143
+ index = get_index(service_context, persist_path)
144
+ print(f'No query engine for {persist_path}; creating')
145
+ query_engine = get_query_engine(index)
146
+ return query_engine
147
+
148
+ @st.cache_resource
149
+ def build_kron_query_engine(query_model, persist_path):
150
+ llm_predictor = get_kron_openai_predictor(query_model)
151
+ query_engine = load_query_engine(llm_predictor, persist_path)
152
+ return query_engine
153
+
154
+ @st.cache_resource
155
+ def build_hf_query_engine(query_model, persist_path):
156
+ llm_predictor = get_hf_predictor(query_model)
157
+ query_engine = load_query_engine(llm_predictor, persist_path)
158
+ return query_engine
159
+
160
+ @st.cache_resource
161
+ def build_cohere_query_engine(query_model, persist_path):
162
+ llm_predictor = get_cohere_predictor(query_model)
163
+ query_engine = load_query_engine(llm_predictor, persist_path)
164
+ return query_engine
165
+
166
+ @st.cache_resource
167
+ def build_baseten_query_engine(query_model, persist_path):
168
+ llm_predictor = get_baseten_predictor(query_model)
169
+ query_engine = load_query_engine(llm_predictor, persist_path)
170
+ return query_engine
171
+
172
+ def format_response(answer):
173
+ # Replace any eventual --
174
+ dashes = r'(\-{2,50})'
175
+ answer.response = re.sub(dashes, '', answer.response)
176
+ return answer.response or "None"
177
+
178
+ def clear_question(query_model):
179
+ if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) :
180
+ if 'prev_model' in st.session_state:
181
+ print(f'clearing question {st.session_state.prev_model} {query_model}')
182
+ else:
183
+ print(f'clearing question None {query_model}')
184
+ if('question_input' in st.session_state):
185
+ st.session_state.question = st.session_state.question_input
186
+ st.session_state.question_input = ''
187
+ st.session_state.question_answered = False
188
+ st.session_state.answer = ''
189
+ st.session_state.prev_model = query_model
190
+
191
+
192
+ initial_query = ''
193
+ #st.session_state.prev_model = None
194
+
195
+ if 'question' not in st.session_state:
196
+ st.session_state.question = ''
197
+
198
+ if __spaces__ :
199
+ answer_model = st.radio(
200
+ "Choose the model used for inference:",
201
+ ('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') #TODO start hf inference container on demand
202
+ # ('cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003')
203
+ )
204
+ else :
205
+ answer_model = st.radio(
206
+ "Choose the model used for inference:",
207
+ ('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003')
208
+ )
209
+
210
+ if answer_model == 'openai/text-davinci-003':
211
+ print(answer_model)
212
+ query_model = 'text-davinci-003'
213
+ clear_question(query_model)
214
+ set_openai()
215
+ query_engine = build_kron_query_engine(query_model, persist_path)
216
+ elif answer_model == 'hf/tiiuae/falcon-7b-instruct':
217
+ print(answer_model)
218
+ query_model = 'tiiuae/falcon-7b-instruct'
219
+ clear_question(query_model)
220
+ query_engine = build_hf_query_engine(query_model, persist_path)
221
+ elif answer_model == 'cohere/command':
222
+ print(answer_model)
223
+ query_model = 'cohere/command'
224
+ clear_question(query_model)
225
+ query_engine = build_cohere_query_engine(query_model, persist_path)
226
+ elif answer_model == 'baseten/Camel-5b':
227
+ print(answer_model)
228
+ query_model = 'baseten/Camel-5b'
229
+ clear_question(query_model)
230
+ query_engine = build_baseten_query_engine(query_model, persist_path)
231
+ elif answer_model == 'Local-Camel':
232
+ query_model = 'Writer/camel-5b-hf'
233
+ print(answer_model)
234
+ clear_question(query_model)
235
+ set_openai_local()
236
+ query_engine = build_kron_query_engine(query_model, persist_path)
237
+ elif answer_model == 'HF-TKI':
238
+ query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl'
239
+ clear_question(query_model)
240
+ query_engine = build_hf_query_engine(query_model, persist_path)
241
+ else:
242
+ print('This is a bug.')
243
+
244
+ # to clear input box
245
+ def submit():
246
+ st.session_state.question = st.session_state.question_input
247
+ st.session_state.question_input = ''
248
+ st.session_state.question_answered = False
249
+
250
+ #def submit_rating(query_model, req, resp):
251
+ # print(f'query model {query_model}')
252
+ # if 'answer_rating' in st.session_state:
253
+ # print(f'rating {st.session_state.answer_rating}')
254
+
255
+ st.write(f'Model, question, answer and rating are logged to help with the improvement of this application.')
256
+ question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit )
257
+
258
+ # answer_str = None
259
+ if(st.session_state.question):
260
+ col1, col2 = st.columns([2, 2])
261
+ with col1:
262
+ st.write(f'Answering: {st.session_state.question} with {query_model}.')
263
+
264
+ try :
265
+ if not st.session_state.question_answered:
266
+ answer = query_engine.query(st.session_state.question)
267
+ st.session_state.answer = answer
268
+ st.session_state.question_answered = True
269
+ else:
270
+ answer = st.session_state.answer
271
+ answer_str = format_response(answer)
272
+ st.write(answer_str)
273
+ with col1:
274
+ if answer_str:
275
+ st.write(f' Please rate this answer.')
276
+ with col2:
277
+ from streamlit_star_rating import st_star_rating
278
+ stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating",
279
+ # customCSS = "div {background-color: red;}"
280
+ # on_change = submit_rating(query_model, st.session_state.question, answer_str)
281
+ )
282
+ print(f"------stars {stars}")
283
+ except Exception as e:
284
+ print(e)
285
+ answer_str = str(e)
286
+ st.session_state.answer_rating = -1
287
+ finally:
288
+ if 'question' in st.session_state:
289
+ req = st.session_state.question
290
+ #st.session_state.question = ''
291
+ if(__spaces__):
292
+ #request_log = get_request_log()
293
+ st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)
294
+
295
+ # if "answer_rating" in st.session_state:
296
+ # if(__spaces__):
297
+ # print('time to log the rating')
298
+ # #request_log = get_request_log()
299
+ # st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)
kron/__init__.py ADDED
File without changes
kron/indices/knowledge_graph/KronKnowledgeGraphIndex.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
3
+
4
+ from llama_index import KnowledgeGraphIndex
5
+ from llama_index.data_structs.data_structs import KG
6
+ from llama_index.indices.service_context import ServiceContext
7
+ from llama_index.prompts.prompts import KnowledgeGraphPrompt
8
+ from llama_index.storage.storage_context import StorageContext
9
+ from llama_index.schema import BaseNode
10
+
11
+ class KronKnowledgeGraphIndex(KnowledgeGraphIndex):
12
+ def __init__(
13
+ self,
14
+ nodes: Optional[Sequence[BaseNode]] = None,
15
+ index_struct: Optional[KG] = None,
16
+ service_context: Optional[ServiceContext] = None,
17
+ storage_context: Optional[StorageContext] = None,
18
+ kg_triple_extract_template: Optional[KnowledgeGraphPrompt] = None,
19
+ max_triplets_per_chunk: int = 10,
20
+ include_embeddings: bool = False,
21
+ **kwargs: Any,
22
+ ) -> None:
23
+ super().__init__(
24
+ nodes,
25
+ index_struct,
26
+ service_context,
27
+ storage_context,
28
+ kg_triple_extract_template,
29
+ max_triplets_per_chunk,
30
+ include_embeddings,
31
+ kwargs
32
+ )
33
+
34
+ def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
35
+ """Extract keywords from text."""
36
+ #response, _ = self._service_context.llm_predictor.predict(
37
+ response = self._service_context.llm_predictor.predict(
38
+ self.kg_triple_extract_template,
39
+ text=text,
40
+ )
41
+ return self._kron_parse_triplet_response(response)
42
+
43
+ @staticmethod
44
+ def _kron_parse_triplet_response(response: str) -> List[Tuple[str, str, str]]:
45
+ print("_kron_parse_triplet_response")
46
+ knowledge_strs = response.strip().split("\n")
47
+ results = []
48
+ for text in knowledge_strs:
49
+ text = text.strip() #triples might not start at the begining of the line
50
+ #text = text.replace('<|endoftext|>', '')
51
+ #useful triplets are before <|endoftext|>
52
+ text = text.split("<|endoftext|>")[0]
53
+ if text == "" or text[0] != "(":
54
+ # skip empty lines and non-triplets
55
+ continue
56
+ tokens = text[1:-1].split(",")
57
+ if len(tokens) != 3:
58
+ continue
59
+ subj, pred, obj = tokens
60
+ results.append((subj.strip(), pred.strip(), obj.strip()))
61
+ return results
kron/kg/__init__.py ADDED
File without changes
kron/llm_predictor/KronBasetenCamelLLM.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, List
2
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
3
+ from langchain.llms import Baseten
4
+
5
+ class KronBasetenCamelLLM(Baseten):
6
+ def _call(
7
+ self,
8
+ prompt: str,
9
+ stop: Optional[List[str]] = None,
10
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
11
+ **kwargs: Any,
12
+ ) -> str:
13
+ """Call to Baseten deployed model endpoint."""
14
+ try:
15
+ import baseten
16
+ except ImportError as exc:
17
+ raise ImportError(
18
+ "Could not import Baseten Python package. "
19
+ "Please install it with `pip install baseten`."
20
+ ) from exc
21
+
22
+ # get the model and version
23
+ try:
24
+ model = baseten.deployed_model_version_id(self.model)
25
+ response = model.predict({"instruction": prompt, **kwargs})
26
+ except baseten.common.core.ApiError:
27
+ model = baseten.deployed_model_id(self.model)
28
+ response = model.predict({"instruction": prompt, **kwargs})
29
+
30
+ response_txt = response['completion']
31
+ #print(f'\n********{response_txt}')
32
+ return response_txt
kron/llm_predictor/KronLLMPredictor.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Generator, Optional, Protocol, Tuple, runtime_checkable
3
+
4
+ from llama_index import LLMPredictor
5
+ from llama_index.llms.utils import LLMType
6
+ from llama_index.callbacks.base import CallbackManager
7
+
8
+ from kron.llm_predictor.utils import kron_resolve_llm
9
+
10
+ class KronLLMPredictor(LLMPredictor):
11
+ """LLM predictor class.
12
+
13
+ Wrapper around an LLMChain from Langchain.
14
+
15
+ Args:
16
+ llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use
17
+ for predictions. Defaults to OpenAI's text-davinci-003 model.
18
+ Please see `Langchain's LLM Page
19
+ <https://langchain.readthedocs.io/en/latest/modules/llms.html>`_
20
+ for more details.
21
+
22
+ retry_on_throttling (bool): Whether to retry on rate limit errors.
23
+ Defaults to true.
24
+
25
+ cache (Optional[langchain.cache.BaseCache]) : use cached result for LLM
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ llm: Optional[LLMType] = None,
31
+ callback_manager: Optional[CallbackManager] = None,
32
+ ) -> None:
33
+ """Initialize params."""
34
+ self._llm = kron_resolve_llm(llm)
35
+ self.callback_manager = callback_manager or CallbackManager([])
36
+
37
+
38
+
39
+
kron/llm_predictor/KronLangChainLLM.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.bridge.langchain import BaseLanguageModel, BaseChatModel
2
+ from llama_index.llms.langchain import LangChainLLM
3
+ from llama_index.bridge.langchain import OpenAI, ChatOpenAI
4
+
5
+ from llama_index.llms.base import LLMMetadata
6
+
7
+ from kron.llm_predictor.openai_utils import kron_openai_modelname_to_contextsize
8
+
9
+ def is_chat_model(llm: BaseLanguageModel) -> bool:
10
+ return isinstance(llm, BaseChatModel)
11
+
12
+ class KronLangChainLLM(LangChainLLM):
13
+ """Adapter for a LangChain LLM."""
14
+
15
+ def __init__(self, llm: BaseLanguageModel) -> None:
16
+ super().__init__(llm)
17
+
18
+
19
+ @property
20
+ def metadata(self) -> LLMMetadata:
21
+ is_chat_model_ = is_chat_model(self.llm)
22
+ if isinstance(self.llm, OpenAI):
23
+ return LLMMetadata(
24
+ context_window=kron_openai_modelname_to_contextsize(self.llm.model_name),
25
+ num_output=self.llm.max_tokens,
26
+ is_chat_model=is_chat_model_ ,
27
+ )
28
+ elif isinstance(self.llm, ChatOpenAI):
29
+ return LLMMetadata(
30
+ context_window=kron_openai_modelname_to_contextsize(self.llm.model_name),
31
+ num_output=self.llm.max_tokens or -1,
32
+ is_chat_model=is_chat_model_ ,
33
+ )
34
+ else:
35
+ return super().metadata()
kron/llm_predictor/KronOpenAILLM.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Awaitable, Callable, Dict, Optional, Sequence
2
+
3
+ from llama_index.bridge.langchain import BaseLanguageModel, BaseChatModel
4
+ from llama_index.llms.langchain import LangChainLLM
5
+ from llama_index.llms.openai import OpenAI
6
+
7
+ from llama_index.llms.base import (
8
+ LLM,
9
+ ChatMessage,
10
+ ChatResponse,
11
+ ChatResponseAsyncGen,
12
+ ChatResponseGen,
13
+ CompletionResponse,
14
+ CompletionResponseAsyncGen,
15
+ CompletionResponseGen,
16
+ LLMMetadata,
17
+ )
18
+
19
+ from kron.llm_predictor.openai_utils import kron_openai_modelname_to_contextsize
20
+
21
+ class KronOpenAI(OpenAI):
22
+
23
+ @property
24
+ def metadata(self) -> LLMMetadata:
25
+ return LLMMetadata(
26
+ context_window=kron_openai_modelname_to_contextsize(self.model),
27
+ num_output=self.max_tokens or -1,
28
+ is_chat_model=self._is_chat_model,
29
+ )
30
+
31
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
32
+ #print("KronOpenAI complete called")
33
+ response = super().complete(prompt, **kwargs)
34
+ text = response.text
35
+ text = text.strip() #triples might not start at the begining of the line
36
+ #useful triplets are before <|endoftext|>
37
+ text = text.split("<|endoftext|>")[0]
38
+ response.text = text
39
+ return response
kron/llm_predictor/__init__.py ADDED
File without changes
kron/llm_predictor/openai_utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LOCAL_MODELS = {
2
+ "Writer/camel-5b-hf": 2048,
3
+ "mosaicml/mpt-7b-instruct": 2048,
4
+ "mosaicml/mpt-30b-instruct": 8192,
5
+ }
6
+
7
+ GPT4_MODELS = {
8
+ # stable model names:
9
+ # resolves to gpt-4-0314 before 2023-06-27,
10
+ # resolves to gpt-4-0613 after
11
+ "gpt-4": 8192,
12
+ "gpt-4-32k": 32768,
13
+ # 0613 models (function calling):
14
+ # https://openai.com/blog/function-calling-and-other-api-updates
15
+ "gpt-4-0613": 8192,
16
+ "gpt-4-32k-0613": 32768,
17
+ # 0314 models
18
+ "gpt-4-0314": 8192,
19
+ "gpt-4-32k-0314": 32768,
20
+ }
21
+
22
+ AZURE_TURBO_MODELS = {
23
+ "gpt-35-turbo-16k": 16384,
24
+ "gpt-35-turbo": 4096,
25
+ }
26
+
27
+ TURBO_MODELS = {
28
+ # stable model names:
29
+ # resolves to gpt-3.5-turbo-0301 before 2023-06-27,
30
+ # resolves to gpt-3.5-turbo-0613 after
31
+ "gpt-3.5-turbo": 4096,
32
+ # resolves to gpt-3.5-turbo-16k-0613
33
+ "gpt-3.5-turbo-16k": 16384,
34
+ # 0613 models (function calling):
35
+ # https://openai.com/blog/function-calling-and-other-api-updates
36
+ "gpt-3.5-turbo-0613": 4096,
37
+ "gpt-3.5-turbo-16k-0613": 16384,
38
+ # 0301 models
39
+ "gpt-3.5-turbo-0301": 4096,
40
+ }
41
+
42
+ GPT3_5_MODELS = {
43
+ "text-davinci-003": 4097,
44
+ "text-davinci-002": 4097,
45
+ }
46
+
47
+ GPT3_MODELS = {
48
+ "text-ada-001": 2049,
49
+ "text-babbage-001": 2040,
50
+ "text-curie-001": 2049,
51
+ "ada": 2049,
52
+ "babbage": 2049,
53
+ "curie": 2049,
54
+ "davinci": 2049,
55
+ }
56
+
57
+ ALL_AVAILABLE_MODELS = {
58
+ **GPT4_MODELS,
59
+ **TURBO_MODELS,
60
+ **GPT3_5_MODELS,
61
+ **GPT3_MODELS,
62
+ **LOCAL_MODELS,
63
+ }
64
+
65
+ CHAT_MODELS = {
66
+ **GPT4_MODELS,
67
+ **TURBO_MODELS,
68
+ **AZURE_TURBO_MODELS,
69
+ }
70
+
71
+
72
+ DISCONTINUED_MODELS = {
73
+ "code-davinci-002": 8001,
74
+ "code-davinci-001": 8001,
75
+ "code-cushman-002": 2048,
76
+ "code-cushman-001": 2048,
77
+ }
78
+
79
+
80
+ def kron_openai_modelname_to_contextsize(modelname: str) -> int:
81
+ """Calculate the maximum number of tokens possible to generate for a model.
82
+
83
+ Args:
84
+ modelname: The modelname we want to know the context size for.
85
+
86
+ Returns:
87
+ The maximum context size
88
+
89
+ Example:
90
+ .. code-block:: python
91
+
92
+ max_tokens = openai.modelname_to_contextsize("text-davinci-003")
93
+
94
+ Modified from:
95
+ https://github.com/hwchase17/langchain/blob/master/langchain/llms/openai.py
96
+ """
97
+ # handling finetuned models
98
+ if "ft-" in modelname:
99
+ modelname = modelname.split(":")[0]
100
+
101
+ if modelname in DISCONTINUED_MODELS:
102
+ raise ValueError(
103
+ f"OpenAI model {modelname} has been discontinued. "
104
+ "Please choose another model."
105
+ )
106
+
107
+ context_size = ALL_AVAILABLE_MODELS.get(modelname, None)
108
+
109
+ if context_size is None:
110
+ raise ValueError(
111
+ f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
112
+ "Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
113
+ )
114
+
115
+ return context_size
kron/llm_predictor/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+ from llama_index.llms.base import LLM
3
+ from langchain.base_language import BaseLanguageModel
4
+
5
+ from kron.llm_predictor.KronLangChainLLM import KronLangChainLLM
6
+ from llama_index.llms.openai import OpenAI
7
+
8
+ from llama_index.llms.utils import LLMType
9
+
10
+
11
+ def kron_resolve_llm(llm: Optional[LLMType] = None) -> LLM:
12
+ if isinstance(llm, BaseLanguageModel):
13
+ # NOTE: if it's a langchain model, wrap it in a LangChainLLM
14
+ return KronLangChainLLM(llm=llm)
15
+
16
+ return llm or OpenAI()
kron/persistence/dynamodb_request_log.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #dynamodb access#
2
+ from datetime import datetime
3
+ import boto3
4
+ from botocore.exceptions import ClientError
5
+
6
+ import logging
7
+ logger = logging.getLogger(__name__)
8
+
9
+ session = boto3.Session(
10
+ # aws_access_key_id=AWS_ACCESS_KEY_ID,
11
+ # aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
12
+ )
13
+ dynamodb = session.resource('dynamodb')
14
+
15
+ class RequestLog:
16
+ """Encapsulates an Amazon DynamoDB table of request data."""
17
+ def __init__(self, dyn_resource):
18
+ """
19
+ :param dyn_resource: A Boto3 DynamoDB resource.
20
+ """
21
+ self.dyn_resource = dyn_resource
22
+ self.table = None
23
+
24
+ def exists(self, table_name):
25
+ """
26
+ Determines whether a table exists. As a side effect, stores the table in
27
+ a member variable.
28
+
29
+ :param table_name: The name of the table to check.
30
+ :return: True when the table exists; otherwise, False.
31
+ """
32
+ try:
33
+ table = self.dyn_resource.Table(table_name)
34
+ table.load()
35
+ exists = True
36
+ except ClientError as err:
37
+ if err.response['Error']['Code'] == 'ResourceNotFoundException':
38
+ exists = False
39
+ else:
40
+ logger.error(
41
+ "Couldn't check for existence of %s. Here's why: %s: %s",
42
+ table_name,
43
+ err.response['Error']['Code'], err.response['Error']['Message'])
44
+ raise
45
+ else:
46
+ self.table = table
47
+ return exists
48
+
49
+ def create_table(self, table_name):
50
+ """
51
+ Creates an Amazon DynamoDB table that can be used to store request data.
52
+ The table uses the release year of the movie as the partition key and the
53
+ title as the sort key.
54
+
55
+ :param table_name: The name of the table to create.
56
+ :return: The newly created table.
57
+ """
58
+ try:
59
+ self.table = self.dyn_resource.create_table(
60
+ TableName=table_name,
61
+ KeySchema=[
62
+ {'AttributeName': 'model', 'KeyType': 'HASH'}, # Partition key
63
+ {'AttributeName': 'timestamp', 'KeyType': 'RANGE'} # Sort key
64
+ ],
65
+ AttributeDefinitions=[
66
+ {'AttributeName': 'model', 'AttributeType': 'S'},
67
+ {'AttributeName': 'timestamp', 'AttributeType': 'S'},
68
+ # {'AttributeName': 'request', 'AttributeType': 'S'},
69
+ # {'AttributeName': 'response', 'AttributeType': 'S'}
70
+ ],
71
+ ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10})
72
+ self.table.wait_until_exists()
73
+ except ClientError as err:
74
+ logger.error(
75
+ "Couldn't create table %s. Here's why: %s: %s", table_name,
76
+ err.response['Error']['Code'], err.response['Error']['Message'])
77
+ raise
78
+ else:
79
+ return self.table
80
+
81
+ def log_request(self, req_timestamp_str, model, request_str, response_str, rating = 0):
82
+ """
83
+ Log a request to the table.
84
+
85
+ # TODO
86
+ :param title: The title of the movie.
87
+ :param year: The release year of the movie.
88
+ :param plot: The plot summary of the movie.
89
+ :param rating: The quality rating of the movie.
90
+ """
91
+ try:
92
+ self.table.put_item(
93
+ Item={
94
+ 'timestamp': req_timestamp_str,
95
+ 'model': model,
96
+ 'request': request_str,
97
+ 'response': response_str,
98
+ 'rating': rating,
99
+ }
100
+ )
101
+ except ClientError as err:
102
+ logger.error(
103
+ "Couldn't add request log %s to table %s. Here's why: %s: %s",
104
+ model, self.table.name,
105
+ err.response['Error']['Code'], err.response['Error']['Message'])
106
+ raise
107
+
108
+ def add_request_log_entry(self, query_model, req, resp, rating=0):
109
+ """
110
+ Logs the cuurent model, req and response
111
+ """
112
+ today = datetime.now()
113
+ # Get current ISO 8601 datetime in string format
114
+ iso_date = today.isoformat()
115
+ self.log_request(iso_date, query_model, req, resp, rating)
116
+
117
+ table_name = 'hf-spaces-request-log'
118
+
119
+ def get_request_log():
120
+ request_log = RequestLog(dynamodb)
121
+ request_log_exists = request_log.exists(table_name)
122
+ if not request_log_exists:
123
+ print(f"\nCreating table {table_name}...")
124
+ request_log.create_table(table_name)
125
+ print(f"\nCreated table {request_log.table.name}.")
126
+ return request_log
127
+
128
+ #def add_request_log_entry(request_log, query_model, req, resp, rating=0):
129
+ # today = datetime.now()
130
+ # # Get current ISO 8601 datetime in string format
131
+ # iso_date = today.isoformat()
132
+ # request_log.log_request(iso_date, query_model, req, resp, rating)
kron/prompts/kg_prompts.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Set of default prompts."""
2
+
3
+ from llama_index.prompts.base import Prompt
4
+ from llama_index.prompts.prompt_type import PromptType
5
+
6
+ ############################################
7
+ # Tree
8
+ ############################################
9
+
10
+ DEFAULT_SUMMARY_PROMPT_TMPL = (
11
+ "Write a summary of the following. Try to use only the "
12
+ "information provided. "
13
+ "Try to include as many key details as possible.\n"
14
+ "\n"
15
+ "\n"
16
+ "{context_str}\n"
17
+ "\n"
18
+ "\n"
19
+ 'SUMMARY:"""\n'
20
+ )
21
+
22
+ DEFAULT_SUMMARY_PROMPT = Prompt(
23
+ DEFAULT_SUMMARY_PROMPT_TMPL, prompt_type=PromptType.SUMMARY
24
+ )
25
+
26
+ # insert prompts
27
+ DEFAULT_INSERT_PROMPT_TMPL = (
28
+ "Context information is below. It is provided in a numbered list "
29
+ "(1 to {num_chunks}),"
30
+ "where each item in the list corresponds to a summary.\n"
31
+ "---------------------\n"
32
+ "{context_list}"
33
+ "---------------------\n"
34
+ "Given the context information, here is a new piece of "
35
+ "information: {new_chunk_text}\n"
36
+ "Answer with the number corresponding to the summary that should be updated. "
37
+ "The answer should be the number corresponding to the "
38
+ "summary that is most relevant to the question.\n"
39
+ )
40
+ DEFAULT_INSERT_PROMPT = Prompt(
41
+ DEFAULT_INSERT_PROMPT_TMPL, prompt_type=PromptType.TREE_INSERT
42
+ )
43
+
44
+
45
+ # # single choice
46
+ DEFAULT_QUERY_PROMPT_TMPL = (
47
+ "Some choices are given below. It is provided in a numbered list "
48
+ "(1 to {num_chunks}),"
49
+ "where each item in the list corresponds to a summary.\n"
50
+ "---------------------\n"
51
+ "{context_list}"
52
+ "\n---------------------\n"
53
+ "Using only the choices above and not prior knowledge, return "
54
+ "the choice that is most relevant to the question: '{query_str}'\n"
55
+ "Provide choice in the following format: 'ANSWER: <number>' and explain why "
56
+ "this summary was selected in relation to the question.\n"
57
+ )
58
+ DEFAULT_QUERY_PROMPT = Prompt(
59
+ DEFAULT_QUERY_PROMPT_TMPL, prompt_type=PromptType.TREE_SELECT
60
+ )
61
+
62
+ # multiple choice
63
+ DEFAULT_QUERY_PROMPT_MULTIPLE_TMPL = (
64
+ "Some choices are given below. It is provided in a numbered "
65
+ "list (1 to {num_chunks}), "
66
+ "where each item in the list corresponds to a summary.\n"
67
+ "---------------------\n"
68
+ "{context_list}"
69
+ "\n---------------------\n"
70
+ "Using only the choices above and not prior knowledge, return the top choices "
71
+ "(no more than {branching_factor}, ranked by most relevant to least) that "
72
+ "are most relevant to the question: '{query_str}'\n"
73
+ "Provide choices in the following format: 'ANSWER: <numbers>' and explain why "
74
+ "these summaries were selected in relation to the question.\n"
75
+ )
76
+ DEFAULT_QUERY_PROMPT_MULTIPLE = Prompt(
77
+ DEFAULT_QUERY_PROMPT_MULTIPLE_TMPL, prompt_type=PromptType.TREE_SELECT_MULTIPLE
78
+ )
79
+
80
+
81
+ DEFAULT_REFINE_PROMPT_TMPL = (
82
+ "The original question is as follows: {query_str}\n"
83
+ "We have provided an existing answer: {existing_answer}\n"
84
+ "We have the opportunity to refine the existing answer "
85
+ "(only if needed) with some more context below.\n"
86
+ "------------\n"
87
+ "{context_msg}\n"
88
+ "------------\n"
89
+ "Given the new context, refine the original answer to better "
90
+ "answer the question. "
91
+ "If the context isn't useful, return the original answer."
92
+ )
93
+ DEFAULT_REFINE_PROMPT = Prompt(
94
+ DEFAULT_REFINE_PROMPT_TMPL, prompt_type=PromptType.REFINE
95
+ )
96
+
97
+
98
+ DEFAULT_TEXT_QA_PROMPT_TMPL = (
99
+ "Context information is below.\n"
100
+ "---------------------\n"
101
+ "{context_str}\n"
102
+ "---------------------\n"
103
+ "Given the context information and not prior knowledge, "
104
+ "answer the question: {query_str}\n"
105
+ )
106
+ DEFAULT_TEXT_QA_PROMPT = Prompt(
107
+ DEFAULT_TEXT_QA_PROMPT_TMPL, prompt_type=PromptType.QUESTION_ANSWER
108
+ )
109
+
110
+
111
+ ############################################
112
+ # Keyword Table
113
+ ############################################
114
+
115
+ DEFAULT_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
116
+ "Some text is provided below. Given the text, extract up to {max_keywords} "
117
+ "keywords from the text. Avoid stopwords."
118
+ "---------------------\n"
119
+ "{text}\n"
120
+ "---------------------\n"
121
+ "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
122
+ )
123
+ DEFAULT_KEYWORD_EXTRACT_TEMPLATE = Prompt(
124
+ DEFAULT_KEYWORD_EXTRACT_TEMPLATE_TMPL, prompt_type=PromptType.KEYWORD_EXTRACT
125
+ )
126
+
127
+
128
+ # NOTE: the keyword extraction for queries can be the same as
129
+ # the one used to build the index, but here we tune it to see if performance is better.
130
+ DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
131
+ "A question is provided below. Given the question, extract up to {max_keywords} "
132
+ "keywords from the text. Focus on extracting the keywords that we can use "
133
+ "to best lookup answers to the question. Avoid stopwords.\n"
134
+ "---------------------\n"
135
+ "{question}\n"
136
+ "---------------------\n"
137
+ "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
138
+ )
139
+ DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE = Prompt(
140
+ DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL,
141
+ prompt_type=PromptType.QUERY_KEYWORD_EXTRACT,
142
+ )
143
+
144
+
145
+ ############################################
146
+ # Structured Store
147
+ ############################################
148
+
149
+ DEFAULT_SCHEMA_EXTRACT_TMPL = (
150
+ "We wish to extract relevant fields from an unstructured text chunk into "
151
+ "a structured schema. We first provide the unstructured text, and then "
152
+ "we provide the schema that we wish to extract. "
153
+ "-----------text-----------\n"
154
+ "{text}\n"
155
+ "-----------schema-----------\n"
156
+ "{schema}\n"
157
+ "---------------------\n"
158
+ "Given the text and schema, extract the relevant fields from the text in "
159
+ "the following format: "
160
+ "field1: <value>\nfield2: <value>\n...\n\n"
161
+ "If a field is not present in the text, don't include it in the output."
162
+ "If no fields are present in the text, return a blank string.\n"
163
+ "Fields: "
164
+ )
165
+ DEFAULT_SCHEMA_EXTRACT_PROMPT = Prompt(
166
+ DEFAULT_SCHEMA_EXTRACT_TMPL, prompt_type=PromptType.SCHEMA_EXTRACT
167
+ )
168
+
169
+ # NOTE: taken from langchain and adapted
170
+ # https://tinyurl.com/b772sd77
171
+ DEFAULT_TEXT_TO_SQL_TMPL = (
172
+ "Given an input question, first create a syntactically correct {dialect} "
173
+ "query to run, then look at the results of the query and return the answer. "
174
+ "You can order the results by a relevant column to return the most "
175
+ "interesting examples in the database.\n"
176
+ "Never query for all the columns from a specific table, only ask for a "
177
+ "few relevant columns given the question.\n"
178
+ "Pay attention to use only the column names that you can see in the schema "
179
+ "description. "
180
+ "Be careful to not query for columns that do not exist. "
181
+ "Pay attention to which column is in which table. "
182
+ "Also, qualify column names with the table name when needed.\n"
183
+ "Use the following format:\n"
184
+ "Question: Question here\n"
185
+ "SQLQuery: SQL Query to run\n"
186
+ "SQLResult: Result of the SQLQuery\n"
187
+ "Answer: Final answer here\n"
188
+ "Only use the tables listed below.\n"
189
+ "{schema}\n"
190
+ "Question: {query_str}\n"
191
+ "SQLQuery: "
192
+ )
193
+
194
+ DEFAULT_TEXT_TO_SQL_PROMPT = Prompt(
195
+ DEFAULT_TEXT_TO_SQL_TMPL,
196
+ stop_token="\nSQLResult:",
197
+ prompt_type=PromptType.TEXT_TO_SQL,
198
+ )
199
+
200
+
201
+ # NOTE: by partially filling schema, we can reduce to a QuestionAnswer prompt
202
+ # that we can feed to ur table
203
+ DEFAULT_TABLE_CONTEXT_TMPL = (
204
+ "We have provided a table schema below. "
205
+ "---------------------\n"
206
+ "{schema}\n"
207
+ "---------------------\n"
208
+ "We have also provided context information below. "
209
+ "{context_str}\n"
210
+ "---------------------\n"
211
+ "Given the context information and the table schema, "
212
+ "give a response to the following task: {query_str}"
213
+ )
214
+
215
+ DEFAULT_TABLE_CONTEXT_QUERY = (
216
+ "Provide a high-level description of the table, "
217
+ "as well as a description of each column in the table. "
218
+ "Provide answers in the following format:\n"
219
+ "TableDescription: <description>\n"
220
+ "Column1Description: <description>\n"
221
+ "Column2Description: <description>\n"
222
+ "...\n\n"
223
+ )
224
+
225
+ DEFAULT_TABLE_CONTEXT_PROMPT = Prompt(
226
+ DEFAULT_TABLE_CONTEXT_TMPL, prompt_type=PromptType.TABLE_CONTEXT
227
+ )
228
+
229
+ # NOTE: by partially filling schema, we can reduce to a RefinePrompt
230
+ # that we can feed to ur table
231
+ DEFAULT_REFINE_TABLE_CONTEXT_TMPL = (
232
+ "We have provided a table schema below. "
233
+ "---------------------\n"
234
+ "{schema}\n"
235
+ "---------------------\n"
236
+ "We have also provided some context information below. "
237
+ "{context_msg}\n"
238
+ "---------------------\n"
239
+ "Given the context information and the table schema, "
240
+ "give a response to the following task: {query_str}\n"
241
+ "We have provided an existing answer: {existing_answer}\n"
242
+ "Given the new context, refine the original answer to better "
243
+ "answer the question. "
244
+ "If the context isn't useful, return the original answer."
245
+ )
246
+ DEFAULT_REFINE_TABLE_CONTEXT_PROMPT = Prompt(
247
+ DEFAULT_REFINE_TABLE_CONTEXT_TMPL, prompt_type=PromptType.TABLE_CONTEXT
248
+ )
249
+
250
+
251
+ ############################################
252
+ # Knowledge-Graph Table
253
+ ############################################
254
+
255
+ KRON_KG_TRIPLET_EXTRACT_TMPL = (
256
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
257
+ "Write a response that appropriately completes the request.\n\n"
258
+ "### Instruction:\n"
259
+ "Some text is provided below. Given the text, extract up to {max_knowledge_triplets} knowledge triplets in the form of "
260
+ "(subject, predicate, object).\n\n"
261
+ "### Input: \n"
262
+ "Text: Alice is Bob's mother. \n"
263
+ "Triplets: \n"
264
+ " (Alice, is mother of, Bob) \n"
265
+ "Text: Philz is a coffee shop founded in Berkeley in 1982. \n"
266
+ "Triplets: \n"
267
+ " (Philz, is, coffee shop) \n"
268
+ " (Philz, founded in, Berkeley) \n"
269
+ " (Philz, founded in, 1982) \n"
270
+ "Text: This small and colorful book is for children. \n"
271
+ "Triplets: \n"
272
+ " (book, is for, children)\n"
273
+ " (book, is, small and colorful) \n"
274
+ " (small book, is for, children) \n"
275
+ " (this small book, is for, children) \n"
276
+ "Text: We saw these dwellings, brightly painted cottages, shining in the sun. \n"
277
+ "Triplets: \n"
278
+ " (dwellings, are, brightly painted cottages) \n"
279
+ " (cottages, shine in, the sun) \n"
280
+ "--------------------- \n"
281
+ "### Text: {text} \n\n"
282
+ "### Triplets: "
283
+ )
284
+
285
+ KRON_KG_TRIPLET_EXTRACT_PROMPT = Prompt(
286
+ KRON_KG_TRIPLET_EXTRACT_TMPL, prompt_type=PromptType.KNOWLEDGE_TRIPLET_EXTRACT
287
+ )
288
+
289
+ ############################################
290
+ # HYDE
291
+ ##############################################
292
+
293
+ HYDE_TMPL = (
294
+ "Please write a passage to answer the question\n"
295
+ "Try to include as many key details as possible.\n"
296
+ "\n"
297
+ "\n"
298
+ "{context_str}\n"
299
+ "\n"
300
+ "\n"
301
+ 'Passage:"""\n'
302
+ )
303
+
304
+ DEFAULT_HYDE_PROMPT = Prompt(HYDE_TMPL, prompt_type=PromptType.SUMMARY)
305
+
306
+
307
+ ############################################
308
+ # Simple Input
309
+ ############################################
310
+
311
+ DEFAULT_SIMPLE_INPUT_TMPL = "{query_str}"
312
+ DEFAULT_SIMPLE_INPUT_PROMPT = Prompt(
313
+ DEFAULT_SIMPLE_INPUT_TMPL, prompt_type=PromptType.SIMPLE_INPUT
314
+ )
315
+
316
+
317
+ ############################################
318
+ # Pandas
319
+ ############################################
320
+
321
+ DEFAULT_PANDAS_TMPL = (
322
+ "You are working with a pandas dataframe in Python.\n"
323
+ "The name of the dataframe is `df`.\n"
324
+ "This is the result of `print(df.head())`:\n"
325
+ "{df_str}\n\n"
326
+ "Here is the input query: {query_str}.\n"
327
+ "Given the df information and the input query, please follow "
328
+ "these instructions:\n"
329
+ "{instruction_str}"
330
+ "Output:\n"
331
+ )
332
+
333
+ DEFAULT_PANDAS_PROMPT = Prompt(DEFAULT_PANDAS_TMPL, prompt_type=PromptType.PANDAS)
334
+
335
+
336
+ ############################################
337
+ # JSON Path
338
+ ############################################
339
+
340
+ DEFAULT_JSON_PATH_TMPL = (
341
+ "We have provided a JSON schema below:\n"
342
+ "{schema}\n"
343
+ "Given a task, respond with a JSON Path query that "
344
+ "can retrieve data from a JSON value that matches the schema.\n"
345
+ "Task: {query_str}\n"
346
+ "JSONPath: "
347
+ )
348
+
349
+ DEFAULT_JSON_PATH_PROMPT = Prompt(
350
+ DEFAULT_JSON_PATH_TMPL, prompt_type=PromptType.JSON_PATH
351
+ )
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #local
2
+ #conda create -n mlk8s python=3.9.15 -y
3
+ #conda activate mlk8s
4
+ #pip install --upgrade streamlit
5
+ #pip install --upgrade huggingface_hub
6
+ #pip install -r requirements.txt
7
+ #streamlit run appname.py
8
+
9
+ torch
10
+ transformers
11
+ llama_index
12
+ pyvis
13
+ nltk
14
+ python-dotenv
15
+ cohere
16
+ baseten
17
+ st-star-rating
18
+ amazon-dax-client>=1.1.7
19
+ boto3>=1.26.79
20
+ pytest>=7.2.1
21
+ requests>=2.28.2
storage/Writer-camel-5b-hf-default-no-coref/graph_store.json ADDED
The diff for this file is too large to render. See raw diff
 
storage/Writer-camel-5b-hf-default-no-coref/index_store.json ADDED
The diff for this file is too large to render. See raw diff
 
storage/Writer-camel-5b-hf-default-no-coref/vector_store.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"embedding_dict": {}, "text_id_to_ref_doc_id": {}}