SemaSci commited on
Commit
4267dd3
·
verified ·
1 Parent(s): c8f6784

Update app.py

Browse files

выполнение ДЗ6

Files changed (1) hide show
  1. app.py +203 -204
app.py CHANGED
@@ -1,77 +1,21 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
-
 
9
  from peft import PeftModel, LoraConfig
10
- import os
11
 
12
- def get_lora_sd_pipeline(
13
- ckpt_dir='./lora_logos',
14
- base_model_name_or_path=None,
15
- dtype=torch.float16,
16
- adapter_name="default"
17
- ):
18
-
19
- unet_sub_dir = os.path.join(ckpt_dir, "unet")
20
- text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
21
-
22
- if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
23
- config = LoraConfig.from_pretrained(text_encoder_sub_dir)
24
- base_model_name_or_path = config.base_model_name_or_path
25
-
26
- if base_model_name_or_path is None:
27
- raise ValueError("Please specify the base model name or path")
28
-
29
- pipe = DiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
30
- before_params = pipe.unet.parameters()
31
- pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
32
- pipe.unet.set_adapter(adapter_name)
33
- after_params = pipe.unet.parameters()
34
- print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
35
-
36
- if os.path.exists(text_encoder_sub_dir):
37
- pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
38
-
39
- if dtype in (torch.float16, torch.bfloat16):
40
- pipe.unet.half()
41
- pipe.text_encoder.half()
42
-
43
- return pipe
44
-
45
- def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
46
- tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
47
- chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
48
-
49
- with torch.no_grad():
50
- embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
51
-
52
- return torch.cat(embeds, dim=1)
53
-
54
- def align_embeddings(prompt_embeds, negative_prompt_embeds):
55
- max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
56
- return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
57
- torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
58
 
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
- #model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
61
- model_id_default = "sd-legacy/stable-diffusion-v1-5"
62
- model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'sd-legacy/stable-diffusion-v1-5' ]
63
-
64
- model_lora_default = "lora_pussinboots_logos"
65
- model_lora_dropdown = ['lora_lady_and_cats_logos', 'lora_pussinboots_logos' ]
66
 
67
  if torch.cuda.is_available():
68
  torch_dtype = torch.float16
69
  else:
70
  torch_dtype = torch.float32
71
 
72
- # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
73
- # pipe = pipe.to(device)
74
-
75
  MAX_SEED = np.iinfo(np.int32).max
76
  MAX_IMAGE_SIZE = 1024
77
 
@@ -80,82 +24,99 @@ MAX_IMAGE_SIZE = 1024
80
  def infer(
81
  prompt,
82
  negative_prompt,
83
- randomize_seed,
84
  width=512,
85
  height=512,
86
- model_repo_id=model_id_default,
87
  seed=42,
88
- guidance_scale=7,
 
89
  num_inference_steps=20,
90
- model_lora_id=model_lora_default,
91
- lora_scale=0.5,
92
- progress=gr.Progress(track_tqdm=True),
93
- ):
94
-
95
- if randomize_seed:
96
- seed = random.randint(0, MAX_SEED)
 
 
 
 
 
97
 
98
- generator = torch.Generator().manual_seed(seed)
 
99
 
100
- # убираем обновление pipe всегда
101
- #pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
102
- #pipe = pipe.to(device)
103
-
104
- # добавляем обновление pipe по условию
105
- if model_repo_id != model_id_default:
106
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
107
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
108
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
109
- prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  else:
111
- # добавляем lora
112
- #pipe = get_lora_sd_pipeline(ckpt_dir='./lora_lady_and_cats_logos', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
113
- pipe = get_lora_sd_pipeline(ckpt_dir='./'+model_lora_id, base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
114
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
115
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
116
- prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
117
- print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
118
- print(f"LoRA scale applied: {lora_scale}")
119
- pipe.fuse_lora(lora_scale=lora_scale)
120
 
 
 
121
 
122
- # заменяем просто вызов pipe с промптом
123
- #image = pipe(
124
- # prompt=prompt,
125
- # negative_prompt=negative_prompt,
126
- # guidance_scale=guidance_scale,
127
- # num_inference_steps=num_inference_steps,
128
- # width=width,
129
- # height=height,
130
- # generator=generator,
131
- #).images[0]
132
 
133
-
134
- # на вызов pipe с эмбеддингами
135
- params = {
136
- 'prompt_embeds': prompt_embeds,
137
- 'negative_prompt_embeds': negative_prompt_embeds,
138
- 'guidance_scale': guidance_scale,
139
- 'num_inference_steps': num_inference_steps,
140
- 'width': width,
141
- 'height': height,
142
- 'generator': generator,
143
- }
144
-
145
- return pipe(**params).images[0], seed
146
 
147
- # return image, seed
 
 
 
148
 
 
149
 
150
- examples = [
151
- "Puss in Boots wearing a sombrero crosses the Grand Canyon on a tightrope with a guitar.",
152
- "A cat is playing a song called ""About the Cat"" on an accordion by the sea at sunset. The sun is quickly setting behind the horizon, and the light is fading.",
153
- "A cat walks through the grass on the streets of an abandoned city. The camera view is always focused on the cat's face.",
154
- "A young lady in a Russian embroidered kaftan is sitting on a beautiful carved veranda, holding a cup to her mouth and drinking tea from the cup. With her other hand, the girl holds a saucer. The cup and saucer are painted with gzhel. Next to the girl on the table stands a samovar, and steam can be seen above it.",
155
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
156
- "An astronaut riding a green horse",
157
- "A delicious ceviche cheesecake slice",
158
- ]
159
 
160
  css = """
161
  #col-container {
@@ -164,58 +125,124 @@ css = """
164
  }
165
  """
166
 
167
- with gr.Blocks(css=css) as demo:
 
 
 
168
  with gr.Column(elem_id="col-container"):
169
- gr.Markdown(" # Text-to-Image SemaSci Template")
170
 
171
  with gr.Row():
172
- prompt = gr.Text(
173
- label="Prompt",
174
- show_label=False,
175
  max_lines=1,
176
- placeholder="Enter your prompt",
177
- container=False,
178
- )
179
-
180
- run_button = gr.Button("Run", scale=0, variant="primary")
181
-
182
- result = gr.Image(label="Result", show_label=False)
183
-
184
- with gr.Accordion("Advanced Settings", open=False):
185
- # model_repo_id = gr.Text(
186
- # label="Model Id",
187
- # max_lines=1,
188
- # placeholder="Choose model",
189
- # visible=True,
190
- # value=model_repo_id,
191
- # )
192
- model_repo_id = gr.Dropdown(
193
- label="Model Id",
194
- choices=model_dropdown,
195
- info="Choose model",
196
- visible=True,
197
- allow_custom_value=True,
198
- # value=model_repo_id,
199
  value=model_id_default,
200
- )
201
-
202
- negative_prompt = gr.Text(
203
- label="Negative prompt",
204
- max_lines=1,
205
- placeholder="Enter a negative prompt",
206
- visible=True,
207
  )
208
 
209
- seed = gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  label="Seed",
211
  minimum=0,
212
  maximum=MAX_SEED,
213
  step=1,
214
  value=42,
215
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  with gr.Row():
220
  width = gr.Slider(
221
  label="Width",
@@ -232,61 +259,33 @@ with gr.Blocks(css=css) as demo:
232
  step=32,
233
  value=512, # Replace with defaults that work for your model
234
  )
235
-
236
- with gr.Row():
237
- guidance_scale = gr.Slider(
238
- label="Guidance scale",
239
- minimum=0.0,
240
- maximum=10.0,
241
- step=0.1,
242
- value=7.0, # Replace with defaults that work for your model
243
- )
244
-
245
- num_inference_steps = gr.Slider(
246
- label="Number of inference steps",
247
- minimum=1,
248
- maximum=50,
249
- step=1,
250
- value=20, # Replace with defaults that work for your model
251
- )
252
-
253
- with gr.Row():
254
- model_lora_id = gr.Dropdown(
255
- label="Lora Id",
256
- choices=model_lora_dropdown,
257
- info="Choose LoRA model",
258
- visible=True,
259
- allow_custom_value=True,
260
- value=model_lora_default,
261
- )
262
 
263
- lora_scale = gr.Slider(
264
- label="LoRA scale",
265
- minimum=0.0,
266
- maximum=1.0,
267
- step=0.1,
268
- value=0.5,
269
- )
270
-
271
- gr.Examples(examples=examples, inputs=[prompt])
272
  gr.on(
273
- triggers=[run_button.click, prompt.submit],
274
  fn=infer,
275
  inputs=[
276
  prompt,
277
  negative_prompt,
278
- randomize_seed,
279
  width,
280
  height,
281
- model_repo_id,
282
  seed,
283
- guidance_scale,
284
- num_inference_steps,
285
- model_lora_id,
286
  lora_scale,
 
 
 
 
 
 
 
 
287
  ],
288
- outputs=[result, seed],
289
  )
290
 
291
  if __name__ == "__main__":
292
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import os
 
 
5
  import torch
6
+ from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
7
+ from diffusers.utils import load_image
8
  from peft import PeftModel, LoraConfig
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
 
 
 
 
 
13
 
14
  if torch.cuda.is_available():
15
  torch_dtype = torch.float16
16
  else:
17
  torch_dtype = torch.float32
18
 
 
 
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
21
 
 
24
  def infer(
25
  prompt,
26
  negative_prompt,
 
27
  width=512,
28
  height=512,
29
+ model_id=model_id_default,
30
  seed=42,
31
+ guidance_scale=7.0,
32
+ lora_scale=1.0,
33
  num_inference_steps=20,
34
+ controlnet_checkbox=False,
35
+ controlnet_strength=0.0,
36
+ controlnet_mode="edge_detection",
37
+ controlnet_image=None,
38
+ ip_adapter_checkbox=False,
39
+ ip_adapter_scale=0.0,
40
+ ip_adapter_image=None,
41
+ progress=gr.Progress(track_tqdm=True),
42
+ ):
43
+ ckpt_dir='./model_output'
44
+ unet_sub_dir = os.path.join(ckpt_dir, "unet")
45
+ text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
46
 
47
+ if model_id is None:
48
+ raise ValueError("Please specify the base model name or path")
49
 
50
+ generator = torch.Generator(device).manual_seed(seed)
51
+ params = {'prompt': prompt,
52
+ 'negative_prompt': negative_prompt,
53
+ 'guidance_scale': guidance_scale,
54
+ 'num_inference_steps': num_inference_steps,
55
+ 'width': width,
56
+ 'height': height,
57
+ 'generator': generator
58
+ }
59
+
60
+ if controlnet_checkbox:
61
+ if controlnet_mode == "depth_map":
62
+ controlnet = ControlNetModel.from_pretrained(
63
+ "lllyasviel/sd-controlnet-depth",
64
+ cache_dir="./models_cache",
65
+ torch_dtype=torch_dtype
66
+ )
67
+ elif controlnet_mode == "pose_estimation":
68
+ controlnet = ControlNetModel.from_pretrained(
69
+ "lllyasviel/sd-controlnet-openpose",
70
+ cache_dir="./models_cache",
71
+ torch_dtype=torch_dtype
72
+ )
73
+ elif controlnet_mode == "normal_map":
74
+ controlnet = ControlNetModel.from_pretrained(
75
+ "lllyasviel/sd-controlnet-normal",
76
+ cache_dir="./models_cache",
77
+ torch_dtype=torch_dtype
78
+ )
79
+ elif controlnet_mode == "scribbles":
80
+ controlnet = ControlNetModel.from_pretrained(
81
+ "lllyasviel/sd-controlnet-scribble",
82
+ cache_dir="./models_cache",
83
+ torch_dtype=torch_dtype
84
+ )
85
+ else:
86
+ controlnet = ControlNetModel.from_pretrained(
87
+ "lllyasviel/sd-controlnet-canny",
88
+ cache_dir="./models_cache",
89
+ torch_dtype=torch_dtype
90
+ )
91
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
92
+ controlnet=controlnet,
93
+ torch_dtype=torch_dtype,
94
+ safety_checker=None).to(device)
95
+ params['image'] = controlnet_image
96
+ params['controlnet_conditioning_scale'] = float(controlnet_strength)
97
  else:
98
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
99
+ torch_dtype=torch_dtype,
100
+ safety_checker=None).to(device)
 
 
 
 
 
 
101
 
102
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
103
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
104
 
105
+ pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
106
+ pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
 
 
 
 
 
 
 
 
107
 
108
+ if torch_dtype in (torch.float16, torch.bfloat16):
109
+ pipe.unet.half()
110
+ pipe.text_encoder.half()
 
 
 
 
 
 
 
 
 
 
111
 
112
+ if ip_adapter_checkbox:
113
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
114
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
115
+ params['ip_adapter_image'] = ip_adapter_image
116
 
117
+ pipe.to(device)
118
 
119
+ return pipe(**params).images[0]
 
 
 
 
 
 
 
 
120
 
121
  css = """
122
  #col-container {
 
125
  }
126
  """
127
 
128
+ def controlnet_params(show_extra):
129
+ return gr.update(visible=show_extra)
130
+
131
+ with gr.Blocks(css=css, fill_height=True) as demo:
132
  with gr.Column(elem_id="col-container"):
133
+ gr.Markdown(" # Text-to-Image demo")
134
 
135
  with gr.Row():
136
+ model_id = gr.Textbox(
137
+ label="Model ID",
 
138
  max_lines=1,
139
+ placeholder="Enter model id",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  value=model_id_default,
 
 
 
 
 
 
 
141
  )
142
 
143
+ prompt = gr.Textbox(
144
+ label="Prompt",
145
+ max_lines=1,
146
+ placeholder="Enter your prompt",
147
+ )
148
+
149
+ negative_prompt = gr.Textbox(
150
+ label="Negative prompt",
151
+ max_lines=1,
152
+ placeholder="Enter your negative prompt",
153
+ )
154
+
155
+ with gr.Row():
156
+ seed = gr.Number(
157
  label="Seed",
158
  minimum=0,
159
  maximum=MAX_SEED,
160
  step=1,
161
  value=42,
162
  )
163
+
164
+ guidance_scale = gr.Slider(
165
+ label="Guidance scale",
166
+ minimum=0.0,
167
+ maximum=30.0,
168
+ step=0.1,
169
+ value=7.0, # Replace with defaults that work for your model
170
+ )
171
+ with gr.Row():
172
+ lora_scale = gr.Slider(
173
+ label="LoRA scale",
174
+ minimum=0.0,
175
+ maximum=1.0,
176
+ step=0.01,
177
+ value=1.0,
178
+ )
179
 
180
+ num_inference_steps = gr.Slider(
181
+ label="Number of inference steps",
182
+ minimum=1,
183
+ maximum=100,
184
+ step=1,
185
+ value=20, # Replace with defaults that work for your model
186
+ )
187
+ with gr.Row():
188
+ controlnet_checkbox = gr.Checkbox(
189
+ label="ControlNet",
190
+ value=False
191
+ )
192
+ with gr.Column(visible=False) as controlnet_params:
193
+ controlnet_strength = gr.Slider(
194
+ label="ControlNet conditioning scale",
195
+ minimum=0.0,
196
+ maximum=1.0,
197
+ step=0.01,
198
+ value=1.0,
199
+ )
200
+ controlnet_mode = gr.Dropdown(
201
+ label="ControlNet mode",
202
+ choices=["edge_detection",
203
+ "depth_map",
204
+ "pose_estimation",
205
+ "normal_map",
206
+ "scribbles"],
207
+ value="edge_detection",
208
+ max_choices=1
209
+ )
210
+ controlnet_image = gr.Image(
211
+ label="ControlNet condition image",
212
+ type="pil",
213
+ format="png"
214
+ )
215
+ controlnet_checkbox.change(
216
+ fn=lambda x: gr.Row.update(visible=x),
217
+ inputs=controlnet_checkbox,
218
+ outputs=controlnet_params
219
+ )
220
 
221
+ with gr.Row():
222
+ ip_adapter_checkbox = gr.Checkbox(
223
+ label="IPAdapter",
224
+ value=False
225
+ )
226
+ with gr.Column(visible=False) as ip_adapter_params:
227
+ ip_adapter_scale = gr.Slider(
228
+ label="IPAdapter scale",
229
+ minimum=0.0,
230
+ maximum=1.0,
231
+ step=0.01,
232
+ value=1.0,
233
+ )
234
+ ip_adapter_image = gr.Image(
235
+ label="IPAdapter condition image",
236
+ type="pil"
237
+ )
238
+ ip_adapter_checkbox.change(
239
+ fn=lambda x: gr.Row.update(visible=x),
240
+ inputs=ip_adapter_checkbox,
241
+ outputs=ip_adapter_params
242
+ )
243
+
244
+ with gr.Accordion("Optional Settings", open=False):
245
+
246
  with gr.Row():
247
  width = gr.Slider(
248
  label="Width",
 
259
  step=32,
260
  value=512, # Replace with defaults that work for your model
261
  )
262
+
263
+ run_button = gr.Button("Run", scale=0, variant="primary")
264
+ result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
 
 
 
 
 
 
 
 
 
266
  gr.on(
267
+ triggers=[run_button.click],
268
  fn=infer,
269
  inputs=[
270
  prompt,
271
  negative_prompt,
 
272
  width,
273
  height,
274
+ model_id,
275
  seed,
276
+ guidance_scale,
 
 
277
  lora_scale,
278
+ num_inference_steps,
279
+ controlnet_checkbox,
280
+ controlnet_strength,
281
+ controlnet_mode,
282
+ controlnet_image,
283
+ ip_adapter_checkbox,
284
+ ip_adapter_scale,
285
+ ip_adapter_image,
286
  ],
287
+ outputs=[result],
288
  )
289
 
290
  if __name__ == "__main__":
291
+ demo.launch()