Wauplin HF staff commited on
Commit
79e495d
1 Parent(s): 859a4bd
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -17,7 +17,7 @@ import os
17
  import torch
18
  from tqdm import tqdm
19
  from safetensors.torch import load_file
20
- from huggingface_hub import hf_hub_download
21
 
22
  from concurrent.futures import ThreadPoolExecutor
23
  import uuid
@@ -44,15 +44,16 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
44
  seed = random.randint(0, MAX_SEED)
45
  return seed
46
 
47
- def save_image(img):
48
  unique_name = str(uuid.uuid4()) + '.png'
49
  img.save(unique_name)
 
50
  return unique_name
51
 
52
- def save_images(image_array):
53
  paths = []
54
  with ThreadPoolExecutor() as executor:
55
- paths = list(executor.map(save_image, image_array))
56
  return paths
57
 
58
  def generate(
@@ -64,7 +65,8 @@ def generate(
64
  num_inference_steps: int = 4,
65
  num_images: int = 4,
66
  randomize_seed: bool = False,
67
- progress = gr.Progress(track_tqdm=True)
 
68
  ) -> PIL.Image.Image:
69
  seed = randomize_seed_fn(seed, randomize_seed)
70
  torch.manual_seed(seed)
@@ -79,7 +81,7 @@ def generate(
79
  lcm_origin_steps=50,
80
  output_type="pil",
81
  ).images
82
- paths = save_images(result)
83
  print(time.time() - start_time)
84
  return paths, seed
85
 
@@ -160,6 +162,9 @@ with gr.Blocks(css="style.css") as demo:
160
  visible=False,
161
  )
162
 
 
 
 
163
  gr.Examples(
164
  examples=examples,
165
  inputs=prompt,
 
17
  import torch
18
  from tqdm import tqdm
19
  from safetensors.torch import load_file
20
+ import gradio_user_history as gr_user_history
21
 
22
  from concurrent.futures import ThreadPoolExecutor
23
  import uuid
 
44
  seed = random.randint(0, MAX_SEED)
45
  return seed
46
 
47
+ def save_image(img, profile: gr.OAuthProfile | None, metadata: dict):
48
  unique_name = str(uuid.uuid4()) + '.png'
49
  img.save(unique_name)
50
+ gr_user_history.save_image(label=metadata["prompt"], image=image, profile=profile, metadata=metadata)
51
  return unique_name
52
 
53
+ def save_images(image_array, profile: gr.OAuthProfile | None, metadata: dict):
54
  paths = []
55
  with ThreadPoolExecutor() as executor:
56
+ paths = list(executor.map(save_image, image_array, [profile]*len(image_array), [metadata]*len(image_array)))
57
  return paths
58
 
59
  def generate(
 
65
  num_inference_steps: int = 4,
66
  num_images: int = 4,
67
  randomize_seed: bool = False,
68
+ progress = gr.Progress(track_tqdm=True),
69
+ profile: gr.OAuthProfile | None = None,
70
  ) -> PIL.Image.Image:
71
  seed = randomize_seed_fn(seed, randomize_seed)
72
  torch.manual_seed(seed)
 
81
  lcm_origin_steps=50,
82
  output_type="pil",
83
  ).images
84
+ paths = save_images(result, profile, metadata={"prompt": prompt, "seed": seed, "width": width, "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps})
85
  print(time.time() - start_time)
86
  return paths, seed
87
 
 
162
  visible=False,
163
  )
164
 
165
+ with gr.Accordion("Past generations", open=False):
166
+ gr_user_history.render()
167
+
168
  gr.Examples(
169
  examples=examples,
170
  inputs=prompt,