kadirnar commited on
Commit
a064d1f
1 Parent(s): c7facaa

Update stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +13 -0
stable_cascade.py CHANGED
@@ -13,6 +13,7 @@ def generate_images(
13
  height=1024,
14
  width=1024,
15
  guidance_scale=4.0,
 
16
  num_images_per_prompt=1,
17
  prior_inference_steps=20,
18
  decoder_inference_steps=10
@@ -30,10 +31,12 @@ def generate_images(
30
  Returns:
31
  - List[PIL.Image]: A list of generated PIL Image objects.
32
  """
 
33
 
34
  # Generate image embeddings using the prior model
35
  prior_output = prior(
36
  prompt=prompt,
 
37
  height=height,
38
  width=width,
39
  negative_prompt=negative_prompt,
@@ -46,6 +49,7 @@ def generate_images(
46
  decoder_output = decoder(
47
  image_embeddings=prior_output.image_embeddings.half(),
48
  prompt=prompt,
 
49
  negative_prompt=negative_prompt,
50
  guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
51
  output_type="pil",
@@ -70,6 +74,15 @@ def web_demo():
70
  placeholder="Negative Prompt",
71
  show_label=False,
72
  )
 
 
 
 
 
 
 
 
 
73
  with gr.Row():
74
  with gr.Column():
75
  text2image_num_images_per_prompt = gr.Slider(
 
13
  height=1024,
14
  width=1024,
15
  guidance_scale=4.0,
16
+ seed=42,
17
  num_images_per_prompt=1,
18
  prior_inference_steps=20,
19
  decoder_inference_steps=10
 
31
  Returns:
32
  - List[PIL.Image]: A list of generated PIL Image objects.
33
  """
34
+ generator = torch.Generator(device="cuda").manual_seed(seed)
35
 
36
  # Generate image embeddings using the prior model
37
  prior_output = prior(
38
  prompt=prompt,
39
+ generator=generator,
40
  height=height,
41
  width=width,
42
  negative_prompt=negative_prompt,
 
49
  decoder_output = decoder(
50
  image_embeddings=prior_output.image_embeddings.half(),
51
  prompt=prompt,
52
+ generator=generator,
53
  negative_prompt=negative_prompt,
54
  guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
55
  output_type="pil",
 
74
  placeholder="Negative Prompt",
75
  show_label=False,
76
  )
77
+
78
+ text2image_num_images_per_prompt = gr.Slider(
79
+ minimum=1,
80
+ maximum=1000000,
81
+ step=10,
82
+ value=42,
83
+ label="Seed",
84
+ )
85
+
86
  with gr.Row():
87
  with gr.Column():
88
  text2image_num_images_per_prompt = gr.Slider(