Ashhar commited on
Commit
59488ed
·
1 Parent(s): e2d2e30

added loaders + image generation

Browse files
Files changed (5) hide show
  1. app.py +96 -10
  2. balls.svg +3 -0
  3. bars_loader.svg +9 -0
  4. requirements.txt +2 -1
  5. ripple.svg +7 -0
app.py CHANGED
@@ -5,7 +5,9 @@ import pytz
5
  import time
6
  import json
7
  import re
 
8
  from transformers import AutoTokenizer
 
9
 
10
  from dotenv import load_dotenv
11
  load_dotenv()
@@ -46,7 +48,7 @@ Keep options to less than 9
46
  # Tier 1: Story Creation
47
  You initiate the storytelling process through a series of engaging prompts:
48
  Story Origin:
49
- Asks users to choose between personal anecdotes or adapting a well-known tale (creating a story database here of well-known stories to choose from).
50
 
51
  Story Use Case:
52
  Asks users to define the purpose of building a story (e.g., profile story, for social media content).
@@ -120,6 +122,8 @@ Note that the final story should include twist, turns and events that make it re
120
 
121
  USER_ICON = "man.png"
122
  AI_ICON = "Kommuneity.png"
 
 
123
  START_MSG = "I want to create a story 😊"
124
 
125
  st.set_page_config(
@@ -144,13 +148,53 @@ pprint("\n")
144
 
145
 
146
  def __isInvalidResponse(response: str):
 
147
  if len(re.findall(r'\n[a-z]', response)) > 3:
148
  return True
149
 
 
 
 
 
 
150
  if ('\n{\n "options"' in response) and (JSON_SEPARATOR not in response):
151
  return True
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def __resetButtonState():
155
  st.session_state["buttonValue"] = ""
156
 
@@ -171,7 +215,10 @@ if "startMsg" not in st.session_state:
171
 
172
  def predict(prompt):
173
  historyFormatted = [{"role": "system", "content": SYSTEM_MSG}]
174
- historyFormatted.extend(st.session_state.messages)
 
 
 
175
  historyFormatted.append({"role": "user", "content": prompt })
176
  contextSize = countTokens(str(historyFormatted))
177
  pprint(f"{contextSize=}")
@@ -192,18 +239,37 @@ def predict(prompt):
192
  yield chunkContent
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  st.title("Kommuneity Story Creator 📖")
196
- if not st.session_state.startMsg:
197
  st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
198
 
199
  for message in st.session_state.messages:
200
  role = message["role"]
201
  content = message["content"]
 
202
  avatar = AI_ICON if role == "assistant" else USER_ICON
203
  with st.chat_message(role, avatar=avatar):
204
  st.markdown(content)
 
 
205
 
206
- if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_state.startMsg):
207
  __resetButtonState()
208
  __setStartMsg("")
209
 
@@ -213,27 +279,30 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
213
  st.session_state.messages.append({"role": "user", "content": prompt })
214
 
215
  with st.chat_message("assistant", avatar=AI_ICON):
216
- placeholder = st.empty()
217
 
218
- def getResponse():
219
  response = ""
 
 
220
  responseGenerator = predict(prompt)
221
 
222
  for chunk in responseGenerator:
223
  response += chunk
224
  if __isInvalidResponse(response):
 
225
  return
226
 
227
  if JSON_SEPARATOR not in response:
228
- placeholder.markdown(response)
229
 
230
  return response
231
 
232
- response = getResponse()
233
  while not response:
234
  pprint("Empty response. Retrying..")
235
  time.sleep(0.5)
236
- response = getResponse()
237
 
238
  pprint(f"{response=}")
239
 
@@ -242,9 +311,22 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
242
  pprint(f"Selected: {optionLabel}")
243
 
244
  responseParts = response.split(JSON_SEPARATOR)
 
 
245
  if len(responseParts) > 1:
246
  [response, jsonStr] = responseParts
247
 
 
 
 
 
 
 
 
 
 
 
 
248
  try:
249
  json.loads(jsonStr)
250
  jsonObj = json.loads(jsonStr)
@@ -260,4 +342,8 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
260
  except Exception as e:
261
  pprint(e)
262
 
263
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
5
  import time
6
  import json
7
  import re
8
+ from typing import List
9
  from transformers import AutoTokenizer
10
+ from gradio_client import Client
11
 
12
  from dotenv import load_dotenv
13
  load_dotenv()
 
48
  # Tier 1: Story Creation
49
  You initiate the storytelling process through a series of engaging prompts:
50
  Story Origin:
51
+ Asks users to choose between personal anecdotes or adapting a well-known story (creating a story database here of well-known finctional stories to choose from).
52
 
53
  Story Use Case:
54
  Asks users to define the purpose of building a story (e.g., profile story, for social media content).
 
122
 
123
  USER_ICON = "man.png"
124
  AI_ICON = "Kommuneity.png"
125
+ IMAGE_LOADER = "ripple.svg"
126
+ TEXT_LOADER = "balls.svg"
127
  START_MSG = "I want to create a story 😊"
128
 
129
  st.set_page_config(
 
148
 
149
 
150
  def __isInvalidResponse(response: str):
151
+ # new line followed by small case char
152
  if len(re.findall(r'\n[a-z]', response)) > 3:
153
  return True
154
 
155
+ # lot of repeating words
156
+ if re.findall(r'\b(\w+)(\s+\1){2,}\b', response) == 0:
157
+ return True
158
+
159
+ # json response without json separator
160
  if ('\n{\n "options"' in response) and (JSON_SEPARATOR not in response):
161
  return True
162
 
163
 
164
+ def __matchingKeywordsCount(keywords: List[str], text: str):
165
+ return sum([
166
+ 1 if keyword in text else 0
167
+ for keyword in keywords
168
+ ])
169
+
170
+
171
+ def __isStringNumber(s: str) -> bool:
172
+ try:
173
+ float(s)
174
+ return True
175
+ except ValueError:
176
+ return False
177
+
178
+
179
+ def __getImageGenerationPrompt(prompt: str, response: str):
180
+ responseLower = response.lower()
181
+ if (
182
+ __matchingKeywordsCount(
183
+ ["adapt", "profile", "social media", "purpose", "use case"],
184
+ responseLower
185
+ ) > 2
186
+ and not __isStringNumber(prompt)
187
+ and prompt.lower() in responseLower
188
+ ):
189
+ return f'a scene from (({prompt})). Include main character'
190
+
191
+ if __matchingKeywordsCount(
192
+ ["Tier 2", "Tier-2"],
193
+ response
194
+ ) > 0:
195
+ return f"photo of a scene from this text: {response}"
196
+
197
+
198
  def __resetButtonState():
199
  st.session_state["buttonValue"] = ""
200
 
 
215
 
216
  def predict(prompt):
217
  historyFormatted = [{"role": "system", "content": SYSTEM_MSG}]
218
+ historyFormatted.extend([
219
+ {"role": message["role"], "content": message["content"]}
220
+ for message in st.session_state.messages
221
+ ])
222
  historyFormatted.append({"role": "user", "content": prompt })
223
  contextSize = countTokens(str(historyFormatted))
224
  pprint(f"{contextSize=}")
 
239
  yield chunkContent
240
 
241
 
242
+ def generateImage(prompt: str):
243
+ pprint(f"imagePrompt={prompt}")
244
+ client = Client("black-forest-labs/FLUX.1-schnell")
245
+ result = client.predict(
246
+ prompt=prompt,
247
+ seed=0,
248
+ randomize_seed=True,
249
+ width=1152,
250
+ height=896,
251
+ num_inference_steps=4,
252
+ api_name="/infer"
253
+ )
254
+ pprint(f"imageResult={result}")
255
+ return result
256
+
257
+
258
  st.title("Kommuneity Story Creator 📖")
259
+ if not (st.session_state["buttonValue"] or st.session_state["startMsg"]):
260
  st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
261
 
262
  for message in st.session_state.messages:
263
  role = message["role"]
264
  content = message["content"]
265
+ imagePath = message.get("image")
266
  avatar = AI_ICON if role == "assistant" else USER_ICON
267
  with st.chat_message(role, avatar=avatar):
268
  st.markdown(content)
269
+ if imagePath:
270
+ st.image(imagePath)
271
 
272
+ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_state["startMsg"]):
273
  __resetButtonState()
274
  __setStartMsg("")
275
 
 
279
  st.session_state.messages.append({"role": "user", "content": prompt })
280
 
281
  with st.chat_message("assistant", avatar=AI_ICON):
282
+ responseContainer = st.empty()
283
 
284
+ def __printAndGetResponse():
285
  response = ""
286
+ # responseContainer.markdown(".....")
287
+ responseContainer.image(TEXT_LOADER)
288
  responseGenerator = predict(prompt)
289
 
290
  for chunk in responseGenerator:
291
  response += chunk
292
  if __isInvalidResponse(response):
293
+ pprint(f"{response=}")
294
  return
295
 
296
  if JSON_SEPARATOR not in response:
297
+ responseContainer.markdown(response)
298
 
299
  return response
300
 
301
+ response = __printAndGetResponse()
302
  while not response:
303
  pprint("Empty response. Retrying..")
304
  time.sleep(0.5)
305
+ response = __printAndGetResponse()
306
 
307
  pprint(f"{response=}")
308
 
 
311
  pprint(f"Selected: {optionLabel}")
312
 
313
  responseParts = response.split(JSON_SEPARATOR)
314
+
315
+ jsonStr = None
316
  if len(responseParts) > 1:
317
  [response, jsonStr] = responseParts
318
 
319
+ imagePath = None
320
+ try:
321
+ imagePrompt = __getImageGenerationPrompt(prompt, response)
322
+ if imagePrompt:
323
+ imageContainer = st.empty().image(IMAGE_LOADER)
324
+ (imagePath, seed) = generateImage(imagePrompt)
325
+ imageContainer.image(imagePath)
326
+ except Exception as e:
327
+ pprint(e)
328
+
329
+ if jsonStr:
330
  try:
331
  json.loads(jsonStr)
332
  jsonObj = json.loads(jsonStr)
 
342
  except Exception as e:
343
  pprint(e)
344
 
345
+ st.session_state.messages.append({
346
+ "role": "assistant",
347
+ "content": response,
348
+ "image": imagePath,
349
+ })
balls.svg ADDED
bars_loader.svg ADDED
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  python-dotenv
2
  groq
3
- transformers
 
 
1
  python-dotenv
2
  groq
3
+ transformers
4
+ gradio_client
ripple.svg ADDED