Spaces:
Running
Running
Add support for Gemini 1.5 Flash via Gemini API
Browse files- app.py +122 -110
- global_config.py +18 -3
- helpers/llm_helper.py +50 -31
- requirements.txt +1 -1
- strings.json +2 -1
app.py
CHANGED
@@ -5,7 +5,6 @@ import datetime
|
|
5 |
import logging
|
6 |
import pathlib
|
7 |
import random
|
8 |
-
import sys
|
9 |
import tempfile
|
10 |
from typing import List, Union
|
11 |
|
@@ -17,9 +16,6 @@ from langchain_community.chat_message_histories import StreamlitChatMessageHisto
|
|
17 |
from langchain_core.messages import HumanMessage
|
18 |
from langchain_core.prompts import ChatPromptTemplate
|
19 |
|
20 |
-
sys.path.append('..')
|
21 |
-
sys.path.append('../..')
|
22 |
-
|
23 |
from global_config import GlobalConfig
|
24 |
from helpers import llm_helper, pptx_helper, text_helper
|
25 |
|
@@ -54,6 +50,60 @@ def _get_prompt_template(is_refinement: bool) -> str:
|
|
54 |
return template
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
APP_TEXT = _load_strings()
|
58 |
|
59 |
# Session variables
|
@@ -80,11 +130,8 @@ with st.sidebar:
|
|
80 |
llm_provider_to_use = st.sidebar.selectbox(
|
81 |
label='2: Select an LLM to use:',
|
82 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
83 |
-
index=
|
84 |
-
help=
|
85 |
-
'LLM provider codes:\n\n'
|
86 |
-
'- **[hf]**: Hugging Face Inference Endpoint\n'
|
87 |
-
),
|
88 |
).split(' ')[0]
|
89 |
|
90 |
# The API key/access token
|
@@ -123,53 +170,28 @@ def set_up_chat_ui():
|
|
123 |
with st.expander('Usage Instructions'):
|
124 |
st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
|
125 |
|
126 |
-
st.info(
|
127 |
-
|
128 |
-
' [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or'
|
129 |
-
' a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai).'
|
130 |
-
' Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated.'
|
131 |
-
)
|
132 |
-
|
133 |
-
# view_messages = st.expander('View the messages in the session state')
|
134 |
-
|
135 |
-
st.chat_message('ai').write(
|
136 |
-
random.choice(APP_TEXT['ai_greetings'])
|
137 |
-
)
|
138 |
|
139 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
prompt_template = ChatPromptTemplate.from_template(template)
|
147 |
|
148 |
# Since Streamlit app reloads at every interaction, display the chat history
|
149 |
# from the save session state
|
150 |
for msg in history.messages:
|
151 |
-
|
152 |
-
if msg_type == 'user':
|
153 |
-
st.chat_message(msg_type).write(msg.content)
|
154 |
-
else:
|
155 |
-
st.chat_message(msg_type).code(msg.content, language='json')
|
156 |
|
157 |
if prompt := st.chat_input(
|
158 |
placeholder=APP_TEXT['chat_placeholder'],
|
159 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
160 |
):
|
161 |
-
if not text_helper.is_valid_prompt(prompt):
|
162 |
-
st.error(
|
163 |
-
'Not enough information provided!'
|
164 |
-
' Please be a little more descriptive and type a few words'
|
165 |
-
' with a few characters :)'
|
166 |
-
)
|
167 |
-
return
|
168 |
-
|
169 |
provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
|
170 |
|
171 |
-
if not provider
|
172 |
-
st.error('No valid LLM provider and/or model name found!')
|
173 |
return
|
174 |
|
175 |
logger.info(
|
@@ -178,72 +200,76 @@ def set_up_chat_ui():
|
|
178 |
)
|
179 |
st.chat_message('user').write(prompt)
|
180 |
|
181 |
-
user_messages = _get_user_messages()
|
182 |
-
user_messages.append(prompt)
|
183 |
-
list_of_msgs = [
|
184 |
-
f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
|
185 |
-
]
|
186 |
-
list_of_msgs = '\n'.join(list_of_msgs)
|
187 |
-
|
188 |
if _is_it_refinement():
|
|
|
|
|
|
|
|
|
|
|
189 |
formatted_template = prompt_template.format(
|
190 |
**{
|
191 |
-
'instructions': list_of_msgs,
|
192 |
'previous_content': _get_last_response(),
|
193 |
}
|
194 |
)
|
195 |
else:
|
196 |
-
formatted_template = prompt_template.format(
|
197 |
-
**{
|
198 |
-
'question': prompt,
|
199 |
-
}
|
200 |
-
)
|
201 |
|
202 |
progress_bar = st.progress(0, 'Preparing to call LLM...')
|
203 |
response = ''
|
204 |
|
205 |
try:
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
)
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
progress_bar.progress(
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
220 |
text='Streaming content...this might take a while...'
|
221 |
)
|
222 |
except requests.exceptions.ConnectionError:
|
223 |
-
|
224 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
225 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
226 |
-
' Alternatively, try selecting a different LLM from the dropdown list.'
|
|
|
227 |
)
|
228 |
-
logger.error(msg)
|
229 |
-
st.error(msg)
|
230 |
return
|
231 |
except huggingface_hub.errors.ValidationError as ve:
|
232 |
-
|
233 |
f'An error occurred while trying to generate the content: {ve}'
|
234 |
-
'\nPlease try again with a significantly shorter input text.'
|
|
|
235 |
)
|
236 |
-
logger.error(msg)
|
237 |
-
st.error(msg)
|
238 |
return
|
239 |
except Exception as ex:
|
240 |
-
|
241 |
f'An unexpected error occurred while generating the content: {ex}'
|
242 |
'\nPlease try again later, possibly with different inputs.'
|
243 |
-
' Alternatively, try selecting a different LLM from the dropdown list.'
|
|
|
244 |
)
|
245 |
-
logger.error(msg)
|
246 |
-
st.error(msg)
|
247 |
return
|
248 |
|
249 |
history.add_user_message(prompt)
|
@@ -252,25 +278,20 @@ def set_up_chat_ui():
|
|
252 |
# The content has been generated as JSON
|
253 |
# There maybe trailing ``` at the end of the response -- remove them
|
254 |
# To be careful: ``` may be part of the content as well when code is generated
|
255 |
-
|
256 |
-
|
257 |
logger.info(
|
258 |
-
'Cleaned JSON
|
259 |
-
len(response), len(response_cleaned)
|
260 |
)
|
261 |
-
# logger.debug('Cleaned JSON: %s', response_cleaned)
|
262 |
|
263 |
# Now create the PPT file
|
264 |
progress_bar.progress(
|
265 |
GlobalConfig.LLM_PROGRESS_MAX,
|
266 |
text='Finding photos online and generating the slide deck...'
|
267 |
)
|
268 |
-
path = generate_slide_deck(response_cleaned)
|
269 |
progress_bar.progress(1.0, text='Done!')
|
270 |
-
|
271 |
st.chat_message('ai').code(response, language='json')
|
272 |
|
273 |
-
if path:
|
274 |
_display_download_button(path)
|
275 |
|
276 |
logger.info(
|
@@ -291,44 +312,35 @@ def generate_slide_deck(json_str: str) -> Union[pathlib.Path, None]:
|
|
291 |
try:
|
292 |
parsed_data = json5.loads(json_str)
|
293 |
except ValueError:
|
294 |
-
|
295 |
-
'Encountered error while parsing JSON...will fix it and retry'
|
296 |
-
|
297 |
-
logger.error(
|
298 |
-
'Caught ValueError: trying again after repairing JSON...'
|
299 |
)
|
300 |
try:
|
301 |
parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
|
302 |
except ValueError:
|
303 |
-
|
304 |
'Encountered an error again while fixing JSON...'
|
305 |
'the slide deck cannot be created, unfortunately ☹'
|
306 |
-
'\nPlease try again later.'
|
|
|
307 |
)
|
308 |
-
logger.error(
|
309 |
-
'Caught ValueError: failed to repair JSON!'
|
310 |
-
)
|
311 |
-
|
312 |
return None
|
313 |
except RecursionError:
|
314 |
-
|
315 |
-
'Encountered
|
316 |
'the slide deck cannot be created, unfortunately ☹'
|
317 |
-
'\nPlease try again later.'
|
|
|
318 |
)
|
319 |
-
logger.error('Caught RecursionError while parsing JSON. Cannot generate the slide deck!')
|
320 |
-
|
321 |
return None
|
322 |
except Exception:
|
323 |
-
|
324 |
'Encountered an error while parsing JSON...'
|
325 |
'the slide deck cannot be created, unfortunately ☹'
|
326 |
-
'\nPlease try again later.'
|
327 |
-
|
328 |
-
logger.error(
|
329 |
-
'Caught ValueError: failed to parse JSON!'
|
330 |
)
|
331 |
-
|
332 |
return None
|
333 |
|
334 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
|
|
5 |
import logging
|
6 |
import pathlib
|
7 |
import random
|
|
|
8 |
import tempfile
|
9 |
from typing import List, Union
|
10 |
|
|
|
16 |
from langchain_core.messages import HumanMessage
|
17 |
from langchain_core.prompts import ChatPromptTemplate
|
18 |
|
|
|
|
|
|
|
19 |
from global_config import GlobalConfig
|
20 |
from helpers import llm_helper, pptx_helper, text_helper
|
21 |
|
|
|
50 |
return template
|
51 |
|
52 |
|
53 |
+
def are_all_inputs_valid(
|
54 |
+
user_prompt: str,
|
55 |
+
selected_provider: str,
|
56 |
+
selected_model: str,
|
57 |
+
user_key: str,
|
58 |
+
) -> bool:
|
59 |
+
"""
|
60 |
+
Validate user input and LLM selection.
|
61 |
+
|
62 |
+
:param user_prompt: The prompt.
|
63 |
+
:param selected_provider: The LLM provider.
|
64 |
+
:param selected_model: Name of the model.
|
65 |
+
:param user_key: User-provided API key.
|
66 |
+
:return: `True` if all inputs "look" OK; `False` otherwise.
|
67 |
+
"""
|
68 |
+
|
69 |
+
if not text_helper.is_valid_prompt(user_prompt):
|
70 |
+
handle_error(
|
71 |
+
'Not enough information provided!'
|
72 |
+
' Please be a little more descriptive and type a few words'
|
73 |
+
' with a few characters :)',
|
74 |
+
False
|
75 |
+
)
|
76 |
+
return False
|
77 |
+
|
78 |
+
if not selected_provider or not selected_model:
|
79 |
+
handle_error('No valid LLM provider and/or model name found!', False)
|
80 |
+
return False
|
81 |
+
|
82 |
+
if not llm_helper.is_valid_llm_provider_model(selected_provider, selected_model, user_key):
|
83 |
+
handle_error(
|
84 |
+
'The LLM settings do not look correct. Make sure that an API key/access token'
|
85 |
+
' is provided if the selected LLM requires it.',
|
86 |
+
False
|
87 |
+
)
|
88 |
+
return False
|
89 |
+
|
90 |
+
return True
|
91 |
+
|
92 |
+
|
93 |
+
def handle_error(error_msg: str, should_log: bool):
|
94 |
+
"""
|
95 |
+
Display an error message in the app.
|
96 |
+
|
97 |
+
:param error_msg: The error message to be displayed.
|
98 |
+
:param should_log: If `True`, log the message.
|
99 |
+
"""
|
100 |
+
|
101 |
+
if should_log:
|
102 |
+
logger.error(error_msg)
|
103 |
+
|
104 |
+
st.error(error_msg)
|
105 |
+
|
106 |
+
|
107 |
APP_TEXT = _load_strings()
|
108 |
|
109 |
# Session variables
|
|
|
130 |
llm_provider_to_use = st.sidebar.selectbox(
|
131 |
label='2: Select an LLM to use:',
|
132 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
133 |
+
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
134 |
+
help=GlobalConfig.LLM_PROVIDER_HELP,
|
|
|
|
|
|
|
135 |
).split(' ')[0]
|
136 |
|
137 |
# The API key/access token
|
|
|
170 |
with st.expander('Usage Instructions'):
|
171 |
st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
|
172 |
|
173 |
+
st.info(APP_TEXT['like_feedback'])
|
174 |
+
st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
177 |
+
prompt_template = ChatPromptTemplate.from_template(
|
178 |
+
_get_prompt_template(
|
179 |
+
is_refinement=_is_it_refinement()
|
180 |
+
)
|
181 |
+
)
|
|
|
|
|
182 |
|
183 |
# Since Streamlit app reloads at every interaction, display the chat history
|
184 |
# from the save session state
|
185 |
for msg in history.messages:
|
186 |
+
st.chat_message(msg.type).code(msg.content, language='json')
|
|
|
|
|
|
|
|
|
187 |
|
188 |
if prompt := st.chat_input(
|
189 |
placeholder=APP_TEXT['chat_placeholder'],
|
190 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
191 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
|
193 |
|
194 |
+
if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
|
|
|
195 |
return
|
196 |
|
197 |
logger.info(
|
|
|
200 |
)
|
201 |
st.chat_message('user').write(prompt)
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
if _is_it_refinement():
|
204 |
+
user_messages = _get_user_messages()
|
205 |
+
user_messages.append(prompt)
|
206 |
+
list_of_msgs = [
|
207 |
+
f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
|
208 |
+
]
|
209 |
formatted_template = prompt_template.format(
|
210 |
**{
|
211 |
+
'instructions': '\n'.join(list_of_msgs),
|
212 |
'previous_content': _get_last_response(),
|
213 |
}
|
214 |
)
|
215 |
else:
|
216 |
+
formatted_template = prompt_template.format(**{'question': prompt})
|
|
|
|
|
|
|
|
|
217 |
|
218 |
progress_bar = st.progress(0, 'Preparing to call LLM...')
|
219 |
response = ''
|
220 |
|
221 |
try:
|
222 |
+
llm = llm_helper.get_langchain_llm(
|
223 |
+
provider=provider,
|
224 |
+
model=llm_name,
|
225 |
+
max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
|
226 |
+
api_key=api_key_token.strip(),
|
227 |
+
)
|
228 |
+
|
229 |
+
if not llm:
|
230 |
+
handle_error(
|
231 |
+
'Failed to create an LLM instance! Make sure that you have selected the'
|
232 |
+
' correct model from the dropdown list and have provided correct API key'
|
233 |
+
' or access token.',
|
234 |
+
False
|
235 |
)
|
236 |
+
return
|
237 |
+
|
238 |
+
for _ in llm.stream(formatted_template):
|
239 |
+
response += _
|
240 |
+
|
241 |
+
# Update the progress bar with an approx progress percentage
|
242 |
progress_bar.progress(
|
243 |
+
min(
|
244 |
+
len(response) / GlobalConfig.VALID_MODELS[
|
245 |
+
llm_provider_to_use
|
246 |
+
]['max_new_tokens'],
|
247 |
+
0.95
|
248 |
+
),
|
249 |
text='Streaming content...this might take a while...'
|
250 |
)
|
251 |
except requests.exceptions.ConnectionError:
|
252 |
+
handle_error(
|
253 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
254 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
255 |
+
' Alternatively, try selecting a different LLM from the dropdown list.',
|
256 |
+
True
|
257 |
)
|
|
|
|
|
258 |
return
|
259 |
except huggingface_hub.errors.ValidationError as ve:
|
260 |
+
handle_error(
|
261 |
f'An error occurred while trying to generate the content: {ve}'
|
262 |
+
'\nPlease try again with a significantly shorter input text.',
|
263 |
+
True
|
264 |
)
|
|
|
|
|
265 |
return
|
266 |
except Exception as ex:
|
267 |
+
handle_error(
|
268 |
f'An unexpected error occurred while generating the content: {ex}'
|
269 |
'\nPlease try again later, possibly with different inputs.'
|
270 |
+
' Alternatively, try selecting a different LLM from the dropdown list.',
|
271 |
+
True
|
272 |
)
|
|
|
|
|
273 |
return
|
274 |
|
275 |
history.add_user_message(prompt)
|
|
|
278 |
# The content has been generated as JSON
|
279 |
# There maybe trailing ``` at the end of the response -- remove them
|
280 |
# To be careful: ``` may be part of the content as well when code is generated
|
281 |
+
response = text_helper.get_clean_json(response)
|
|
|
282 |
logger.info(
|
283 |
+
'Cleaned JSON length: %d', len(response)
|
|
|
284 |
)
|
|
|
285 |
|
286 |
# Now create the PPT file
|
287 |
progress_bar.progress(
|
288 |
GlobalConfig.LLM_PROGRESS_MAX,
|
289 |
text='Finding photos online and generating the slide deck...'
|
290 |
)
|
|
|
291 |
progress_bar.progress(1.0, text='Done!')
|
|
|
292 |
st.chat_message('ai').code(response, language='json')
|
293 |
|
294 |
+
if path := generate_slide_deck(response):
|
295 |
_display_download_button(path)
|
296 |
|
297 |
logger.info(
|
|
|
312 |
try:
|
313 |
parsed_data = json5.loads(json_str)
|
314 |
except ValueError:
|
315 |
+
handle_error(
|
316 |
+
'Encountered error while parsing JSON...will fix it and retry',
|
317 |
+
True
|
|
|
|
|
318 |
)
|
319 |
try:
|
320 |
parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
|
321 |
except ValueError:
|
322 |
+
handle_error(
|
323 |
'Encountered an error again while fixing JSON...'
|
324 |
'the slide deck cannot be created, unfortunately ☹'
|
325 |
+
'\nPlease try again later.',
|
326 |
+
True
|
327 |
)
|
|
|
|
|
|
|
|
|
328 |
return None
|
329 |
except RecursionError:
|
330 |
+
handle_error(
|
331 |
+
'Encountered a recursion error while parsing JSON...'
|
332 |
'the slide deck cannot be created, unfortunately ☹'
|
333 |
+
'\nPlease try again later.',
|
334 |
+
True
|
335 |
)
|
|
|
|
|
336 |
return None
|
337 |
except Exception:
|
338 |
+
handle_error(
|
339 |
'Encountered an error while parsing JSON...'
|
340 |
'the slide deck cannot be created, unfortunately ☹'
|
341 |
+
'\nPlease try again later.',
|
342 |
+
True
|
|
|
|
|
343 |
)
|
|
|
344 |
return None
|
345 |
|
346 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
global_config.py
CHANGED
@@ -17,17 +17,32 @@ class GlobalConfig:
|
|
17 |
A data class holding the configurations.
|
18 |
"""
|
19 |
|
20 |
-
|
|
|
|
|
21 |
VALID_MODELS = {
|
|
|
|
|
|
|
|
|
|
|
22 |
'[hf]mistralai/Mistral-7B-Instruct-v0.2': {
|
23 |
'description': 'faster, shorter',
|
24 |
-
'max_new_tokens': 8192
|
|
|
25 |
},
|
26 |
'[hf]mistralai/Mistral-Nemo-Instruct-2407': {
|
27 |
'description': 'longer response',
|
28 |
-
'max_new_tokens':
|
|
|
29 |
},
|
30 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
LLM_MODEL_TEMPERATURE = 0.2
|
32 |
LLM_MODEL_MIN_OUTPUT_LENGTH = 100
|
33 |
LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
|
|
|
17 |
A data class holding the configurations.
|
18 |
"""
|
19 |
|
20 |
+
PROVIDER_HUGGING_FACE = 'hf'
|
21 |
+
PROVIDER_GOOGLE_GEMINI = 'gg'
|
22 |
+
VALID_PROVIDERS = {PROVIDER_HUGGING_FACE, PROVIDER_GOOGLE_GEMINI}
|
23 |
VALID_MODELS = {
|
24 |
+
'[gg]gemini-1.5-flash-002': {
|
25 |
+
'description': 'faster response',
|
26 |
+
'max_new_tokens': 8192,
|
27 |
+
'paid': True,
|
28 |
+
},
|
29 |
'[hf]mistralai/Mistral-7B-Instruct-v0.2': {
|
30 |
'description': 'faster, shorter',
|
31 |
+
'max_new_tokens': 8192,
|
32 |
+
'paid': False,
|
33 |
},
|
34 |
'[hf]mistralai/Mistral-Nemo-Instruct-2407': {
|
35 |
'description': 'longer response',
|
36 |
+
'max_new_tokens': 10240,
|
37 |
+
'paid': False,
|
38 |
},
|
39 |
}
|
40 |
+
LLM_PROVIDER_HELP = (
|
41 |
+
'LLM provider codes:\n\n'
|
42 |
+
'- **[gg]**: Google Gemini API\n'
|
43 |
+
'- **[hf]**: Hugging Face Inference Endpoint\n'
|
44 |
+
)
|
45 |
+
DEFAULT_MODEL_INDEX = 1
|
46 |
LLM_MODEL_TEMPERATURE = 0.2
|
47 |
LLM_MODEL_MIN_OUTPUT_LENGTH = 100
|
48 |
LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
|
helpers/llm_helper.py
CHANGED
@@ -1,13 +1,18 @@
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import re
|
|
|
3 |
from typing import Tuple, Union
|
4 |
|
5 |
import requests
|
6 |
from requests.adapters import HTTPAdapter
|
7 |
from urllib3.util import Retry
|
8 |
-
|
9 |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
10 |
-
from langchain_core.language_models import
|
|
|
|
|
11 |
|
12 |
from global_config import GlobalConfig
|
13 |
|
@@ -49,30 +54,26 @@ def get_provider_model(provider_model: str) -> Tuple[str, str]:
|
|
49 |
return '', ''
|
50 |
|
51 |
|
52 |
-
def
|
53 |
"""
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
:param
|
59 |
-
:
|
|
|
|
|
60 |
"""
|
61 |
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
return
|
65 |
-
repo_id=repo_id,
|
66 |
-
max_new_tokens=max_new_tokens,
|
67 |
-
top_k=40,
|
68 |
-
top_p=0.95,
|
69 |
-
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
|
70 |
-
repetition_penalty=1.03,
|
71 |
-
streaming=True,
|
72 |
-
huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
|
73 |
-
return_full_text=False,
|
74 |
-
stop_sequences=['</s>'],
|
75 |
-
)
|
76 |
|
77 |
|
78 |
def get_langchain_llm(
|
@@ -80,22 +81,19 @@ def get_langchain_llm(
|
|
80 |
model: str,
|
81 |
max_new_tokens: int,
|
82 |
api_key: str = ''
|
83 |
-
) -> Union[
|
84 |
"""
|
85 |
Get an LLM based on the provider and model specified.
|
86 |
|
87 |
:param provider: The LLM provider. Valid values are `hf` for Hugging Face.
|
88 |
-
:param model:
|
89 |
-
:param max_new_tokens:
|
90 |
-
:param api_key:
|
91 |
-
:return:
|
92 |
"""
|
93 |
-
if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
|
94 |
-
return None
|
95 |
|
96 |
-
if provider ==
|
97 |
logger.debug('Getting LLM via HF endpoint: %s', model)
|
98 |
-
|
99 |
return HuggingFaceEndpoint(
|
100 |
repo_id=model,
|
101 |
max_new_tokens=max_new_tokens,
|
@@ -109,6 +107,27 @@ def get_langchain_llm(
|
|
109 |
stop_sequences=['</s>'],
|
110 |
)
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return None
|
113 |
|
114 |
|
|
|
1 |
+
"""
|
2 |
+
Helper functions to access LLMs.
|
3 |
+
"""
|
4 |
import logging
|
5 |
import re
|
6 |
+
import sys
|
7 |
from typing import Tuple, Union
|
8 |
|
9 |
import requests
|
10 |
from requests.adapters import HTTPAdapter
|
11 |
from urllib3.util import Retry
|
|
|
12 |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
13 |
+
from langchain_core.language_models import BaseLLM
|
14 |
+
|
15 |
+
sys.path.append('..')
|
16 |
|
17 |
from global_config import GlobalConfig
|
18 |
|
|
|
54 |
return '', ''
|
55 |
|
56 |
|
57 |
+
def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool:
|
58 |
"""
|
59 |
+
Verify whether LLM settings are proper.
|
60 |
+
This function does not verify whether `api_key` is correct. It only confirms that the key has
|
61 |
+
at least five characters. Key verification is done when the LLM is created.
|
62 |
+
|
63 |
+
:param provider: Name of the LLM provider.
|
64 |
+
:param model: Name of the model.
|
65 |
+
:param api_key: The API key or access token.
|
66 |
+
:return: `True` if the settings "look" OK; `False` otherwise.
|
67 |
"""
|
68 |
|
69 |
+
if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
|
70 |
+
return False
|
71 |
+
|
72 |
+
if provider in [GlobalConfig.PROVIDER_GOOGLE_GEMINI, ]:
|
73 |
+
if not api_key or len(api_key) < 5:
|
74 |
+
return False
|
75 |
|
76 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
|
79 |
def get_langchain_llm(
|
|
|
81 |
model: str,
|
82 |
max_new_tokens: int,
|
83 |
api_key: str = ''
|
84 |
+
) -> Union[BaseLLM, None]:
|
85 |
"""
|
86 |
Get an LLM based on the provider and model specified.
|
87 |
|
88 |
:param provider: The LLM provider. Valid values are `hf` for Hugging Face.
|
89 |
+
:param model: The name of the LLM.
|
90 |
+
:param max_new_tokens: The maximum number of tokens to generate.
|
91 |
+
:param api_key: API key or access token to use.
|
92 |
+
:return: An instance of the LLM or `None` in case of any error.
|
93 |
"""
|
|
|
|
|
94 |
|
95 |
+
if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
|
96 |
logger.debug('Getting LLM via HF endpoint: %s', model)
|
|
|
97 |
return HuggingFaceEndpoint(
|
98 |
repo_id=model,
|
99 |
max_new_tokens=max_new_tokens,
|
|
|
107 |
stop_sequences=['</s>'],
|
108 |
)
|
109 |
|
110 |
+
if provider == GlobalConfig.PROVIDER_GOOGLE_GEMINI:
|
111 |
+
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
|
112 |
+
from langchain_google_genai import GoogleGenerativeAI
|
113 |
+
|
114 |
+
return GoogleGenerativeAI(
|
115 |
+
model=model,
|
116 |
+
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
|
117 |
+
max_tokens=max_new_tokens,
|
118 |
+
timeout=None,
|
119 |
+
max_retries=2,
|
120 |
+
google_api_key=api_key,
|
121 |
+
safety_settings={
|
122 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
|
123 |
+
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
124 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
125 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
126 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
|
127 |
+
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
128 |
+
}
|
129 |
+
)
|
130 |
+
|
131 |
return None
|
132 |
|
133 |
|
requirements.txt
CHANGED
@@ -10,6 +10,7 @@ pydantic==2.9.1
|
|
10 |
langchain~=0.3.7
|
11 |
langchain-core~=0.3.0
|
12 |
langchain-community==0.3.0
|
|
|
13 |
streamlit~=1.38.0
|
14 |
|
15 |
python-pptx
|
@@ -19,7 +20,6 @@ requests~=2.32.3
|
|
19 |
|
20 |
transformers~=4.44.0
|
21 |
torch==2.4.0
|
22 |
-
langchain-community
|
23 |
|
24 |
urllib3~=2.2.1
|
25 |
lxml~=4.9.3
|
|
|
10 |
langchain~=0.3.7
|
11 |
langchain-core~=0.3.0
|
12 |
langchain-community==0.3.0
|
13 |
+
langchain-google-genai==2.0.6
|
14 |
streamlit~=1.38.0
|
15 |
|
16 |
python-pptx
|
|
|
20 |
|
21 |
transformers~=4.44.0
|
22 |
torch==2.4.0
|
|
|
23 |
|
24 |
urllib3~=2.2.1
|
25 |
lxml~=4.9.3
|
strings.json
CHANGED
@@ -33,5 +33,6 @@
|
|
33 |
"Looks like you have a looming deadline. Can I help you get started with your slide deck?",
|
34 |
"Hello! What topic do you have on your mind today?"
|
35 |
],
|
36 |
-
"chat_placeholder": "Write the topic or instructions here"
|
|
|
37 |
}
|
|
|
33 |
"Looks like you have a looming deadline. Can I help you get started with your slide deck?",
|
34 |
"Hello! What topic do you have on your mind today?"
|
35 |
],
|
36 |
+
"chat_placeholder": "Write the topic or instructions here",
|
37 |
+
"like_feedback": "If you like SlideDeck AI, please consider leaving a heart ❤\uFE0F on the [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai). Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated."
|
38 |
}
|