Ashhar
commited on
Commit
·
59488ed
1
Parent(s):
e2d2e30
added loaders + image generation
Browse files- app.py +96 -10
- balls.svg +3 -0
- bars_loader.svg +9 -0
- requirements.txt +2 -1
- ripple.svg +7 -0
app.py
CHANGED
@@ -5,7 +5,9 @@ import pytz
|
|
5 |
import time
|
6 |
import json
|
7 |
import re
|
|
|
8 |
from transformers import AutoTokenizer
|
|
|
9 |
|
10 |
from dotenv import load_dotenv
|
11 |
load_dotenv()
|
@@ -46,7 +48,7 @@ Keep options to less than 9
|
|
46 |
# Tier 1: Story Creation
|
47 |
You initiate the storytelling process through a series of engaging prompts:
|
48 |
Story Origin:
|
49 |
-
Asks users to choose between personal anecdotes or adapting a well-known
|
50 |
|
51 |
Story Use Case:
|
52 |
Asks users to define the purpose of building a story (e.g., profile story, for social media content).
|
@@ -120,6 +122,8 @@ Note that the final story should include twist, turns and events that make it re
|
|
120 |
|
121 |
USER_ICON = "man.png"
|
122 |
AI_ICON = "Kommuneity.png"
|
|
|
|
|
123 |
START_MSG = "I want to create a story 😊"
|
124 |
|
125 |
st.set_page_config(
|
@@ -144,13 +148,53 @@ pprint("\n")
|
|
144 |
|
145 |
|
146 |
def __isInvalidResponse(response: str):
|
|
|
147 |
if len(re.findall(r'\n[a-z]', response)) > 3:
|
148 |
return True
|
149 |
|
|
|
|
|
|
|
|
|
|
|
150 |
if ('\n{\n "options"' in response) and (JSON_SEPARATOR not in response):
|
151 |
return True
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
def __resetButtonState():
|
155 |
st.session_state["buttonValue"] = ""
|
156 |
|
@@ -171,7 +215,10 @@ if "startMsg" not in st.session_state:
|
|
171 |
|
172 |
def predict(prompt):
|
173 |
historyFormatted = [{"role": "system", "content": SYSTEM_MSG}]
|
174 |
-
historyFormatted.extend(
|
|
|
|
|
|
|
175 |
historyFormatted.append({"role": "user", "content": prompt })
|
176 |
contextSize = countTokens(str(historyFormatted))
|
177 |
pprint(f"{contextSize=}")
|
@@ -192,18 +239,37 @@ def predict(prompt):
|
|
192 |
yield chunkContent
|
193 |
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
st.title("Kommuneity Story Creator 📖")
|
196 |
-
if not st.session_state.startMsg:
|
197 |
st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
|
198 |
|
199 |
for message in st.session_state.messages:
|
200 |
role = message["role"]
|
201 |
content = message["content"]
|
|
|
202 |
avatar = AI_ICON if role == "assistant" else USER_ICON
|
203 |
with st.chat_message(role, avatar=avatar):
|
204 |
st.markdown(content)
|
|
|
|
|
205 |
|
206 |
-
if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_state
|
207 |
__resetButtonState()
|
208 |
__setStartMsg("")
|
209 |
|
@@ -213,27 +279,30 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
|
|
213 |
st.session_state.messages.append({"role": "user", "content": prompt })
|
214 |
|
215 |
with st.chat_message("assistant", avatar=AI_ICON):
|
216 |
-
|
217 |
|
218 |
-
def
|
219 |
response = ""
|
|
|
|
|
220 |
responseGenerator = predict(prompt)
|
221 |
|
222 |
for chunk in responseGenerator:
|
223 |
response += chunk
|
224 |
if __isInvalidResponse(response):
|
|
|
225 |
return
|
226 |
|
227 |
if JSON_SEPARATOR not in response:
|
228 |
-
|
229 |
|
230 |
return response
|
231 |
|
232 |
-
response =
|
233 |
while not response:
|
234 |
pprint("Empty response. Retrying..")
|
235 |
time.sleep(0.5)
|
236 |
-
response =
|
237 |
|
238 |
pprint(f"{response=}")
|
239 |
|
@@ -242,9 +311,22 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
|
|
242 |
pprint(f"Selected: {optionLabel}")
|
243 |
|
244 |
responseParts = response.split(JSON_SEPARATOR)
|
|
|
|
|
245 |
if len(responseParts) > 1:
|
246 |
[response, jsonStr] = responseParts
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
try:
|
249 |
json.loads(jsonStr)
|
250 |
jsonObj = json.loads(jsonStr)
|
@@ -260,4 +342,8 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
|
|
260 |
except Exception as e:
|
261 |
pprint(e)
|
262 |
|
263 |
-
|
|
|
|
|
|
|
|
|
|
5 |
import time
|
6 |
import json
|
7 |
import re
|
8 |
+
from typing import List
|
9 |
from transformers import AutoTokenizer
|
10 |
+
from gradio_client import Client
|
11 |
|
12 |
from dotenv import load_dotenv
|
13 |
load_dotenv()
|
|
|
48 |
# Tier 1: Story Creation
|
49 |
You initiate the storytelling process through a series of engaging prompts:
|
50 |
Story Origin:
|
51 |
+
Asks users to choose between personal anecdotes or adapting a well-known story (creating a story database here of well-known finctional stories to choose from).
|
52 |
|
53 |
Story Use Case:
|
54 |
Asks users to define the purpose of building a story (e.g., profile story, for social media content).
|
|
|
122 |
|
123 |
USER_ICON = "man.png"
|
124 |
AI_ICON = "Kommuneity.png"
|
125 |
+
IMAGE_LOADER = "ripple.svg"
|
126 |
+
TEXT_LOADER = "balls.svg"
|
127 |
START_MSG = "I want to create a story 😊"
|
128 |
|
129 |
st.set_page_config(
|
|
|
148 |
|
149 |
|
150 |
def __isInvalidResponse(response: str):
|
151 |
+
# new line followed by small case char
|
152 |
if len(re.findall(r'\n[a-z]', response)) > 3:
|
153 |
return True
|
154 |
|
155 |
+
# lot of repeating words
|
156 |
+
if re.findall(r'\b(\w+)(\s+\1){2,}\b', response) == 0:
|
157 |
+
return True
|
158 |
+
|
159 |
+
# json response without json separator
|
160 |
if ('\n{\n "options"' in response) and (JSON_SEPARATOR not in response):
|
161 |
return True
|
162 |
|
163 |
|
164 |
+
def __matchingKeywordsCount(keywords: List[str], text: str):
|
165 |
+
return sum([
|
166 |
+
1 if keyword in text else 0
|
167 |
+
for keyword in keywords
|
168 |
+
])
|
169 |
+
|
170 |
+
|
171 |
+
def __isStringNumber(s: str) -> bool:
|
172 |
+
try:
|
173 |
+
float(s)
|
174 |
+
return True
|
175 |
+
except ValueError:
|
176 |
+
return False
|
177 |
+
|
178 |
+
|
179 |
+
def __getImageGenerationPrompt(prompt: str, response: str):
|
180 |
+
responseLower = response.lower()
|
181 |
+
if (
|
182 |
+
__matchingKeywordsCount(
|
183 |
+
["adapt", "profile", "social media", "purpose", "use case"],
|
184 |
+
responseLower
|
185 |
+
) > 2
|
186 |
+
and not __isStringNumber(prompt)
|
187 |
+
and prompt.lower() in responseLower
|
188 |
+
):
|
189 |
+
return f'a scene from (({prompt})). Include main character'
|
190 |
+
|
191 |
+
if __matchingKeywordsCount(
|
192 |
+
["Tier 2", "Tier-2"],
|
193 |
+
response
|
194 |
+
) > 0:
|
195 |
+
return f"photo of a scene from this text: {response}"
|
196 |
+
|
197 |
+
|
198 |
def __resetButtonState():
|
199 |
st.session_state["buttonValue"] = ""
|
200 |
|
|
|
215 |
|
216 |
def predict(prompt):
|
217 |
historyFormatted = [{"role": "system", "content": SYSTEM_MSG}]
|
218 |
+
historyFormatted.extend([
|
219 |
+
{"role": message["role"], "content": message["content"]}
|
220 |
+
for message in st.session_state.messages
|
221 |
+
])
|
222 |
historyFormatted.append({"role": "user", "content": prompt })
|
223 |
contextSize = countTokens(str(historyFormatted))
|
224 |
pprint(f"{contextSize=}")
|
|
|
239 |
yield chunkContent
|
240 |
|
241 |
|
242 |
+
def generateImage(prompt: str):
|
243 |
+
pprint(f"imagePrompt={prompt}")
|
244 |
+
client = Client("black-forest-labs/FLUX.1-schnell")
|
245 |
+
result = client.predict(
|
246 |
+
prompt=prompt,
|
247 |
+
seed=0,
|
248 |
+
randomize_seed=True,
|
249 |
+
width=1152,
|
250 |
+
height=896,
|
251 |
+
num_inference_steps=4,
|
252 |
+
api_name="/infer"
|
253 |
+
)
|
254 |
+
pprint(f"imageResult={result}")
|
255 |
+
return result
|
256 |
+
|
257 |
+
|
258 |
st.title("Kommuneity Story Creator 📖")
|
259 |
+
if not (st.session_state["buttonValue"] or st.session_state["startMsg"]):
|
260 |
st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
|
261 |
|
262 |
for message in st.session_state.messages:
|
263 |
role = message["role"]
|
264 |
content = message["content"]
|
265 |
+
imagePath = message.get("image")
|
266 |
avatar = AI_ICON if role == "assistant" else USER_ICON
|
267 |
with st.chat_message(role, avatar=avatar):
|
268 |
st.markdown(content)
|
269 |
+
if imagePath:
|
270 |
+
st.image(imagePath)
|
271 |
|
272 |
+
if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_state["startMsg"]):
|
273 |
__resetButtonState()
|
274 |
__setStartMsg("")
|
275 |
|
|
|
279 |
st.session_state.messages.append({"role": "user", "content": prompt })
|
280 |
|
281 |
with st.chat_message("assistant", avatar=AI_ICON):
|
282 |
+
responseContainer = st.empty()
|
283 |
|
284 |
+
def __printAndGetResponse():
|
285 |
response = ""
|
286 |
+
# responseContainer.markdown(".....")
|
287 |
+
responseContainer.image(TEXT_LOADER)
|
288 |
responseGenerator = predict(prompt)
|
289 |
|
290 |
for chunk in responseGenerator:
|
291 |
response += chunk
|
292 |
if __isInvalidResponse(response):
|
293 |
+
pprint(f"{response=}")
|
294 |
return
|
295 |
|
296 |
if JSON_SEPARATOR not in response:
|
297 |
+
responseContainer.markdown(response)
|
298 |
|
299 |
return response
|
300 |
|
301 |
+
response = __printAndGetResponse()
|
302 |
while not response:
|
303 |
pprint("Empty response. Retrying..")
|
304 |
time.sleep(0.5)
|
305 |
+
response = __printAndGetResponse()
|
306 |
|
307 |
pprint(f"{response=}")
|
308 |
|
|
|
311 |
pprint(f"Selected: {optionLabel}")
|
312 |
|
313 |
responseParts = response.split(JSON_SEPARATOR)
|
314 |
+
|
315 |
+
jsonStr = None
|
316 |
if len(responseParts) > 1:
|
317 |
[response, jsonStr] = responseParts
|
318 |
|
319 |
+
imagePath = None
|
320 |
+
try:
|
321 |
+
imagePrompt = __getImageGenerationPrompt(prompt, response)
|
322 |
+
if imagePrompt:
|
323 |
+
imageContainer = st.empty().image(IMAGE_LOADER)
|
324 |
+
(imagePath, seed) = generateImage(imagePrompt)
|
325 |
+
imageContainer.image(imagePath)
|
326 |
+
except Exception as e:
|
327 |
+
pprint(e)
|
328 |
+
|
329 |
+
if jsonStr:
|
330 |
try:
|
331 |
json.loads(jsonStr)
|
332 |
jsonObj = json.loads(jsonStr)
|
|
|
342 |
except Exception as e:
|
343 |
pprint(e)
|
344 |
|
345 |
+
st.session_state.messages.append({
|
346 |
+
"role": "assistant",
|
347 |
+
"content": response,
|
348 |
+
"image": imagePath,
|
349 |
+
})
|
balls.svg
ADDED
bars_loader.svg
ADDED
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
python-dotenv
|
2 |
groq
|
3 |
-
transformers
|
|
|
|
1 |
python-dotenv
|
2 |
groq
|
3 |
+
transformers
|
4 |
+
gradio_client
|
ripple.svg
ADDED