m4r4k0s23 commited on
Commit
9db3da3
·
verified ·
1 Parent(s): a85e632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -159
app.py CHANGED
@@ -1,113 +1,122 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
- from diffusers import DiffusionPipeline
6
- from peft import PeftModel, PeftConfig
7
  import torch
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # Model list including your LoRA model
12
- MODEL_LIST = [
13
- "CompVis/stable-diffusion-v1-4",
14
- "stabilityai/sdxl-turbo",
15
- "runwayml/stable-diffusion-v1-5",
16
- "stabilityai/stable-diffusion-2-1",
17
- "m4r4k0s23/hw5_lora_raccoon",
18
- ]
19
 
20
  if torch.cuda.is_available():
21
  torch_dtype = torch.float16
22
  else:
23
  torch_dtype = torch.float32
24
 
25
- # Cache to avoid re-initializing pipelines repeatedly
26
- model_cache = {}
27
-
28
- def load_pipeline(model_id: str):
29
- """
30
- Loads or retrieves a cached DiffusionPipeline.
31
-
32
- If the chosen model is your LoRA adapter, then load the base model
33
- (CompVis/stable-diffusion-v1-4) and apply the LoRA weights.
34
- """
35
- if model_id in model_cache:
36
- return model_cache[model_id]
37
-
38
- if model_id == "m4r4k0s23/hw5_lora_raccoon":
39
- # Use the specified base model for your LoRA adapter.
40
- base_model = "CompVis/stable-diffusion-v1-4"
41
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
42
- # Load the LoRA weights
43
- pipe.unet = PeftModel.from_pretrained(
44
- pipe.unet,
45
- model_id,
46
- subfolder="unet",
47
- torch_dtype=torch_dtype
48
- )
49
- pipe.text_encoder = PeftModel.from_pretrained(
50
- pipe.text_encoder,
51
- model_id,
52
- subfolder="text_encoder",
53
- torch_dtype=torch_dtype
54
- )
55
- else:
56
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
57
-
58
- pipe.to(device)
59
- model_cache[model_id] = pipe
60
- return pipe
61
-
62
  MAX_SEED = np.iinfo(np.int32).max
63
  MAX_IMAGE_SIZE = 1024
64
 
 
 
65
  def infer(
66
- model_id,
67
  prompt,
68
  negative_prompt,
69
- seed,
70
- randomize_seed,
71
- width,
72
- height,
73
- guidance_scale,
74
- num_inference_steps,
75
- lora_scale, # New parameter for adjusting LoRA scale
76
- progress=gr.Progress(track_tqdm=True),
77
- ):
78
- # Load the pipeline for the chosen model
79
- pipe = load_pipeline(model_id)
80
-
81
- if randomize_seed:
82
- seed = random.randint(0, MAX_SEED)
83
-
84
- generator = torch.Generator(device=device).manual_seed(seed)
85
-
86
- # If using the LoRA model, update the LoRA scale if supported.
87
- if model_id == "m4r4k0s23/hw5_lora_raccoon":
88
- # This assumes your pipeline's unet has a method to update the LoRA scale.
89
- if hasattr(pipe.unet, "set_lora_scale"):
90
- pipe.unet.set_lora_scale(lora_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  else:
92
- print("Warning: LoRA scale adjustment method not found on UNet.")
93
-
94
- image = pipe(
95
- prompt=prompt,
96
- negative_prompt=negative_prompt,
97
- guidance_scale=guidance_scale,
98
- num_inference_steps=num_inference_steps,
99
- width=width,
100
- height=height,
101
- generator=generator,
102
- ).images[0]
103
-
104
- return image, seed
105
-
106
- examples = [
107
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
108
- "An astronaut riding a green horse",
109
- "A delicious ceviche cheesecake slice",
110
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  css = """
113
  #col-container {
@@ -116,55 +125,131 @@ css = """
116
  }
117
  """
118
 
119
- with gr.Blocks(css=css) as demo:
 
 
 
120
  with gr.Column(elem_id="col-container"):
121
- gr.Markdown(" # Text-to-Image Gradio Template")
122
 
123
  with gr.Row():
124
- # Dropdown to select the model from Hugging Face
125
- model_id = gr.Dropdown(
126
- label="Model",
127
- choices=MODEL_LIST,
128
- value=MODEL_LIST[0], # Default model
129
- )
130
-
131
- with gr.Row():
132
- prompt = gr.Text(
133
- label="Prompt",
134
- show_label=False,
135
- max_lines=1,
136
- placeholder="Enter your prompt",
137
- container=False,
138
- )
139
-
140
- run_button = gr.Button("Run", scale=0, variant="primary")
141
-
142
- result = gr.Image(label="Result", show_label=False)
143
-
144
- with gr.Accordion("Advanced Settings", open=False):
145
- negative_prompt = gr.Text(
146
- label="Negative prompt",
147
  max_lines=1,
148
- placeholder="Enter a negative prompt",
 
149
  )
150
 
151
- seed = gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  label="Seed",
153
  minimum=0,
154
  maximum=MAX_SEED,
155
  step=1,
156
- value=42, # Default seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
 
159
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  with gr.Row():
162
  width = gr.Slider(
163
  label="Width",
164
  minimum=256,
165
  maximum=MAX_IMAGE_SIZE,
166
  step=32,
167
- value=1024,
168
  )
169
 
170
  height = gr.Slider(
@@ -172,54 +257,35 @@ with gr.Blocks(css=css) as demo:
172
  minimum=256,
173
  maximum=MAX_IMAGE_SIZE,
174
  step=32,
175
- value=1024,
176
  )
177
-
178
- with gr.Row():
179
- guidance_scale = gr.Slider(
180
- label="Guidance scale",
181
- minimum=0.0,
182
- maximum=20.0,
183
- step=0.5,
184
- value=7.0,
185
- )
186
-
187
- num_inference_steps = gr.Slider(
188
- label="Number of inference steps",
189
- minimum=1,
190
- maximum=100,
191
- step=1,
192
- value=20,
193
- )
194
-
195
- # New slider for LoRA scale.
196
- lora_scale = gr.Slider(
197
- label="LoRA Scale",
198
- minimum=0.0,
199
- maximum=2.0,
200
- step=0.1,
201
- value=1.0,
202
- info="Adjust the influence of the LoRA weights",
203
- )
204
-
205
- gr.Examples(examples=examples, inputs=[prompt])
206
  gr.on(
207
- triggers=[run_button.click, prompt.submit],
208
  fn=infer,
209
  inputs=[
210
- model_id,
211
  prompt,
212
  negative_prompt,
213
- seed,
214
- randomize_seed,
215
  width,
216
  height,
217
- guidance_scale,
 
 
 
218
  num_inference_steps,
219
- lora_scale, # Pass the new slider value
 
 
 
 
 
 
220
  ],
221
- outputs=[result, seed],
222
  )
223
 
224
  if __name__ == "__main__":
225
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import os
 
 
5
  import torch
6
+ from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
7
+ from diffusers.utils import load_image
8
+ from peft import PeftModel, LoraConfig
9
 
 
10
 
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
 
 
 
 
 
 
13
 
14
  if torch.cuda.is_available():
15
  torch_dtype = torch.float16
16
  else:
17
  torch_dtype = torch.float32
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
21
 
22
+
23
+ # @spaces.GPU #[uncomment to use ZeroGPU]
24
  def infer(
 
25
  prompt,
26
  negative_prompt,
27
+ width=512,
28
+ height=512,
29
+ model_id=model_id_default,
30
+ seed=42,
31
+ guidance_scale=7.0,
32
+ lora_scale=1.0,
33
+ num_inference_steps=20,
34
+ controlnet_checkbox=False,
35
+ controlnet_strength=0.0,
36
+ controlnet_mode="edge_detection",
37
+ controlnet_image=None,
38
+ ip_adapter_checkbox=False,
39
+ ip_adapter_scale=0.0,
40
+ ip_adapter_image=None,
41
+ progress=gr.Progress(track_tqdm=True),
42
+ ):
43
+
44
+ unet_sub_dir = "unet"
45
+ text_encoder_sub_dir = "text_encoder"
46
+
47
+ if model_id is None:
48
+ raise ValueError("Please specify the base model name or path")
49
+
50
+ generator = torch.Generator(device).manual_seed(seed)
51
+ params = {'prompt': prompt,
52
+ 'negative_prompt': negative_prompt,
53
+ 'guidance_scale': guidance_scale,
54
+ 'num_inference_steps': num_inference_steps,
55
+ 'width': width,
56
+ 'height': height,
57
+ 'generator': generator
58
+ }
59
+
60
+ if controlnet_checkbox:
61
+ if controlnet_mode == "depth_map":
62
+ controlnet = ControlNetModel.from_pretrained(
63
+ "lllyasviel/sd-controlnet-depth",
64
+ cache_dir="./models_cache",
65
+ torch_dtype=torch_dtype
66
+ )
67
+ elif controlnet_mode == "pose_estimation":
68
+ controlnet = ControlNetModel.from_pretrained(
69
+ "lllyasviel/sd-controlnet-openpose",
70
+ cache_dir="./models_cache",
71
+ torch_dtype=torch_dtype
72
+ )
73
+ elif controlnet_mode == "normal_map":
74
+ controlnet = ControlNetModel.from_pretrained(
75
+ "lllyasviel/sd-controlnet-normal",
76
+ cache_dir="./models_cache",
77
+ torch_dtype=torch_dtype
78
+ )
79
+ elif controlnet_mode == "scribbles":
80
+ controlnet = ControlNetModel.from_pretrained(
81
+ "lllyasviel/sd-controlnet-scribble",
82
+ cache_dir="./models_cache",
83
+ torch_dtype=torch_dtype
84
+ )
85
  else:
86
+ controlnet = ControlNetModel.from_pretrained(
87
+ "lllyasviel/sd-controlnet-canny",
88
+ cache_dir="./models_cache",
89
+ torch_dtype=torch_dtype
90
+ )
91
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
92
+ controlnet=controlnet,
93
+ torch_dtype=torch_dtype,
94
+ safety_checker=None).to(device)
95
+ params['image'] = controlnet_image
96
+ params['controlnet_conditioning_scale'] = float(controlnet_strength)
97
+ else:
98
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
99
+ torch_dtype=torch_dtype,
100
+ safety_checker=None).to(device)
101
+
102
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
103
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
104
+
105
+ pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
106
+ pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
107
+
108
+ if torch_dtype in (torch.float16, torch.bfloat16):
109
+ pipe.unet.half()
110
+ pipe.text_encoder.half()
111
+
112
+ if ip_adapter_checkbox:
113
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
114
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
115
+ params['ip_adapter_image'] = ip_adapter_image
116
+
117
+ pipe.to(device)
118
+
119
+ return pipe(**params).images[0]
120
 
121
  css = """
122
  #col-container {
 
125
  }
126
  """
127
 
128
+ def controlnet_params(show_extra):
129
+ return gr.update(visible=show_extra)
130
+
131
+ with gr.Blocks(css=css, fill_height=True) as demo:
132
  with gr.Column(elem_id="col-container"):
133
+ gr.Markdown(" # Text-to-Image demo")
134
 
135
  with gr.Row():
136
+ model_id = gr.Textbox(
137
+ label="Model ID",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  max_lines=1,
139
+ placeholder="Enter model id",
140
+ value=model_id_default,
141
  )
142
 
143
+ prompt = gr.Textbox(
144
+ label="Prompt",
145
+ max_lines=1,
146
+ placeholder="Enter your prompt",
147
+ )
148
+
149
+ negative_prompt = gr.Textbox(
150
+ label="Negative prompt",
151
+ max_lines=1,
152
+ placeholder="Enter your negative prompt",
153
+ )
154
+
155
+ with gr.Row():
156
+ seed = gr.Number(
157
  label="Seed",
158
  minimum=0,
159
  maximum=MAX_SEED,
160
  step=1,
161
+ value=42,
162
+ )
163
+
164
+ guidance_scale = gr.Slider(
165
+ label="Guidance scale",
166
+ minimum=0.0,
167
+ maximum=30.0,
168
+ step=0.1,
169
+ value=7.0, # Replace with defaults that work for your model
170
+ )
171
+ with gr.Row():
172
+ lora_scale = gr.Slider(
173
+ label="LoRA scale",
174
+ minimum=0.0,
175
+ maximum=1.0,
176
+ step=0.01,
177
+ value=1.0,
178
  )
179
 
180
+ num_inference_steps = gr.Slider(
181
+ label="Number of inference steps",
182
+ minimum=1,
183
+ maximum=100,
184
+ step=1,
185
+ value=20, # Replace with defaults that work for your model
186
+ )
187
+ with gr.Row():
188
+ controlnet_checkbox = gr.Checkbox(
189
+ label="ControlNet",
190
+ value=False
191
+ )
192
+ with gr.Column(visible=False) as controlnet_params:
193
+ controlnet_strength = gr.Slider(
194
+ label="ControlNet conditioning scale",
195
+ minimum=0.0,
196
+ maximum=1.0,
197
+ step=0.01,
198
+ value=1.0,
199
+ )
200
+ controlnet_mode = gr.Dropdown(
201
+ label="ControlNet mode",
202
+ choices=["edge_detection",
203
+ "depth_map",
204
+ "pose_estimation",
205
+ "normal_map",
206
+ "scribbles"],
207
+ value="edge_detection",
208
+ max_choices=1
209
+ )
210
+ controlnet_image = gr.Image(
211
+ label="ControlNet condition image",
212
+ type="pil",
213
+ format="png"
214
+ )
215
+ controlnet_checkbox.change(
216
+ fn=lambda x: gr.Row.update(visible=x),
217
+ inputs=controlnet_checkbox,
218
+ outputs=controlnet_params
219
+ )
220
 
221
+ with gr.Row():
222
+ ip_adapter_checkbox = gr.Checkbox(
223
+ label="IPAdapter",
224
+ value=False
225
+ )
226
+ with gr.Column(visible=False) as ip_adapter_params:
227
+ ip_adapter_scale = gr.Slider(
228
+ label="IPAdapter scale",
229
+ minimum=0.0,
230
+ maximum=1.0,
231
+ step=0.01,
232
+ value=1.0,
233
+ )
234
+ ip_adapter_image = gr.Image(
235
+ label="IPAdapter condition image",
236
+ type="pil"
237
+ )
238
+ ip_adapter_checkbox.change(
239
+ fn=lambda x: gr.Row.update(visible=x),
240
+ inputs=ip_adapter_checkbox,
241
+ outputs=ip_adapter_params
242
+ )
243
+
244
+ with gr.Accordion("Optional Settings", open=False):
245
+
246
  with gr.Row():
247
  width = gr.Slider(
248
  label="Width",
249
  minimum=256,
250
  maximum=MAX_IMAGE_SIZE,
251
  step=32,
252
+ value=512, # Replace with defaults that work for your model
253
  )
254
 
255
  height = gr.Slider(
 
257
  minimum=256,
258
  maximum=MAX_IMAGE_SIZE,
259
  step=32,
260
+ value=512, # Replace with defaults that work for your model
261
  )
262
+
263
+ run_button = gr.Button("Run", scale=0, variant="primary")
264
+ result = gr.Image(label="Result", show_label=False)
265
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  gr.on(
267
+ triggers=[run_button.click],
268
  fn=infer,
269
  inputs=[
 
270
  prompt,
271
  negative_prompt,
 
 
272
  width,
273
  height,
274
+ model_id,
275
+ seed,
276
+ guidance_scale,
277
+ lora_scale,
278
  num_inference_steps,
279
+ controlnet_checkbox,
280
+ controlnet_strength,
281
+ controlnet_mode,
282
+ controlnet_image,
283
+ ip_adapter_checkbox,
284
+ ip_adapter_scale,
285
+ ip_adapter_image,
286
  ],
287
+ outputs=[result],
288
  )
289
 
290
  if __name__ == "__main__":
291
+ demo.launch()