Dominik Hintersdorf commited on
Commit
25a58c7
β€’
1 Parent(s): 22110aa

added error message when name does not match prompts

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -354,7 +354,15 @@ def get_majority_predictions(predictions: pd.Series, values_only=False, counts_o
354
 
355
  def on_submit_btn_click(model_name, true_name, prompts, images):
356
  # assert that the name is in the prompts
357
- assert prompts.iloc[0].str.contains(true_name).sum() == len(prompts.T)
 
 
 
 
 
 
 
 
358
 
359
  # calculate the image embeddings
360
  img_embeddings = calculate_image_embeddings(model_name, images)
@@ -542,8 +550,7 @@ with block as demo:
542
  with gr.Column():
543
  model_dd = gr.Dropdown(label="CLIP Model", choices=list(MODELS.keys()),
544
  value=list(MODELS.keys())[0])
545
- true_name = gr.Textbox(label='Name of Person (make sure it matches the prompts):', lines=1, value=DEFAULT_INITIAL_NAME,
546
- every=5)
547
  prompts = gr.Dataframe(
548
  value=[[x.format(DEFAULT_INITIAL_NAME) for x in PROMPTS]],
549
  label='Prompts Used (hold shift to scroll sideways):',
 
354
 
355
  def on_submit_btn_click(model_name, true_name, prompts, images):
356
  # assert that the name is in the prompts
357
+ if not prompts.iloc[0].str.contains(true_name).sum() == len(prompts.T):
358
+ return None, None, """<br>
359
+ <div class="error-message" style="background-color: #fce4e4; border: 1px solid #fcc2c3; padding: 20px 30px; border-radius: var(--radius-lg);">
360
+ <span class="error-text" style="color: #cc0033; font-weight: bold;">
361
+ The given name does not match the name in the prompts. Sometimes the UI is responding slow.
362
+ Please retype the name and check that it is inserted fully into the prompts.
363
+ </span>
364
+ </div>
365
+ """
366
 
367
  # calculate the image embeddings
368
  img_embeddings = calculate_image_embeddings(model_name, images)
 
550
  with gr.Column():
551
  model_dd = gr.Dropdown(label="CLIP Model", choices=list(MODELS.keys()),
552
  value=list(MODELS.keys())[0])
553
+ true_name = gr.Textbox(label='Name of Person (make sure it matches the prompts):', lines=1, value=DEFAULT_INITIAL_NAME)
 
554
  prompts = gr.Dataframe(
555
  value=[[x.format(DEFAULT_INITIAL_NAME) for x in PROMPTS]],
556
  label='Prompts Used (hold shift to scroll sideways):',