svjack commited on
Commit
6f13693
1 Parent(s): 2ff9809

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -10
app.py CHANGED
@@ -7,6 +7,31 @@ from PIL import Image
7
  from huggingface_hub import InferenceApi, InferenceClient
8
  from datasets import load_dataset
9
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  '''
12
  dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
@@ -21,26 +46,25 @@ def get_samples():
21
  prompt_list = prompt_df.sample(n = 10)["Prompt"].map(lambda x: x).values.tolist()
22
  return prompt_list
23
 
24
- def update_models():
25
  client = InferenceClient()
26
  models = client.list_deployed_models()
27
  list_models = models["text-to-image"]
 
 
 
 
28
  return gr.Dropdown.update(choices=list_models)
29
 
30
  def update_prompts():
31
  return gr.Dropdown.update(choices=get_samples())
32
 
33
- def get_params(request: gr.Request):
34
  params = request.query_params
35
  ip = request.client.host
36
  req = {"params": params,
37
  "ip": ip}
38
- return update_models(), update_prompts()
39
-
40
- client = InferenceClient()
41
- models = client.list_deployed_models()
42
- list_models = models["text-to-image"]
43
- list_prompts = get_samples()
44
 
45
  '''
46
  list_models = [
@@ -239,6 +263,12 @@ with gr.Blocks(css=css) as demo:
239
  with gr.Row(elem_id="prompt-container"):
240
  with gr.Column():
241
  btn_refresh = gr.Button(value="Click to get current deployed models and newly Prompt candidates")
 
 
 
 
 
 
242
  #btn_refresh.click(None, js="window.location.reload()")
243
  current_model = gr.Dropdown(label="Current Model", choices=list_models, value=DEFAULT_MODEL,
244
  info = "default model: {}".format(DEFAULT_MODEL)
@@ -276,9 +306,11 @@ with gr.Blocks(css=css) as demo:
276
 
277
  text_button.click(generate_txt2img, inputs=[current_model, text_prompt, negative_prompt, image_style], outputs=image_output)
278
  select_button.click(generate_txt2img, inputs=[current_model, select_prompt, negative_prompt, image_style], outputs=image_output)
279
- btn_refresh.click(update_models, None, current_model)
280
  btn_refresh.click(update_prompts, None, select_prompt)
 
 
281
 
282
- demo.load(get_params, None, [current_model, select_prompt])
283
 
284
  demo.launch(show_api=False)
 
7
  from huggingface_hub import InferenceApi, InferenceClient
8
  from datasets import load_dataset
9
  import pandas as pd
10
+ import re
11
+
12
+ def rank_score(repo_str):
13
+ p_list = re.findall(r"[-_xlv\d]+" ,repo_str.split("/")[-1])
14
+ xl_in_str = any(map(lambda x: "xl" in x, p_list))
15
+ v_in_str = any(map(lambda x: "v" in x and
16
+ any(map(lambda y:
17
+ any(map(lambda z: y.startswith(z), "0123456789"))
18
+ ,x.split("v")))
19
+ , p_list))
20
+ stable_in_str = repo_str.split("/")[-1].lower().startswith("stable")
21
+ score = sum(map(lambda t2: t2[0] * t2[1] ,(zip(*[[stable_in_str, xl_in_str, v_in_str], [1000, 100, 10]]))))
22
+ #return p_list, xl_in_str, v_in_str, stable_in_str, score
23
+ return score
24
+
25
+ def shorten_by(repo_list, by = None):
26
+ if by == "user":
27
+ return sorted(
28
+ pd.DataFrame(pd.Series(repo_list).map(lambda x: (x.split("/")[0], x)).values.tolist()).groupby(0)[1].apply(list).map(lambda x:
29
+ sorted(x, key = rank_score, reverse = True)[0]).values.tolist(),
30
+ key = rank_score, reverse = True
31
+ )
32
+ if by == "model":
33
+ return sorted(repo_list, key = lambda x: rank_score(x), reverse = True)
34
+ return repo_list
35
 
36
  '''
37
  dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
 
46
  prompt_list = prompt_df.sample(n = 10)["Prompt"].map(lambda x: x).values.tolist()
47
  return prompt_list
48
 
49
+ def update_models(models_rank_by = "model"):
50
  client = InferenceClient()
51
  models = client.list_deployed_models()
52
  list_models = models["text-to-image"]
53
+ if hasattr(models_rank_by, "value"):
54
+ list_models = shorten_by(list_models, models_rank_by.value)
55
+ else:
56
+ list_models = shorten_by(list_models, models_rank_by)
57
  return gr.Dropdown.update(choices=list_models)
58
 
59
  def update_prompts():
60
  return gr.Dropdown.update(choices=get_samples())
61
 
62
+ def get_params(request: gr.Request, models_rank_by):
63
  params = request.query_params
64
  ip = request.client.host
65
  req = {"params": params,
66
  "ip": ip}
67
+ return update_models(models_rank_by), update_prompts()
 
 
 
 
 
68
 
69
  '''
70
  list_models = [
 
263
  with gr.Row(elem_id="prompt-container"):
264
  with gr.Column():
265
  btn_refresh = gr.Button(value="Click to get current deployed models and newly Prompt candidates")
266
+ models_rank_by = gr.Radio(choices=["model", "user"],
267
+ value="model", label="Models ranked by", elem_id="rank_radio")
268
+
269
+ list_models = update_models(models_rank_by)
270
+ list_prompts = get_samples()
271
+
272
  #btn_refresh.click(None, js="window.location.reload()")
273
  current_model = gr.Dropdown(label="Current Model", choices=list_models, value=DEFAULT_MODEL,
274
  info = "default model: {}".format(DEFAULT_MODEL)
 
306
 
307
  text_button.click(generate_txt2img, inputs=[current_model, text_prompt, negative_prompt, image_style], outputs=image_output)
308
  select_button.click(generate_txt2img, inputs=[current_model, select_prompt, negative_prompt, image_style], outputs=image_output)
309
+ btn_refresh.click(update_models, models_rank_by, current_model)
310
  btn_refresh.click(update_prompts, None, select_prompt)
311
+
312
+ models_rank_by.change(update_models, models_rank_by, current_model)
313
 
314
+ demo.load(get_params, models_rank_by, [current_model, select_prompt])
315
 
316
  demo.launch(show_api=False)