Wauplin HF staff commited on
Commit
4d84406
1 Parent(s): 525b190

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +3 -1
  2. app.py +45 -21
  3. gallery_history.py +121 -0
  4. requirements.txt +2 -1
README.md CHANGED
@@ -4,10 +4,12 @@ emoji: 🌍
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.44.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ duplicated_from: warp-ai/Wuerstchen
12
+ hf_oauth: true
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -9,7 +9,10 @@ from diffusers.utils import numpy_to_pil
9
  from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
  from previewer.modules import Previewer
12
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
 
 
 
13
 
14
  DESCRIPTION = "# Würstchen"
15
  DESCRIPTION += "\n<p style=\"text-align: center\"><a href='https://huggingface.co/warp-ai/wuerstchen' target='_blank'>Würstchen</a> is a new fast and efficient high resolution text-to-image architecture and model</p>"
@@ -26,8 +29,12 @@ PREVIEW_IMAGES = True
26
  dtype = torch.float16
27
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
  if torch.cuda.is_available():
29
- prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-ai/wuerstchen-prior", torch_dtype=dtype)
30
- decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained("warp-ai/wuerstchen", torch_dtype=dtype)
 
 
 
 
31
  if ENABLE_CPU_OFFLOAD:
32
  prior_pipeline.enable_model_cpu_offload()
33
  decoder_pipeline.enable_model_cpu_offload()
@@ -36,18 +43,27 @@ if torch.cuda.is_available():
36
  decoder_pipeline.to(device)
37
 
38
  if USE_TORCH_COMPILE:
39
- prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
40
- decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
41
-
 
 
 
 
42
  if PREVIEW_IMAGES:
43
  previewer = Previewer()
44
- previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
 
 
 
 
45
  previewer.eval().requires_grad_(False).to(device).to(dtype)
46
 
47
  def callback_prior(i, t, latents):
48
  output = previewer(latents)
49
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
50
  return output
 
51
  else:
52
  previewer = None
53
  callback_prior = None
@@ -96,7 +112,7 @@ def generate(
96
  if isinstance(r, list):
97
  yield r
98
  prior_output = r
99
-
100
  decoder_output = decoder_pipeline(
101
  image_embeddings=prior_output.image_embeddings,
102
  prompt=prompt,
@@ -209,19 +225,21 @@ with gr.Blocks(css="style.css") as demo:
209
  cache_examples=CACHE_EXAMPLES,
210
  )
211
 
 
 
212
  inputs = [
213
- prompt,
214
- negative_prompt,
215
- seed,
216
- width,
217
- height,
218
- prior_num_inference_steps,
219
- # prior_timesteps,
220
- prior_guidance_scale,
221
- decoder_num_inference_steps,
222
- # decoder_timesteps,
223
- decoder_guidance_scale,
224
- num_images_per_prompt,
225
  ]
226
  prompt.submit(
227
  fn=randomize_seed_fn,
@@ -234,6 +252,8 @@ with gr.Blocks(css="style.css") as demo:
234
  inputs=inputs,
235
  outputs=result,
236
  api_name="run",
 
 
237
  )
238
  negative_prompt.submit(
239
  fn=randomize_seed_fn,
@@ -246,6 +266,8 @@ with gr.Blocks(css="style.css") as demo:
246
  inputs=inputs,
247
  outputs=result,
248
  api_name=False,
 
 
249
  )
250
  run_button.click(
251
  fn=randomize_seed_fn,
@@ -258,7 +280,9 @@ with gr.Blocks(css="style.css") as demo:
258
  inputs=inputs,
259
  outputs=result,
260
  api_name=False,
 
 
261
  )
262
 
263
  if __name__ == "__main__":
264
- demo.queue(max_size=20).launch()
 
9
  from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
  from previewer.modules import Previewer
12
+
13
+ from gallery_history import fetch_gallery_history, show_gallery_history
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
 
17
  DESCRIPTION = "# Würstchen"
18
  DESCRIPTION += "\n<p style=\"text-align: center\"><a href='https://huggingface.co/warp-ai/wuerstchen' target='_blank'>Würstchen</a> is a new fast and efficient high resolution text-to-image architecture and model</p>"
 
29
  dtype = torch.float16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  if torch.cuda.is_available():
32
+ prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
33
+ "warp-ai/wuerstchen-prior", torch_dtype=dtype
34
+ )
35
+ decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
36
+ "warp-ai/wuerstchen", torch_dtype=dtype
37
+ )
38
  if ENABLE_CPU_OFFLOAD:
39
  prior_pipeline.enable_model_cpu_offload()
40
  decoder_pipeline.enable_model_cpu_offload()
 
43
  decoder_pipeline.to(device)
44
 
45
  if USE_TORCH_COMPILE:
46
+ prior_pipeline.prior = torch.compile(
47
+ prior_pipeline.prior, mode="reduce-overhead", fullgraph=True
48
+ )
49
+ decoder_pipeline.decoder = torch.compile(
50
+ decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True
51
+ )
52
+
53
  if PREVIEW_IMAGES:
54
  previewer = Previewer()
55
+ previewer.load_state_dict(
56
+ torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")[
57
+ "state_dict"
58
+ ]
59
+ )
60
  previewer.eval().requires_grad_(False).to(device).to(dtype)
61
 
62
  def callback_prior(i, t, latents):
63
  output = previewer(latents)
64
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
65
  return output
66
+
67
  else:
68
  previewer = None
69
  callback_prior = None
 
112
  if isinstance(r, list):
113
  yield r
114
  prior_output = r
115
+
116
  decoder_output = decoder_pipeline(
117
  image_embeddings=prior_output.image_embeddings,
118
  prompt=prompt,
 
225
  cache_examples=CACHE_EXAMPLES,
226
  )
227
 
228
+ history = show_gallery_history()
229
+
230
  inputs = [
231
+ prompt,
232
+ negative_prompt,
233
+ seed,
234
+ width,
235
+ height,
236
+ prior_num_inference_steps,
237
+ # prior_timesteps,
238
+ prior_guidance_scale,
239
+ decoder_num_inference_steps,
240
+ # decoder_timesteps,
241
+ decoder_guidance_scale,
242
+ num_images_per_prompt,
243
  ]
244
  prompt.submit(
245
  fn=randomize_seed_fn,
 
252
  inputs=inputs,
253
  outputs=result,
254
  api_name="run",
255
+ ).then(
256
+ fn=fetch_gallery_history, inputs=[prompt, result], outputs=history
257
  )
258
  negative_prompt.submit(
259
  fn=randomize_seed_fn,
 
266
  inputs=inputs,
267
  outputs=result,
268
  api_name=False,
269
+ ).then(
270
+ fn=fetch_gallery_history, inputs=[prompt, result], outputs=history
271
  )
272
  run_button.click(
273
  fn=randomize_seed_fn,
 
280
  inputs=inputs,
281
  outputs=result,
282
  api_name=False,
283
+ ).then(
284
+ fn=fetch_gallery_history, inputs=[prompt, result], outputs=history
285
  )
286
 
287
  if __name__ == "__main__":
288
+ demo.queue(max_size=20).launch()
gallery_history.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ How to use:
3
+ 1. Create a Space with a Persistent Storage attached. Filesystem will be available under `/data`.
4
+ 2. Add `hf_oauth: true` to the Space metadata (README.md). Make sure to have Gradio>=3.41.0 configured.
5
+ 3. Add `HISTORY_FOLDER` as a Space variable (example. `"/data/history"`).
6
+ 4. Add `filelock` as dependency in `requirements.txt`.
7
+ 5. Add history gallery to your Gradio app:
8
+ a. Add imports: `from gallery_history import fetch_gallery_history, show_gallery_history`
9
+ a. Add `history = show_gallery_history()` within `gr.Blocks` context.
10
+ b. Add `.then(fn=fetch_gallery_history, inputs=[prompt, result], outputs=history)` on the generate event.
11
+ """
12
+ import json
13
+ import os
14
+ import shutil
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple
17
+ from uuid import uuid4
18
+
19
+ import gradio as gr
20
+ from filelock import FileLock
21
+
22
+ _folder = os.environ.get("HISTORY_FOLDER")
23
+ if _folder is None:
24
+ print(
25
+ "'HISTORY_FOLDER' environment variable not set. User history will be saved "
26
+ "locally and will be lost when the Space instance is restarted."
27
+ )
28
+ _folder = Path(__file__).parent / "history"
29
+ HISTORY_FOLDER_PATH = Path(_folder)
30
+
31
+ IMAGES_FOLDER_PATH = HISTORY_FOLDER_PATH / "images"
32
+ IMAGES_FOLDER_PATH.mkdir(parents=True, exist_ok=True)
33
+
34
+
35
+ def show_gallery_history():
36
+ gr.Markdown("## Past images\n\nYou must be logged in to activate it.")
37
+ with gr.Column():
38
+ with gr.Row():
39
+ gr.LoginButton()
40
+ gr.LogoutButton()
41
+ gallery = gr.Gallery(
42
+ label="Past images",
43
+ show_label=True,
44
+ elem_id="gallery",
45
+ object_fit="contain",
46
+ height="auto",
47
+ preview=True,
48
+ show_share_button=True,
49
+ show_download_button=True,
50
+ )
51
+ gallery.attach_load_event(fetch_gallery_history, every=None)
52
+ return gallery
53
+
54
+
55
+ def fetch_gallery_history(
56
+ prompt: Optional[str] = None,
57
+ result: Optional[Dict] = None,
58
+ user: Optional[gr.OAuthProfile] = None,
59
+ ):
60
+ if user is None:
61
+ return []
62
+ try:
63
+ if prompt is not None and result is not None: # None values means no new images
64
+ return _update_user_history(
65
+ user["preferred_username"], [(item["name"], prompt) for item in result]
66
+ )
67
+ else:
68
+ return _read_user_history(user["preferred_username"])
69
+ except Exception as e:
70
+ raise gr.Error(f"Error while fetching history: {e}") from e
71
+
72
+
73
+ ####################
74
+ # Internal helpers #
75
+ ####################
76
+
77
+
78
+ def _read_user_history(username: str) -> List[Tuple[str, str]]:
79
+ """Return saved history for that user."""
80
+ with _user_lock(username):
81
+ path = _user_history_path(username)
82
+ if path.exists():
83
+ return json.loads(path.read_text())
84
+ return [] # No history yet
85
+
86
+
87
+ def _update_user_history(
88
+ username: str, new_images: List[Tuple[str, str]]
89
+ ) -> List[Tuple[str, str]]:
90
+ """Update history for that user and return it."""
91
+ with _user_lock(username):
92
+ # Read existing
93
+ path = _user_history_path(username)
94
+ if path.exists():
95
+ images = json.loads(path.read_text())
96
+ else:
97
+ images = [] # No history yet
98
+
99
+ # Copy images to persistent folder
100
+ for src_path, prompt in new_images:
101
+ images.append((_copy_image(src_path), prompt))
102
+
103
+ # Save and return
104
+ path.write_text(json.dumps(images))
105
+ return images
106
+
107
+
108
+ def _user_history_path(username: str) -> Path:
109
+ return HISTORY_FOLDER_PATH / f"{username}.json"
110
+
111
+
112
+ def _user_lock(username: str) -> FileLock:
113
+ """Ensure history is not corrupted if concurrent calls."""
114
+ return FileLock(f"{_user_history_path(username)}.lock")
115
+
116
+
117
+ def _copy_image(src: str) -> str:
118
+ """Copy image to the persistent storage."""
119
+ dst = IMAGES_FOLDER_PATH / f"{uuid4().hex}_{Path(src).name}" # keep file ext
120
+ shutil.copyfile(src, dst)
121
+ return str(dst)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ invisible-watermark==0.2.0
5
  Pillow==10.0.0
6
  torch==2.0.1
7
  transformers==4.32.1
8
- compel
 
 
5
  Pillow==10.0.0
6
  torch==2.0.1
7
  transformers==4.32.1
8
+ compel
9
+ filelock