Joshua Sundance Bailey commited on
Commit
0ce4fb3
1 Parent(s): 21eccfc

cleanup & options

Browse files
kubernetes/resources.yaml CHANGED
@@ -71,6 +71,10 @@ spec:
71
  key: LANGCHAIN_API_KEY
72
  - name: LANGCHAIN_PROJECT
73
  value: "langchain-streamlit-demo"
 
 
 
 
74
  securityContext:
75
  runAsNonRoot: true
76
  ---
 
71
  key: LANGCHAIN_API_KEY
72
  - name: LANGCHAIN_PROJECT
73
  value: "langchain-streamlit-demo"
74
+ - name: SHOW_LANGCHAIN_OPTIONS
75
+ value: "False"
76
+ - name: SHOW_AZURE_OPTIONS
77
+ value: "False"
78
  securityContext:
79
  runAsNonRoot: true
80
  ---
langchain-streamlit-demo/app.py CHANGED
@@ -14,29 +14,8 @@ from langchain.schema.retriever import BaseRetriever
14
  from langsmith.client import Client
15
  from streamlit_feedback import streamlit_feedback
16
 
17
- from defaults import (
18
- MODEL_DICT,
19
- SUPPORTED_MODELS,
20
- DEFAULT_MODEL,
21
- DEFAULT_SYSTEM_PROMPT,
22
- MIN_TEMP,
23
- MAX_TEMP,
24
- DEFAULT_TEMP,
25
- MIN_MAX_TOKENS,
26
- MAX_MAX_TOKENS,
27
- DEFAULT_MAX_TOKENS,
28
- DEFAULT_LANGSMITH_PROJECT,
29
- AZURE_DICT,
30
- PROVIDER_KEY_DICT,
31
- OPENAI_API_KEY,
32
- MIN_CHUNK_SIZE,
33
- MAX_CHUNK_SIZE,
34
- DEFAULT_CHUNK_SIZE,
35
- MIN_CHUNK_OVERLAP,
36
- MAX_CHUNK_OVERLAP,
37
- DEFAULT_CHUNK_OVERLAP,
38
- DEFAULT_RETRIEVER_K,
39
- )
40
  from llm_resources import get_runnable, get_llm, get_texts_and_retriever, StreamHandler
41
 
42
  __version__ = "0.0.13"
@@ -81,12 +60,14 @@ RUN_COLLECTOR = RunCollectorCallbackHandler()
81
  @st.cache_data
82
  def get_texts_and_retriever_cacheable_wrapper(
83
  uploaded_file_bytes: bytes,
84
- chunk_size: int = DEFAULT_CHUNK_SIZE,
85
- chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
86
- k: int = DEFAULT_RETRIEVER_K,
 
87
  ) -> Tuple[List[Document], BaseRetriever]:
88
  return get_texts_and_retriever(
89
  uploaded_file_bytes=uploaded_file_bytes,
 
90
  chunk_size=chunk_size,
91
  chunk_overlap=chunk_overlap,
92
  k=k,
@@ -100,14 +81,14 @@ with sidebar:
100
 
101
  model = st.selectbox(
102
  label="Chat Model",
103
- options=SUPPORTED_MODELS,
104
- index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
105
  )
106
 
107
- st.session_state.provider = MODEL_DICT[model]
108
 
109
  provider_api_key = (
110
- PROVIDER_KEY_DICT.get(
111
  st.session_state.provider,
112
  )
113
  or st.text_input(
@@ -130,7 +111,7 @@ with sidebar:
130
  openai_api_key = (
131
  provider_api_key
132
  if st.session_state.provider == "OpenAI"
133
- else OPENAI_API_KEY
134
  or st.sidebar.text_input("OpenAI API Key: ", type="password")
135
  )
136
 
@@ -143,7 +124,7 @@ with sidebar:
143
  k = st.slider(
144
  label="Number of Chunks",
145
  help="How many document chunks will be used for context?",
146
- value=DEFAULT_RETRIEVER_K,
147
  min_value=1,
148
  max_value=10,
149
  )
@@ -151,17 +132,17 @@ with sidebar:
151
  chunk_size = st.slider(
152
  label="Number of Tokens per Chunk",
153
  help="Size of each chunk of text",
154
- min_value=MIN_CHUNK_SIZE,
155
- max_value=MAX_CHUNK_SIZE,
156
- value=DEFAULT_CHUNK_SIZE,
157
  )
158
 
159
  chunk_overlap = st.slider(
160
  label="Chunk Overlap",
161
  help="Number of characters to overlap between chunks",
162
- min_value=MIN_CHUNK_OVERLAP,
163
- max_value=MAX_CHUNK_OVERLAP,
164
- value=DEFAULT_CHUNK_OVERLAP,
165
  )
166
 
167
  chain_type_help_root = (
@@ -198,8 +179,9 @@ with sidebar:
198
  (
199
  st.session_state.texts,
200
  st.session_state.retriever,
201
- ) = get_texts_and_retriever(
202
  uploaded_file_bytes=uploaded_file.getvalue(),
 
203
  chunk_size=chunk_size,
204
  chunk_overlap=chunk_overlap,
205
  k=k,
@@ -216,7 +198,7 @@ with sidebar:
216
  system_prompt = (
217
  st.text_area(
218
  "Custom Instructions",
219
- DEFAULT_SYSTEM_PROMPT,
220
  help="Custom instructions to provide the language model to determine style, personality, etc.",
221
  )
222
  .strip()
@@ -226,84 +208,99 @@ with sidebar:
226
 
227
  temperature = st.slider(
228
  "Temperature",
229
- min_value=MIN_TEMP,
230
- max_value=MAX_TEMP,
231
- value=DEFAULT_TEMP,
232
  help="Higher values give more random results.",
233
  )
234
 
235
  max_tokens = st.slider(
236
  "Max Tokens",
237
- min_value=MIN_MAX_TOKENS,
238
- max_value=MAX_MAX_TOKENS,
239
- value=DEFAULT_MAX_TOKENS,
240
  help="Higher values give longer results.",
241
  )
242
 
243
  # --- LangSmith Options ---
244
- with st.expander("LangSmith Options", expanded=False):
245
- LANGSMITH_API_KEY = st.text_input(
246
- "LangSmith API Key (optional)",
247
- type="password",
248
- value=PROVIDER_KEY_DICT.get("LANGSMITH"),
249
- )
250
-
251
- LANGSMITH_PROJECT = st.text_input(
252
- "LangSmith Project Name",
253
- value=DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo",
254
- )
255
 
256
- if st.session_state.client is None and LANGSMITH_API_KEY:
257
- st.session_state.client = Client(
258
- api_url="https://api.smith.langchain.com",
259
- api_key=LANGSMITH_API_KEY,
 
 
260
  )
261
- st.session_state.ls_tracer = LangChainTracer(
262
- project_name=LANGSMITH_PROJECT,
263
- client=st.session_state.client,
 
264
  )
265
 
266
- # --- Azure Options ---
267
- with st.expander("Azure Options", expanded=False):
268
- AZURE_OPENAI_BASE_URL = st.text_input(
269
- "AZURE_OPENAI_BASE_URL",
270
- value=AZURE_DICT["AZURE_OPENAI_BASE_URL"],
271
  )
272
-
273
- AZURE_OPENAI_API_VERSION = st.text_input(
274
- "AZURE_OPENAI_API_VERSION",
275
- value=AZURE_DICT["AZURE_OPENAI_API_VERSION"],
276
  )
277
 
278
- AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
279
- "AZURE_OPENAI_DEPLOYMENT_NAME",
280
- value=AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"],
281
- )
 
 
 
 
 
 
 
 
 
 
 
282
 
283
- AZURE_OPENAI_API_KEY = st.text_input(
284
- "AZURE_OPENAI_API_KEY",
285
- value=AZURE_DICT["AZURE_OPENAI_API_KEY"],
286
- type="password",
287
- )
288
 
289
- AZURE_OPENAI_MODEL_VERSION = st.text_input(
290
- "AZURE_OPENAI_MODEL_VERSION",
291
- value=AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"],
292
- )
293
 
294
- AZURE_AVAILABLE = all(
295
- [
296
- AZURE_OPENAI_BASE_URL,
297
- AZURE_OPENAI_API_VERSION,
298
- AZURE_OPENAI_DEPLOYMENT_NAME,
299
- AZURE_OPENAI_API_KEY,
300
- AZURE_OPENAI_MODEL_VERSION,
301
- ],
302
- )
 
 
 
 
 
 
 
 
 
 
 
303
 
304
 
305
  # --- LLM Instantiation ---
306
- llm = get_llm(
307
  provider=st.session_state.provider,
308
  model=model,
309
  provider_api_key=provider_api_key,
@@ -384,6 +381,8 @@ if st.session_state.llm:
384
  st.session_state.llm,
385
  st.session_state.retriever,
386
  MEMORY,
 
 
387
  )
388
 
389
  # --- LLM call ---
 
14
  from langsmith.client import Client
15
  from streamlit_feedback import streamlit_feedback
16
 
17
+ from defaults import default_values
18
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from llm_resources import get_runnable, get_llm, get_texts_and_retriever, StreamHandler
20
 
21
  __version__ = "0.0.13"
 
60
  @st.cache_data
61
  def get_texts_and_retriever_cacheable_wrapper(
62
  uploaded_file_bytes: bytes,
63
+ openai_api_key: str,
64
+ chunk_size: int = default_values.DEFAULT_CHUNK_SIZE,
65
+ chunk_overlap: int = default_values.DEFAULT_CHUNK_OVERLAP,
66
+ k: int = default_values.DEFAULT_RETRIEVER_K,
67
  ) -> Tuple[List[Document], BaseRetriever]:
68
  return get_texts_and_retriever(
69
  uploaded_file_bytes=uploaded_file_bytes,
70
+ openai_api_key=openai_api_key,
71
  chunk_size=chunk_size,
72
  chunk_overlap=chunk_overlap,
73
  k=k,
 
81
 
82
  model = st.selectbox(
83
  label="Chat Model",
84
+ options=default_values.SUPPORTED_MODELS,
85
+ index=default_values.SUPPORTED_MODELS.index(default_values.DEFAULT_MODEL),
86
  )
87
 
88
+ st.session_state.provider = default_values.MODEL_DICT[model]
89
 
90
  provider_api_key = (
91
+ default_values.PROVIDER_KEY_DICT.get(
92
  st.session_state.provider,
93
  )
94
  or st.text_input(
 
111
  openai_api_key = (
112
  provider_api_key
113
  if st.session_state.provider == "OpenAI"
114
+ else default_values.OPENAI_API_KEY
115
  or st.sidebar.text_input("OpenAI API Key: ", type="password")
116
  )
117
 
 
124
  k = st.slider(
125
  label="Number of Chunks",
126
  help="How many document chunks will be used for context?",
127
+ value=default_values.DEFAULT_RETRIEVER_K,
128
  min_value=1,
129
  max_value=10,
130
  )
 
132
  chunk_size = st.slider(
133
  label="Number of Tokens per Chunk",
134
  help="Size of each chunk of text",
135
+ min_value=default_values.MIN_CHUNK_SIZE,
136
+ max_value=default_values.MAX_CHUNK_SIZE,
137
+ value=default_values.DEFAULT_CHUNK_SIZE,
138
  )
139
 
140
  chunk_overlap = st.slider(
141
  label="Chunk Overlap",
142
  help="Number of characters to overlap between chunks",
143
+ min_value=default_values.MIN_CHUNK_OVERLAP,
144
+ max_value=default_values.MAX_CHUNK_OVERLAP,
145
+ value=default_values.DEFAULT_CHUNK_OVERLAP,
146
  )
147
 
148
  chain_type_help_root = (
 
179
  (
180
  st.session_state.texts,
181
  st.session_state.retriever,
182
+ ) = get_texts_and_retriever_cacheable_wrapper(
183
  uploaded_file_bytes=uploaded_file.getvalue(),
184
+ openai_api_key=openai_api_key,
185
  chunk_size=chunk_size,
186
  chunk_overlap=chunk_overlap,
187
  k=k,
 
198
  system_prompt = (
199
  st.text_area(
200
  "Custom Instructions",
201
+ default_values.DEFAULT_SYSTEM_PROMPT,
202
  help="Custom instructions to provide the language model to determine style, personality, etc.",
203
  )
204
  .strip()
 
208
 
209
  temperature = st.slider(
210
  "Temperature",
211
+ min_value=default_values.MIN_TEMP,
212
+ max_value=default_values.MAX_TEMP,
213
+ value=default_values.DEFAULT_TEMP,
214
  help="Higher values give more random results.",
215
  )
216
 
217
  max_tokens = st.slider(
218
  "Max Tokens",
219
+ min_value=default_values.MIN_MAX_TOKENS,
220
+ max_value=default_values.MAX_MAX_TOKENS,
221
+ value=default_values.DEFAULT_MAX_TOKENS,
222
  help="Higher values give longer results.",
223
  )
224
 
225
  # --- LangSmith Options ---
226
+ LANGSMITH_API_KEY = default_values.PROVIDER_KEY_DICT.get("LANGSMITH")
227
+ LANGSMITH_PROJECT = (
228
+ default_values.DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo"
229
+ )
 
 
 
 
 
 
 
230
 
231
+ if default_values.SHOW_LANGSMITH_OPTIONS:
232
+ with st.expander("LangSmith Options", expanded=False):
233
+ LANGSMITH_API_KEY = st.text_input(
234
+ "LangSmith API Key (optional)",
235
+ value=LANGSMITH_API_KEY,
236
+ type="password",
237
  )
238
+
239
+ LANGSMITH_PROJECT = st.text_input(
240
+ "LangSmith Project Name",
241
+ value=LANGSMITH_PROJECT,
242
  )
243
 
244
+ if st.session_state.client is None and LANGSMITH_API_KEY:
245
+ st.session_state.client = Client(
246
+ api_url="https://api.smith.langchain.com",
247
+ api_key=LANGSMITH_API_KEY,
 
248
  )
249
+ st.session_state.ls_tracer = LangChainTracer(
250
+ project_name=LANGSMITH_PROJECT,
251
+ client=st.session_state.client,
 
252
  )
253
 
254
+ # --- Azure Options ---
255
+ AZURE_OPENAI_BASE_URL = default_values.AZURE_DICT["AZURE_OPENAI_BASE_URL"]
256
+ AZURE_OPENAI_API_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_API_VERSION"]
257
+ AZURE_OPENAI_DEPLOYMENT_NAME = default_values.AZURE_DICT[
258
+ "AZURE_OPENAI_DEPLOYMENT_NAME"
259
+ ]
260
+ AZURE_OPENAI_API_KEY = default_values.AZURE_DICT["AZURE_OPENAI_API_KEY"]
261
+ AZURE_OPENAI_MODEL_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"]
262
+
263
+ if default_values.SHOW_AZURE_OPTIONS:
264
+ with st.expander("Azure Options", expanded=False):
265
+ AZURE_OPENAI_BASE_URL = st.text_input(
266
+ "AZURE_OPENAI_BASE_URL",
267
+ value=AZURE_OPENAI_BASE_URL,
268
+ )
269
 
270
+ AZURE_OPENAI_API_VERSION = st.text_input(
271
+ "AZURE_OPENAI_API_VERSION",
272
+ value=AZURE_OPENAI_API_VERSION,
273
+ )
 
274
 
275
+ AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
276
+ "AZURE_OPENAI_DEPLOYMENT_NAME",
277
+ value=AZURE_OPENAI_DEPLOYMENT_NAME,
278
+ )
279
 
280
+ AZURE_OPENAI_API_KEY = st.text_input(
281
+ "AZURE_OPENAI_API_KEY",
282
+ value=AZURE_OPENAI_API_KEY,
283
+ type="password",
284
+ )
285
+
286
+ AZURE_OPENAI_MODEL_VERSION = st.text_input(
287
+ "AZURE_OPENAI_MODEL_VERSION",
288
+ value=AZURE_OPENAI_MODEL_VERSION,
289
+ )
290
+
291
+ AZURE_AVAILABLE = all(
292
+ [
293
+ AZURE_OPENAI_BASE_URL,
294
+ AZURE_OPENAI_API_VERSION,
295
+ AZURE_OPENAI_DEPLOYMENT_NAME,
296
+ AZURE_OPENAI_API_KEY,
297
+ AZURE_OPENAI_MODEL_VERSION,
298
+ ],
299
+ )
300
 
301
 
302
  # --- LLM Instantiation ---
303
+ st.session_state.llm = get_llm(
304
  provider=st.session_state.provider,
305
  model=model,
306
  provider_api_key=provider_api_key,
 
381
  st.session_state.llm,
382
  st.session_state.retriever,
383
  MEMORY,
384
+ chat_prompt,
385
+ prompt,
386
  )
387
 
388
  # --- LLM call ---
langchain-streamlit-demo/defaults.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
 
3
  MODEL_DICT = {
4
  "gpt-3.5-turbo": "OpenAI",
@@ -41,6 +43,12 @@ AZURE_VARS = [
41
 
42
  AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
43
 
 
 
 
 
 
 
44
  PROVIDER_KEY_DICT = {
45
  "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
46
  "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
@@ -60,3 +68,61 @@ MAX_CHUNK_OVERLAP = 10000
60
  DEFAULT_CHUNK_OVERLAP = 0
61
 
62
  DEFAULT_RETRIEVER_K = 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from collections import namedtuple
3
+
4
 
5
  MODEL_DICT = {
6
  "gpt-3.5-turbo": "OpenAI",
 
43
 
44
  AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
45
 
46
+
47
+ SHOW_LANGSMITH_OPTIONS = (
48
+ os.environ.get("SHOW_LANGSMITH_OPTIONS", "true").lower() == "true"
49
+ )
50
+ SHOW_AZURE_OPTIONS = os.environ.get("SHOW_AZURE_OPTIONS", "true").lower() == "true"
51
+
52
  PROVIDER_KEY_DICT = {
53
  "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
54
  "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
 
68
  DEFAULT_CHUNK_OVERLAP = 0
69
 
70
  DEFAULT_RETRIEVER_K = 4
71
+
72
+ DEFAULT_VALUES = namedtuple(
73
+ "DEFAULT_VALUES",
74
+ [
75
+ "MODEL_DICT",
76
+ "SUPPORTED_MODELS",
77
+ "DEFAULT_MODEL",
78
+ "DEFAULT_SYSTEM_PROMPT",
79
+ "MIN_TEMP",
80
+ "MAX_TEMP",
81
+ "DEFAULT_TEMP",
82
+ "MIN_MAX_TOKENS",
83
+ "MAX_MAX_TOKENS",
84
+ "DEFAULT_MAX_TOKENS",
85
+ "DEFAULT_LANGSMITH_PROJECT",
86
+ "AZURE_VARS",
87
+ "AZURE_DICT",
88
+ "PROVIDER_KEY_DICT",
89
+ "OPENAI_API_KEY",
90
+ "MIN_CHUNK_SIZE",
91
+ "MAX_CHUNK_SIZE",
92
+ "DEFAULT_CHUNK_SIZE",
93
+ "MIN_CHUNK_OVERLAP",
94
+ "MAX_CHUNK_OVERLAP",
95
+ "DEFAULT_CHUNK_OVERLAP",
96
+ "DEFAULT_RETRIEVER_K",
97
+ "SHOW_LANGSMITH_OPTIONS",
98
+ "SHOW_AZURE_OPTIONS",
99
+ ],
100
+ )
101
+
102
+
103
+ default_values = DEFAULT_VALUES(
104
+ MODEL_DICT,
105
+ SUPPORTED_MODELS,
106
+ DEFAULT_MODEL,
107
+ DEFAULT_SYSTEM_PROMPT,
108
+ MIN_TEMP,
109
+ MAX_TEMP,
110
+ DEFAULT_TEMP,
111
+ MIN_MAX_TOKENS,
112
+ MAX_MAX_TOKENS,
113
+ DEFAULT_MAX_TOKENS,
114
+ DEFAULT_LANGSMITH_PROJECT,
115
+ AZURE_VARS,
116
+ AZURE_DICT,
117
+ PROVIDER_KEY_DICT,
118
+ OPENAI_API_KEY,
119
+ MIN_CHUNK_SIZE,
120
+ MAX_CHUNK_SIZE,
121
+ DEFAULT_CHUNK_SIZE,
122
+ MIN_CHUNK_OVERLAP,
123
+ MAX_CHUNK_OVERLAP,
124
+ DEFAULT_CHUNK_OVERLAP,
125
+ DEFAULT_RETRIEVER_K,
126
+ SHOW_LANGSMITH_OPTIONS,
127
+ SHOW_AZURE_OPTIONS,
128
+ )
langchain-streamlit-demo/llm_resources.py CHANGED
@@ -1,9 +1,8 @@
1
  from tempfile import NamedTemporaryFile
2
  from typing import Tuple, List
3
 
4
- from langchain import LLMChain, FAISS
5
  from langchain.callbacks.base import BaseCallbackHandler
6
- from langchain.chains import RetrievalQA
7
  from langchain.chat_models import (
8
  AzureChatOpenAI,
9
  ChatOpenAI,
@@ -15,8 +14,8 @@ from langchain.embeddings import OpenAIEmbeddings
15
  from langchain.retrievers import BM25Retriever, EnsembleRetriever
16
  from langchain.schema import Document, BaseRetriever
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
18
 
19
- from app import chat_prompt, prompt, openai_api_key
20
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
21
  from qagen import get_rag_qa_gen_chain
22
  from summarize import get_rag_summarization_chain
@@ -28,6 +27,8 @@ def get_runnable(
28
  llm,
29
  retriever,
30
  memory,
 
 
31
  ):
32
  if not use_document_chat:
33
  return LLMChain(
@@ -43,7 +44,7 @@ def get_runnable(
43
  )
44
  elif document_chat_chain_type == "Summarization":
45
  return get_rag_summarization_chain(
46
- prompt,
47
  retriever,
48
  llm,
49
  )
@@ -112,6 +113,7 @@ def get_llm(
112
 
113
  def get_texts_and_retriever(
114
  uploaded_file_bytes: bytes,
 
115
  chunk_size: int = DEFAULT_CHUNK_SIZE,
116
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
117
  k: int = DEFAULT_RETRIEVER_K,
 
1
  from tempfile import NamedTemporaryFile
2
  from typing import Tuple, List
3
 
 
4
  from langchain.callbacks.base import BaseCallbackHandler
5
+ from langchain.chains import RetrievalQA, LLMChain
6
  from langchain.chat_models import (
7
  AzureChatOpenAI,
8
  ChatOpenAI,
 
14
  from langchain.retrievers import BM25Retriever, EnsembleRetriever
15
  from langchain.schema import Document, BaseRetriever
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain.vectorstores import FAISS
18
 
 
19
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
20
  from qagen import get_rag_qa_gen_chain
21
  from summarize import get_rag_summarization_chain
 
27
  llm,
28
  retriever,
29
  memory,
30
+ chat_prompt,
31
+ summarization_prompt,
32
  ):
33
  if not use_document_chat:
34
  return LLMChain(
 
44
  )
45
  elif document_chat_chain_type == "Summarization":
46
  return get_rag_summarization_chain(
47
+ summarization_prompt,
48
  retriever,
49
  llm,
50
  )
 
113
 
114
  def get_texts_and_retriever(
115
  uploaded_file_bytes: bytes,
116
+ openai_api_key: str,
117
  chunk_size: int = DEFAULT_CHUNK_SIZE,
118
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
119
  k: int = DEFAULT_RETRIEVER_K,