Damian Stewart commited on
Commit
d8ffb68
1 Parent(s): 209d166

allow different resolutions for w/h

Browse files
Files changed (4) hide show
  1. StableDiffuser.py +23 -33
  2. app.py +78 -14
  3. finetuning.py +2 -7
  4. train.py +1 -1
StableDiffuser.py CHANGED
@@ -1,17 +1,13 @@
1
  import argparse
2
- import traceback
3
 
4
  import torch
5
  from baukit import TraceDict
6
- from diffusers import AutoencoderKL, UNet2DConditionModel
7
  from PIL import Image
8
  from tqdm.auto import tqdm
9
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
10
- from diffusers.schedulers import EulerAncestralDiscreteScheduler
11
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
12
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
13
  from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
14
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
15
  import util
16
 
17
 
@@ -39,31 +35,17 @@ class StableDiffuser(torch.nn.Module):
39
  def __init__(self,
40
  scheduler='LMS',
41
  repo_id_or_path="CompVis/stable-diffusion-v1-4",
 
42
  ):
43
 
44
  super().__init__()
45
 
46
- # Load the autoencoder model which will be used to decode the latents into image space.
47
- self.vae = AutoencoderKL.from_pretrained(
48
- repo_id_or_path, subfolder="vae")
49
-
50
- # Load the tokenizer and text encoder to tokenize and encode the text.
51
- self.tokenizer = CLIPTokenizer.from_pretrained(
52
- repo_id_or_path, subfolder="tokenizer")
53
- self.text_encoder = CLIPTextModel.from_pretrained(
54
- repo_id_or_path, subfolder="text_encoder")
55
-
56
- # The UNet model for generating the latents.
57
- self.unet = UNet2DConditionModel.from_pretrained(
58
- repo_id_or_path, subfolder="unet")
59
-
60
- try:
61
- self.feature_extractor = CLIPFeatureExtractor.from_pretrained(repo_id_or_path, subfolder="feature_extractor")
62
- self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id_or_path, subfolder="safety_checker")
63
- except Exception as error:
64
- print(f"caught exception {error} making feature extractor / safety checker")
65
- self.feature_extractor = None
66
- self.safety_checker = None
67
 
68
  if scheduler == 'LMS':
69
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
@@ -74,10 +56,14 @@ class StableDiffuser(torch.nn.Module):
74
 
75
  self.eval()
76
 
77
- def get_noise(self, batch_size, img_size, generator=None):
 
 
 
 
78
  param = list(self.parameters())[0]
79
  return torch.randn(
80
- (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
81
  generator=generator).type(param.dtype).to(param.device)
82
 
83
  def add_noise(self, latents, noise, step):
@@ -109,8 +95,8 @@ class StableDiffuser(torch.nn.Module):
109
  def set_scheduler_timesteps(self, n_steps):
110
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
111
 
112
- def get_initial_latents(self, n_imgs, img_size, n_prompts, generator=None):
113
- noise = self.get_noise(n_imgs, img_size, generator=generator).repeat(n_prompts, 1, 1, 1)
114
  latents = noise * self.scheduler.init_noise_sigma
115
  return latents
116
 
@@ -196,7 +182,8 @@ class StableDiffuser(torch.nn.Module):
196
  def __call__(self,
197
  prompts,
198
  negative_prompts,
199
- img_size=512,
 
200
  n_steps=50,
201
  n_imgs=1,
202
  end_iteration=None,
@@ -210,7 +197,7 @@ class StableDiffuser(torch.nn.Module):
210
  prompts = [prompts]
211
 
212
  self.set_scheduler_timesteps(n_steps)
213
- latents = self.get_initial_latents(n_imgs, img_size, len(prompts), generator=generator)
214
  text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
215
  end_iteration = end_iteration or n_steps
216
  latents_steps, trace_steps = self.diffusion(
@@ -239,13 +226,16 @@ class StableDiffuser(torch.nn.Module):
239
 
240
  return images_steps
241
 
 
 
 
242
 
243
  if __name__ == '__main__':
244
 
245
  parser = default_parser()
246
  args = parser.parse_args()
247
 
248
- diffuser = StableDiffuser(seed=args.seed, scheduler='DDIM').to(torch.device(args.device)).half()
249
 
250
  images = diffuser(args.prompts,
251
  n_steps=args.nsteps,
 
1
  import argparse
 
2
 
3
  import torch
4
  from baukit import TraceDict
5
+ from diffusers import StableDiffusionPipeline
6
  from PIL import Image
7
  from tqdm.auto import tqdm
 
 
8
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
9
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
10
  from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
 
11
  import util
12
 
13
 
 
35
  def __init__(self,
36
  scheduler='LMS',
37
  repo_id_or_path="CompVis/stable-diffusion-v1-4",
38
+ variant='fp16'
39
  ):
40
 
41
  super().__init__()
42
 
43
+ self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path, variant=variant)
44
+
45
+ self.vae = self.pipeline.vae
46
+ self.unet = self.pipeline.unet
47
+ self.tokenizer = self.pipeline.tokenizer
48
+ self.text_encoder = self.pipeline.text_encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  if scheduler == 'LMS':
51
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
 
56
 
57
  self.eval()
58
 
59
+ @property
60
+ def safety_checker(self):
61
+ return self.pipeline.safety_checker
62
+
63
+ def get_noise(self, batch_size, width, height, generator=None):
64
  param = list(self.parameters())[0]
65
  return torch.randn(
66
+ (batch_size, self.unet.in_channels, width // 8, height // 8),
67
  generator=generator).type(param.dtype).to(param.device)
68
 
69
  def add_noise(self, latents, noise, step):
 
95
  def set_scheduler_timesteps(self, n_steps):
96
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
97
 
98
+ def get_initial_latents(self, n_imgs, width, height, n_prompts, generator=None):
99
+ noise = self.get_noise(n_imgs, width, height, generator=generator).repeat(n_prompts, 1, 1, 1)
100
  latents = noise * self.scheduler.init_noise_sigma
101
  return latents
102
 
 
182
  def __call__(self,
183
  prompts,
184
  negative_prompts,
185
+ width=512,
186
+ height=512,
187
  n_steps=50,
188
  n_imgs=1,
189
  end_iteration=None,
 
197
  prompts = [prompts]
198
 
199
  self.set_scheduler_timesteps(n_steps)
200
+ latents = self.get_initial_latents(n_imgs, width, height, len(prompts), generator=generator)
201
  text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
202
  end_iteration = end_iteration or n_steps
203
  latents_steps, trace_steps = self.diffusion(
 
226
 
227
  return images_steps
228
 
229
+ def save_pretrained(self, path, **kwargs):
230
+ self.pipeline.save_pretrained(path, **kwargs)
231
+
232
 
233
  if __name__ == '__main__':
234
 
235
  parser = default_parser()
236
  args = parser.parse_args()
237
 
238
+ diffuser = StableDiffuser(scheduler='DDIM').to(torch.device(args.device)).half()
239
 
240
  images = diffuser(args.prompts,
241
  n_steps=args.nsteps,
app.py CHANGED
@@ -86,8 +86,16 @@ class Demo:
86
  label="Seed",
87
  value=42
88
  )
89
- self.img_size_infr = gr.Slider(
90
- label="Image size",
 
 
 
 
 
 
 
 
91
  minimum=256,
92
  maximum=1024,
93
  value=512,
@@ -190,11 +198,51 @@ class Demo:
190
 
191
  self.download = gr.Files()
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  self.infr_button.click(self.inference, inputs = [
194
  self.prompt_input_infr,
195
  self.negative_prompt_input_infr,
196
  self.seed_infr,
197
- self.img_size_infr,
 
198
  self.model_dropdown,
199
  self.base_repo_id_or_path_input_infr
200
  ],
@@ -214,6 +262,14 @@ class Demo:
214
  ],
215
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
216
  )
 
 
 
 
 
 
 
 
217
 
218
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
219
 
@@ -251,42 +307,50 @@ class Demo:
251
 
252
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
253
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- def inference(self, prompt, negative_prompt, seed, img_size, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
256
 
257
  seed = seed or 42
258
- generator = torch.manual_seed(seed)
259
  model_path = model_map[model_name]
260
  checkpoint = torch.load(model_path)
261
 
262
  self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
263
  finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
264
- torch.cuda.empty_cache()
265
 
 
 
 
266
  images = self.diffuser(
267
  prompt,
268
  negative_prompt,
269
- img_size=img_size,
 
270
  n_steps=50,
271
  generator=generator
272
  )
273
-
274
-
275
  orig_image = images[0][0]
276
 
277
  torch.cuda.empty_cache()
278
-
279
- generator = torch.manual_seed(seed)
280
-
281
  with finetuner:
282
-
283
  images = self.diffuser(
284
  prompt,
285
  negative_prompt,
 
 
286
  n_steps=50,
287
  generator=generator
288
  )
289
-
290
  edited_image = images[0][0]
291
 
292
  del finetuner
 
86
  label="Seed",
87
  value=42
88
  )
89
+ self.img_width_infr = gr.Slider(
90
+ label="Image width",
91
+ minimum=256,
92
+ maximum=1024,
93
+ value=512,
94
+ step=64
95
+ )
96
+
97
+ self.img_height_infr = gr.Slider(
98
+ label="Image height",
99
  minimum=256,
100
  maximum=1024,
101
  value=512,
 
198
 
199
  self.download = gr.Files()
200
 
201
+ with gr.Tab("Export") as export_column:
202
+
203
+ with gr.Row():
204
+
205
+ self.explain_train= gr.Markdown(interactive=False,
206
+ value='Export a model to Diffusers format. Please enter the base model and select the editing weights.')
207
+
208
+ with gr.Row():
209
+
210
+ with gr.Column(scale=3):
211
+
212
+ self.base_repo_id_or_path_input_export = gr.Text(
213
+ label="Base model",
214
+ value="CompVis/stable-diffusion-v1-4",
215
+ info="Path or huggingface repo id of the base model that this edit was done against"
216
+ )
217
+
218
+ self.model_dropdown_export = gr.Dropdown(
219
+ label="ESD Model",
220
+ choices=list(model_map.keys()),
221
+ value='Van Gogh',
222
+ interactive=True
223
+ )
224
+
225
+ self.save_path_input_export = gr.Text(
226
+ label="Output path",
227
+ placeholder="./exported_models/model_name",
228
+ info="Path to export the model to. A diffusers folder will be written to this location."
229
+ )
230
+
231
+ self.save_half_export = gr.Checkbox(
232
+ label="Save as fp16"
233
+ )
234
+
235
+ with gr.Column(scale=1):
236
+ self.export_button = gr.Button(
237
+ value="Export",
238
+ )
239
+
240
  self.infr_button.click(self.inference, inputs = [
241
  self.prompt_input_infr,
242
  self.negative_prompt_input_infr,
243
  self.seed_infr,
244
+ self.img_width_infr,
245
+ self.img_height_infr,
246
  self.model_dropdown,
247
  self.base_repo_id_or_path_input_infr
248
  ],
 
262
  ],
263
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
264
  )
265
+ self.export_button.click(self.export, inputs = [
266
+ self.model_dropdown_export,
267
+ self.base_repo_id_or_path_input_export,
268
+ self.save_path_input_export,
269
+ self.save_half_export
270
+ ],
271
+ outputs=[self.export_button]
272
+ )
273
 
274
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
275
 
 
307
 
308
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
309
 
310
+ def export(self, model_name, base_repo_id_or_path, save_path, save_half):
311
+ model_path = model_map[model_name]
312
+ checkpoint = torch.load(model_path)
313
+ self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval()
314
+ finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval()
315
+ with finetuner:
316
+ if save_half:
317
+ self.diffuser = self.diffuser.half()
318
+ self.diffuser.pipeline.to(torch.float16, torch_device=self.diffuser.device)
319
+ self.diffuser.save_pretrained(save_path)
320
+
321
 
322
+ def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
323
 
324
  seed = seed or 42
 
325
  model_path = model_map[model_name]
326
  checkpoint = torch.load(model_path)
327
 
328
  self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
329
  finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
 
330
 
331
+ generator = torch.manual_seed(seed)
332
+
333
+ torch.cuda.empty_cache()
334
  images = self.diffuser(
335
  prompt,
336
  negative_prompt,
337
+ width=width,
338
+ height=height,
339
  n_steps=50,
340
  generator=generator
341
  )
 
 
342
  orig_image = images[0][0]
343
 
344
  torch.cuda.empty_cache()
 
 
 
345
  with finetuner:
 
346
  images = self.diffuser(
347
  prompt,
348
  negative_prompt,
349
+ width=width,
350
+ height=height,
351
  n_steps=50,
352
  generator=generator
353
  )
 
354
  edited_image = images[0][0]
355
 
356
  del finetuner
finetuning.py CHANGED
@@ -2,11 +2,12 @@ import copy
2
  import re
3
  import torch
4
  import util
 
5
 
6
  class FineTunedModel(torch.nn.Module):
7
 
8
  def __init__(self,
9
- model,
10
  modules,
11
  frozen_modules=[]
12
  ):
@@ -24,11 +25,8 @@ class FineTunedModel(torch.nn.Module):
24
 
25
  for module_name, module in model.named_modules():
26
  for ft_module_regex in modules:
27
-
28
  match = re.search(ft_module_regex, module_name)
29
-
30
  if match is not None:
31
-
32
  ft_module = copy.deepcopy(module)
33
 
34
  self.orig_modules[module_name] = module
@@ -39,13 +37,10 @@ class FineTunedModel(torch.nn.Module):
39
  print(f"=> Finetuning {module_name}")
40
 
41
  for ft_module_name, module in ft_module.named_modules():
42
-
43
  ft_module_name = f"{module_name}.{ft_module_name}"
44
-
45
  for freeze_module_name in frozen_modules:
46
 
47
  match = re.search(freeze_module_name, ft_module_name)
48
-
49
  if match:
50
  print(f"=> Freezing {ft_module_name}")
51
  util.freeze(module)
 
2
  import re
3
  import torch
4
  import util
5
+ from StableDiffuser import StableDiffuser
6
 
7
  class FineTunedModel(torch.nn.Module):
8
 
9
  def __init__(self,
10
+ model: StableDiffuser,
11
  modules,
12
  frozen_modules=[]
13
  ):
 
25
 
26
  for module_name, module in model.named_modules():
27
  for ft_module_regex in modules:
 
28
  match = re.search(ft_module_regex, module_name)
 
29
  if match is not None:
 
30
  ft_module = copy.deepcopy(module)
31
 
32
  self.orig_modules[module_name] = module
 
37
  print(f"=> Finetuning {module_name}")
38
 
39
  for ft_module_name, module in ft_module.named_modules():
 
40
  ft_module_name = f"{module_name}.{ft_module_name}"
 
41
  for freeze_module_name in frozen_modules:
42
 
43
  match = re.search(freeze_module_name, ft_module_name)
 
44
  if match:
45
  print(f"=> Freezing {ft_module_name}")
46
  util.freeze(module)
train.py CHANGED
@@ -36,7 +36,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
36
  optimizer.zero_grad()
37
 
38
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
39
- latents = diffuser.get_initial_latents(1, img_size, 1)
40
 
41
  with finetuner:
42
  latents_steps, _ = diffuser.diffusion(
 
36
  optimizer.zero_grad()
37
 
38
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
39
+ latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
40
 
41
  with finetuner:
42
  latents_steps, _ = diffuser.diffusion(