kadirnar commited on
Commit
8e43c73
1 Parent(s): 36c2a6f

Delete stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +0 -153
stable_cascade.py DELETED
@@ -1,153 +0,0 @@
1
- import torch, os
2
- from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
3
- import gradio as gr
4
-
5
- prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
6
- decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
7
-
8
- def generate_images(
9
- prompt="a photo of a girl",
10
- negative_prompt="bad,ugly,deformed",
11
- height=1024,
12
- width=1024,
13
- guidance_scale=4.0,
14
- seed=42,
15
- num_images_per_prompt=1,
16
- prior_inference_steps=20,
17
- decoder_inference_steps=10
18
- ):
19
- """
20
- Generates images based on a given prompt using Stable Diffusion models on CUDA device.
21
- Parameters:
22
- - prompt (str): The prompt to generate images for.
23
- - negative_prompt (str): The negative prompt to guide image generation away from.
24
- - height (int): The height of the generated images.
25
- - width (int): The width of the generated images.
26
- - guidance_scale (float): The scale of guidance for the image generation.
27
- - prior_inference_steps (int): The number of inference steps for the prior model.
28
- - decoder_inference_steps (int): The number of inference steps for the decoder model.
29
- Returns:
30
- - List[PIL.Image]: A list of generated PIL Image objects.
31
- """
32
- generator = torch.Generator(device="cuda").manual_seed(int(seed))
33
-
34
- # Generate image embeddings using the prior model
35
- prior_output = prior(
36
- prompt=prompt,
37
- generator=generator,
38
- height=height,
39
- width=width,
40
- negative_prompt=negative_prompt,
41
- guidance_scale=guidance_scale,
42
- num_images_per_prompt=num_images_per_prompt,
43
- num_inference_steps=prior_inference_steps
44
- )
45
-
46
- # Generate images using the decoder model and the embeddings from the prior model
47
- decoder_output = decoder(
48
- image_embeddings=prior_output.image_embeddings.half(),
49
- prompt=prompt,
50
- generator=generator,
51
- negative_prompt=negative_prompt,
52
- guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
53
- output_type="pil",
54
- num_inference_steps=decoder_inference_steps
55
- ).images
56
-
57
- return decoder_output
58
-
59
-
60
- def web_demo():
61
- with gr.Blocks():
62
- with gr.Row():
63
- with gr.Column():
64
- text2image_prompt = gr.Textbox(
65
- lines=1,
66
- placeholder="Prompt",
67
- show_label=False,
68
- )
69
-
70
- text2image_negative_prompt = gr.Textbox(
71
- lines=1,
72
- placeholder="Negative Prompt",
73
- show_label=False,
74
- )
75
-
76
- text2image_seed = gr.Number(
77
- value=42,
78
- label="Seed",
79
- )
80
-
81
- with gr.Row():
82
- with gr.Column():
83
- text2image_num_images_per_prompt = gr.Slider(
84
- minimum=1,
85
- maximum=2,
86
- step=1,
87
- value=1,
88
- label="Number Image",
89
- )
90
-
91
- text2image_height = gr.Slider(
92
- minimum=128,
93
- maximum=1024,
94
- step=32,
95
- value=1024,
96
- label="Image Height",
97
- )
98
-
99
- text2image_width = gr.Slider(
100
- minimum=128,
101
- maximum=1024,
102
- step=32,
103
- value=1024,
104
- label="Image Width",
105
- )
106
- with gr.Row():
107
- with gr.Column():
108
- text2image_guidance_scale = gr.Slider(
109
- minimum=0.1,
110
- maximum=15,
111
- step=0.1,
112
- value=4.0,
113
- label="Guidance Scale",
114
- )
115
- text2image_prior_inference_step = gr.Slider(
116
- minimum=1,
117
- maximum=50,
118
- step=1,
119
- value=20,
120
- label="Prior Inference Step",
121
- )
122
-
123
- text2image_decoder_inference_step = gr.Slider(
124
- minimum=1,
125
- maximum=50,
126
- step=1,
127
- value=10,
128
- label="Decoder Inference Step",
129
- )
130
- text2image_predict = gr.Button(value="Generate Image")
131
-
132
- with gr.Column():
133
- output_image = gr.Gallery(
134
- label="Generated images",
135
- show_label=False,
136
- elem_id="gallery",
137
- ).style(grid=(1, 2), height=300)
138
-
139
- text2image_predict.click(
140
- fn=generate_images,
141
- inputs=[
142
- text2image_prompt,
143
- text2image_negative_prompt,
144
- text2image_height,
145
- text2image_width,
146
- text2image_guidance_scale,
147
- text2image_seed,
148
- text2image_num_images_per_prompt,
149
- text2image_prior_inference_step,
150
- text2image_decoder_inference_step
151
- ],
152
- outputs=output_image,
153
- )