patrickvonplaten Wauplin HF staff commited on
Commit
b3ccbc7
1 Parent(s): 859a4bd

[Feature] Add User history (#4)

Browse files

- [WIP] (79e495da9831d76276d041da2ff4b866e61ff9bf)
- Update requirements.txt (7cad1d75d372c67c3780ada03a362a5b8715eafa)
- Update README.md (bfc94ddc82953279eea44615bb0d963e229a28a4)


Co-authored-by: Lucain Pouget <Wauplin@users.noreply.huggingface.co>

Files changed (3) hide show
  1. README.md +2 -0
  2. app.py +11 -6
  3. requirements.txt +2 -0
README.md CHANGED
@@ -9,6 +9,8 @@ app_file: app.py
9
  license: mit
10
  pinned: false
11
  suggested_hardware: a10g-small
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  license: mit
10
  pinned: false
11
  suggested_hardware: a10g-small
12
+ suggested_storage: small
13
+ hf_oauth: true
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -17,7 +17,7 @@ import os
17
  import torch
18
  from tqdm import tqdm
19
  from safetensors.torch import load_file
20
- from huggingface_hub import hf_hub_download
21
 
22
  from concurrent.futures import ThreadPoolExecutor
23
  import uuid
@@ -44,15 +44,16 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
44
  seed = random.randint(0, MAX_SEED)
45
  return seed
46
 
47
- def save_image(img):
48
  unique_name = str(uuid.uuid4()) + '.png'
49
  img.save(unique_name)
 
50
  return unique_name
51
 
52
- def save_images(image_array):
53
  paths = []
54
  with ThreadPoolExecutor() as executor:
55
- paths = list(executor.map(save_image, image_array))
56
  return paths
57
 
58
  def generate(
@@ -64,7 +65,8 @@ def generate(
64
  num_inference_steps: int = 4,
65
  num_images: int = 4,
66
  randomize_seed: bool = False,
67
- progress = gr.Progress(track_tqdm=True)
 
68
  ) -> PIL.Image.Image:
69
  seed = randomize_seed_fn(seed, randomize_seed)
70
  torch.manual_seed(seed)
@@ -79,7 +81,7 @@ def generate(
79
  lcm_origin_steps=50,
80
  output_type="pil",
81
  ).images
82
- paths = save_images(result)
83
  print(time.time() - start_time)
84
  return paths, seed
85
 
@@ -160,6 +162,9 @@ with gr.Blocks(css="style.css") as demo:
160
  visible=False,
161
  )
162
 
 
 
 
163
  gr.Examples(
164
  examples=examples,
165
  inputs=prompt,
 
17
  import torch
18
  from tqdm import tqdm
19
  from safetensors.torch import load_file
20
+ import gradio_user_history as gr_user_history
21
 
22
  from concurrent.futures import ThreadPoolExecutor
23
  import uuid
 
44
  seed = random.randint(0, MAX_SEED)
45
  return seed
46
 
47
+ def save_image(img, profile: gr.OAuthProfile | None, metadata: dict):
48
  unique_name = str(uuid.uuid4()) + '.png'
49
  img.save(unique_name)
50
+ gr_user_history.save_image(label=metadata["prompt"], image=image, profile=profile, metadata=metadata)
51
  return unique_name
52
 
53
+ def save_images(image_array, profile: gr.OAuthProfile | None, metadata: dict):
54
  paths = []
55
  with ThreadPoolExecutor() as executor:
56
+ paths = list(executor.map(save_image, image_array, [profile]*len(image_array), [metadata]*len(image_array)))
57
  return paths
58
 
59
  def generate(
 
65
  num_inference_steps: int = 4,
66
  num_images: int = 4,
67
  randomize_seed: bool = False,
68
+ progress = gr.Progress(track_tqdm=True),
69
+ profile: gr.OAuthProfile | None = None,
70
  ) -> PIL.Image.Image:
71
  seed = randomize_seed_fn(seed, randomize_seed)
72
  torch.manual_seed(seed)
 
81
  lcm_origin_steps=50,
82
  output_type="pil",
83
  ).images
84
+ paths = save_images(result, profile, metadata={"prompt": prompt, "seed": seed, "width": width, "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps})
85
  print(time.time() - start_time)
86
  return paths, seed
87
 
 
162
  visible=False,
163
  )
164
 
165
+ with gr.Accordion("Past generations", open=False):
166
+ gr_user_history.render()
167
+
168
  gr.Examples(
169
  examples=examples,
170
  inputs=prompt,
requirements.txt CHANGED
@@ -6,3 +6,5 @@ Pillow
6
  torch==2.0.1
7
  transformers
8
  opencv-python
 
 
 
6
  torch==2.0.1
7
  transformers
8
  opencv-python
9
+
10
+ git+https://huggingface.co/spaces/Wauplin/gradio-user-history