multimodalart HF staff commited on
Commit
0963421
1 Parent(s): 4bacf14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -58
app.py CHANGED
@@ -8,7 +8,6 @@ from typing import List
8
  from diffusers.utils import numpy_to_pil
9
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
- import spaces
12
  from previewer.modules import Previewer
13
  import user_history
14
 
@@ -24,14 +23,12 @@ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
- PREVIEW_IMAGES = True
28
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  if torch.cuda.is_available():
32
- prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
33
- decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
34
-
35
  if ENABLE_CPU_OFFLOAD:
36
  prior_pipeline.enable_model_cpu_offload()
37
  decoder_pipeline.enable_model_cpu_offload()
@@ -43,19 +40,6 @@ if torch.cuda.is_available():
43
  prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
44
  decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
45
 
46
- if PREVIEW_IMAGES:
47
- previewer = Previewer()
48
- previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
- previewer.load_state_dict(previewer_state_dict)
50
- def callback_prior(i, t, latents):
51
- output = previewer(latents)
52
- output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
53
- return output
54
- callback_steps = 1
55
- else:
56
- previewer = None
57
- callback_prior = None
58
- callback_steps = None
59
  else:
60
  prior_pipeline = None
61
  decoder_pipeline = None
@@ -66,7 +50,6 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
66
  seed = random.randint(0, MAX_SEED)
67
  return seed
68
 
69
- @spaces.GPU
70
  def generate(
71
  prompt: str,
72
  negative_prompt: str = "",
@@ -82,12 +65,8 @@ def generate(
82
  num_images_per_prompt: int = 2,
83
  profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
85
- previewer.eval().requires_grad_(False).to(device).to(dtype)
86
- prior_pipeline.to(device)
87
- decoder_pipeline.to(device)
88
 
89
  generator = torch.Generator().manual_seed(seed)
90
- print("prior_num_inference_steps: ", prior_num_inference_steps)
91
  prior_output = prior_pipeline(
92
  prompt=prompt,
93
  height=height,
@@ -98,17 +77,8 @@ def generate(
98
  guidance_scale=prior_guidance_scale,
99
  num_images_per_prompt=num_images_per_prompt,
100
  generator=generator,
101
- callback=callback_prior,
102
- callback_steps=callback_steps
103
  )
104
 
105
- if PREVIEW_IMAGES:
106
- for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
107
- r = next(prior_output)
108
- if isinstance(r, list):
109
- yield r[0]
110
- prior_output = r
111
-
112
  decoder_output = decoder_pipeline(
113
  image_embeddings=prior_output.image_embeddings,
114
  prompt=prompt,
@@ -120,25 +90,7 @@ def generate(
120
  output_type="pil",
121
  ).images
122
 
123
- #Save images
124
- for image in decoder_output:
125
- user_history.save_image(
126
- profile=profile,
127
- image=image,
128
- label=prompt,
129
- metadata={
130
- "negative_prompt": negative_prompt,
131
- "seed": seed,
132
- "width": width,
133
- "height": height,
134
- "prior_guidance_scale": prior_guidance_scale,
135
- "decoder_num_inference_steps": decoder_num_inference_steps,
136
- "decoder_guidance_scale": decoder_guidance_scale,
137
- "num_images_per_prompt": num_images_per_prompt,
138
- },
139
- )
140
-
141
- yield decoder_output[0]
142
 
143
 
144
  examples = [
@@ -270,11 +222,8 @@ with gr.Blocks() as demo:
270
  api_name="run",
271
  )
272
 
273
- with gr.Blocks(css="style.css") as demo_with_history:
274
- with gr.Tab("App"):
275
- demo.render()
276
- with gr.Tab("Past generations"):
277
- user_history.render()
278
-
279
  if __name__ == "__main__":
280
- demo_with_history.queue(max_size=20).launch()
 
8
  from diffusers.utils import numpy_to_pil
9
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
 
11
  from previewer.modules import Previewer
12
  import user_history
13
 
 
23
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
24
  USE_TORCH_COMPILE = False
25
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
 
26
 
27
  dtype = torch.bfloat16
28
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
  if torch.cuda.is_available():
30
+ prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype).to(device)
31
+ decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype).to(device)
 
32
  if ENABLE_CPU_OFFLOAD:
33
  prior_pipeline.enable_model_cpu_offload()
34
  decoder_pipeline.enable_model_cpu_offload()
 
40
  prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
41
  decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  else:
44
  prior_pipeline = None
45
  decoder_pipeline = None
 
50
  seed = random.randint(0, MAX_SEED)
51
  return seed
52
 
 
53
  def generate(
54
  prompt: str,
55
  negative_prompt: str = "",
 
65
  num_images_per_prompt: int = 2,
66
  profile: gr.OAuthProfile | None = None,
67
  ) -> PIL.Image.Image:
 
 
 
68
 
69
  generator = torch.Generator().manual_seed(seed)
 
70
  prior_output = prior_pipeline(
71
  prompt=prompt,
72
  height=height,
 
77
  guidance_scale=prior_guidance_scale,
78
  num_images_per_prompt=num_images_per_prompt,
79
  generator=generator,
 
 
80
  )
81
 
 
 
 
 
 
 
 
82
  decoder_output = decoder_pipeline(
83
  image_embeddings=prior_output.image_embeddings,
84
  prompt=prompt,
 
90
  output_type="pil",
91
  ).images
92
 
93
+ return decoder_output[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  examples = [
 
222
  api_name="run",
223
  )
224
 
225
+ with gr.Blocks(css="style.css") as local_demo:
226
+ demo.render()
227
+
 
 
 
228
  if __name__ == "__main__":
229
+ local_demo.queue(max_size=20).launch()