Ashhar commited on
Commit
eb9c1db
1 Parent(s): ced155a

flux client to replicate + words count

Browse files
Files changed (4) hide show
  1. app.py +43 -5
  2. constants.py +1 -0
  3. requirements.txt +1 -1
  4. utils.py +12 -0
app.py CHANGED
@@ -6,7 +6,7 @@ import json
6
  import re
7
  from typing import List, Literal, TypedDict, Tuple
8
  from transformers import AutoTokenizer
9
- from gradio_client import Client
10
  from openai import OpenAI
11
  import anthropic
12
  from groq import Groq
@@ -41,7 +41,7 @@ MODEL_CONFIG: dict[ModelType, ModelConfig] = {
41
  },
42
  "CLAUDE": {
43
  "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
44
- "model": "claude-3-5-sonnet-20240620",
45
  "max_context": 128000,
46
  "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
47
  },
@@ -275,7 +275,8 @@ def __predict():
275
  yield f"{C.EXCEPTION_KEYWORD} | {e}"
276
 
277
 
278
- def __generateImage(prompt: str):
 
279
  fluxClient = Client(
280
  "black-forest-labs/FLUX.1-schnell",
281
  os.environ.get("HF_FLUX_CLIENT_TOKEN")
@@ -293,6 +294,28 @@ def __generateImage(prompt: str):
293
  return result
294
 
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  def __paintImageIfApplicable(
297
  imageContainer: DeltaGenerator,
298
  prompt: str,
@@ -312,7 +335,7 @@ def __paintImageIfApplicable(
312
  unsafe_allow_html=True
313
  )
314
  imgContainer.image(C.IMAGE_LOADER)
315
- (imagePath, seed) = __generateImage(imagePrompt)
316
  imageContainer.image(imagePath)
317
  except Exception as e:
318
  U.pprint(e)
@@ -335,6 +358,19 @@ def __showButtons(options: list):
335
  )
336
 
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def __resetButtonState():
339
  st.session_state.buttonValue = ""
340
 
@@ -494,12 +530,14 @@ def mainApp():
494
  elif action:
495
  U.pprint(f"{action=}")
496
  if action == "SHOW_STORY_DATABASE":
497
- time.sleep(0.5)
498
  st.switch_page("pages/popular-stories.py")
499
  # st.code(jsonStr, language="json")
500
  except Exception as e:
501
  U.pprint(e)
502
 
 
 
503
  saveLatestActivity()
504
 
505
 
 
6
  import re
7
  from typing import List, Literal, TypedDict, Tuple
8
  from transformers import AutoTokenizer
9
+ import replicate
10
  from openai import OpenAI
11
  import anthropic
12
  from groq import Groq
 
41
  },
42
  "CLAUDE": {
43
  "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
44
+ "model": "claude-3-5-sonnet-20241022",
45
  "max_context": 128000,
46
  "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
47
  },
 
275
  yield f"{C.EXCEPTION_KEYWORD} | {e}"
276
 
277
 
278
+ def __generateImageHF(prompt: str):
279
+ from gradio_client import Client
280
  fluxClient = Client(
281
  "black-forest-labs/FLUX.1-schnell",
282
  os.environ.get("HF_FLUX_CLIENT_TOKEN")
 
294
  return result
295
 
296
 
297
+ def __generateImage(prompt: str):
298
+ fluxClient = replicate.Client(api_token=os.environ.get("REPLICATE_API_KEY"))
299
+ result = fluxClient.run(
300
+ "black-forest-labs/flux-schnell",
301
+ input={
302
+ "prompt": prompt,
303
+ "seed": 0,
304
+ "go_fast": False,
305
+ "megapixels": "1",
306
+ "num_outputs": 1,
307
+ "aspect_ratio": "4:3",
308
+ "output_format": "webp",
309
+ "output_quality": 80,
310
+ "num_inference_steps": 4
311
+ }
312
+ )
313
+ if result:
314
+ result = result[0]
315
+ U.pprint(f"imageResult={result}")
316
+ return result
317
+
318
+
319
  def __paintImageIfApplicable(
320
  imageContainer: DeltaGenerator,
321
  prompt: str,
 
335
  unsafe_allow_html=True
336
  )
337
  imgContainer.image(C.IMAGE_LOADER)
338
+ imagePath = __generateImage(imagePrompt)
339
  imageContainer.image(imagePath)
340
  except Exception as e:
341
  U.pprint(e)
 
358
  )
359
 
360
 
361
+ def __showWordsCount(response: str):
362
+ wordsCount = len(response.split())
363
+ countClass = "crossed-limit" if wordsCount > C.WORDS_LIMIT else ""
364
+ st.markdown(
365
+ f"""
366
+ <div class="words-count {countClass}">
367
+ {wordsCount} words
368
+ </div>
369
+ """,
370
+ unsafe_allow_html=True
371
+ )
372
+
373
+
374
  def __resetButtonState():
375
  st.session_state.buttonValue = ""
376
 
 
530
  elif action:
531
  U.pprint(f"{action=}")
532
  if action == "SHOW_STORY_DATABASE":
533
+ time.sleep(1)
534
  st.switch_page("pages/popular-stories.py")
535
  # st.code(jsonStr, language="json")
536
  except Exception as e:
537
  U.pprint(e)
538
 
539
+ __showWordsCount(response)
540
+
541
  saveLatestActivity()
542
 
543
 
constants.py CHANGED
@@ -1,6 +1,7 @@
1
  JSON_SEPARATOR = ">>>>"
2
  EXCEPTION_KEYWORD = "<<EXCEPTION>>"
3
  BOOKING_LINK = "https://calendly.com"
 
4
 
5
  SYSTEM_MSG = f"""
6
  => Context:
 
1
  JSON_SEPARATOR = ">>>>"
2
  EXCEPTION_KEYWORD = "<<EXCEPTION>>"
3
  BOOKING_LINK = "https://calendly.com"
4
+ WORDS_LIMIT = 300
5
 
6
  SYSTEM_MSG = f"""
7
  => Context:
requirements.txt CHANGED
@@ -7,4 +7,4 @@ anthropic
7
  supabase
8
  descope
9
  cloudinary
10
-
 
7
  supabase
8
  descope
9
  cloudinary
10
+ replicate
utils.py CHANGED
@@ -125,6 +125,18 @@ def applyCommonStyles():
125
  margin-bottom: 1rem;
126
  }
127
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  </style>
129
  """,
130
  unsafe_allow_html=True
 
125
  margin-bottom: 1rem;
126
  }
127
 
128
+ div.words-count {
129
+ font-size: 0.7rem !important;
130
+ color: green;
131
+ margin-top: -1rem;
132
+ text-align: right;
133
+ }
134
+
135
+ div.words-count.crossed-limit {
136
+ color: red;
137
+ font-size: 1rem !important;
138
+ }
139
+
140
  </style>
141
  """,
142
  unsafe_allow_html=True