JadenFK commited on
Commit
7c89716
1 Parent(s): 87cd610

Updated diffuser to use seperate generator

Browse files
Files changed (3) hide show
  1. StableDiffuser.py +7 -23
  2. app.py +15 -17
  3. train.py +3 -7
StableDiffuser.py CHANGED
@@ -34,14 +34,11 @@ def default_parser():
34
  class StableDiffuser(torch.nn.Module):
35
 
36
  def __init__(self,
37
- scheduler='LMS',
38
- seed=None
39
  ):
40
 
41
  super().__init__()
42
 
43
- self._seed = seed
44
-
45
  # Load the autoencoder model which will be used to decode the latents into image space.
46
  self.vae = AutoencoderKL.from_pretrained(
47
  "CompVis/stable-diffusion-v1-4", subfolder="vae")
@@ -62,25 +59,16 @@ class StableDiffuser(torch.nn.Module):
62
  self.scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
63
  elif scheduler == 'DDPM':
64
  self.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
65
- self.generator = torch.Generator()
66
-
67
- if self._seed is not None:
68
-
69
- self.seed(seed)
70
 
71
  self.eval()
72
 
73
- def seed(self, seed):
74
-
75
- self.generator = torch.manual_seed(seed)
76
-
77
- def get_noise(self, batch_size, img_size):
78
 
79
  param = list(self.parameters())[0]
80
 
81
  return torch.randn(
82
  (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
83
- generator=self.generator).type(param.dtype).to(param.device)
84
 
85
  def add_noise(self, latents, noise, step):
86
 
@@ -118,9 +106,9 @@ class StableDiffuser(torch.nn.Module):
118
  def set_scheduler_timesteps(self, n_steps):
119
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
120
 
121
- def get_initial_latents(self, n_imgs, img_size, n_prompts):
122
 
123
- noise = self.get_noise(n_imgs, img_size).repeat(n_prompts, 1, 1, 1)
124
 
125
  latents = noise * self.scheduler.init_noise_sigma
126
 
@@ -221,7 +209,7 @@ class StableDiffuser(torch.nn.Module):
221
  n_steps=50,
222
  n_imgs=1,
223
  end_iteration=None,
224
- reseed=False,
225
  **kwargs
226
  ):
227
 
@@ -233,11 +221,7 @@ class StableDiffuser(torch.nn.Module):
233
 
234
  self.set_scheduler_timesteps(n_steps)
235
 
236
- if reseed:
237
-
238
- self.seed(self._seed)
239
-
240
- latents = self.get_initial_latents(n_imgs, img_size, len(prompts))
241
 
242
  text_embeddings = self.get_text_embeddings(prompts,n_imgs=n_imgs)
243
 
34
  class StableDiffuser(torch.nn.Module):
35
 
36
  def __init__(self,
37
+ scheduler='LMS'
 
38
  ):
39
 
40
  super().__init__()
41
 
 
 
42
  # Load the autoencoder model which will be used to decode the latents into image space.
43
  self.vae = AutoencoderKL.from_pretrained(
44
  "CompVis/stable-diffusion-v1-4", subfolder="vae")
59
  self.scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
60
  elif scheduler == 'DDPM':
61
  self.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
 
 
 
 
 
62
 
63
  self.eval()
64
 
65
+ def get_noise(self, batch_size, img_size, generator=None):
 
 
 
 
66
 
67
  param = list(self.parameters())[0]
68
 
69
  return torch.randn(
70
  (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
71
+ generator=generator).type(param.dtype).to(param.device)
72
 
73
  def add_noise(self, latents, noise, step):
74
 
106
  def set_scheduler_timesteps(self, n_steps):
107
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
108
 
109
+ def get_initial_latents(self, n_imgs, img_size, n_prompts, generator=None):
110
 
111
+ noise = self.get_noise(n_imgs, img_size, generator=generator).repeat(n_prompts, 1, 1, 1)
112
 
113
  latents = noise * self.scheduler.init_noise_sigma
114
 
209
  n_steps=50,
210
  n_imgs=1,
211
  end_iteration=None,
212
+ generator=None,
213
  **kwargs
214
  ):
215
 
221
 
222
  self.set_scheduler_timesteps(n_steps)
223
 
224
+ latents = self.get_initial_latents(n_imgs, img_size, len(prompts), generator=generator)
 
 
 
 
225
 
226
  text_embeddings = self.get_text_embeddings(prompts,n_imgs=n_imgs)
227
 
app.py CHANGED
@@ -17,18 +17,18 @@ class Demo:
17
  def __init__(self) -> None:
18
 
19
  self.training = False
 
20
 
21
- self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda').eval().half()
22
 
23
  with gr.Blocks() as demo:
24
  self.layout()
25
- demo.queue(concurrency_count=2).launch()
26
 
27
 
28
  def layout(self):
29
 
30
  with gr.Row():
31
-
32
 
33
  with gr.Tab("Test") as inference_column:
34
 
@@ -152,13 +152,10 @@ class Demo:
152
  )
153
 
154
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
155
- # self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda').eval().half()
156
  if self.training:
157
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
158
- # clear the diffusers
159
- # del self.diffuser
160
- # torch.cuda.empty_cache()
161
-
162
  if train_method == 'ESD-x':
163
 
164
  modules = ".*attn2$"
@@ -188,45 +185,46 @@ class Demo:
188
 
189
  model_map['Custom'] = save_path
190
 
191
- # del self.diffuser
192
- torch.cuda.empty_cache()
193
- # self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda').eval().half()
194
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
195
 
196
 
197
  def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
198
 
199
- self.diffuser._seed = seed or 42
 
 
200
 
201
  model_path = model_map[model_name]
202
 
203
  checkpoint = torch.load(model_path)
204
 
205
- self.finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
206
 
207
  torch.cuda.empty_cache()
208
 
209
  images = self.diffuser(
210
  prompt,
211
  n_steps=50,
212
- reseed=True
213
  )
214
 
215
  orig_image = images[0][0]
216
 
217
  torch.cuda.empty_cache()
218
 
219
- with self.finetuner:
 
 
220
 
221
  images = self.diffuser(
222
  prompt,
223
  n_steps=50,
224
- reseed=True
225
  )
226
 
227
  edited_image = images[0][0]
228
 
229
- del self.finetuner
230
  torch.cuda.empty_cache()
231
 
232
  return edited_image, orig_image
17
  def __init__(self) -> None:
18
 
19
  self.training = False
20
+ self.generating = False
21
 
22
+ self.diffuser = StableDiffuser(scheduler='DDIM').to('cuda').eval().half()
23
 
24
  with gr.Blocks() as demo:
25
  self.layout()
26
+ demo.queue(concurrency_count=5).launch()
27
 
28
 
29
  def layout(self):
30
 
31
  with gr.Row():
 
32
 
33
  with gr.Tab("Test") as inference_column:
34
 
152
  )
153
 
154
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
155
+
156
  if self.training:
157
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
158
+
 
 
 
159
  if train_method == 'ESD-x':
160
 
161
  modules = ".*attn2$"
185
 
186
  model_map['Custom'] = save_path
187
 
 
 
 
188
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
189
 
190
 
191
  def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
192
 
193
+ seed = seed or 42
194
+
195
+ generator = torch.manual_seed(seed)
196
 
197
  model_path = model_map[model_name]
198
 
199
  checkpoint = torch.load(model_path)
200
 
201
+ finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
202
 
203
  torch.cuda.empty_cache()
204
 
205
  images = self.diffuser(
206
  prompt,
207
  n_steps=50,
208
+ generator=generator
209
  )
210
 
211
  orig_image = images[0][0]
212
 
213
  torch.cuda.empty_cache()
214
 
215
+ generator = torch.manual_seed(seed)
216
+
217
+ with finetuner:
218
 
219
  images = self.diffuser(
220
  prompt,
221
  n_steps=50,
222
+ generator=generator
223
  )
224
 
225
  edited_image = images[0][0]
226
 
227
+ del finetuner
228
  torch.cuda.empty_cache()
229
 
230
  return edited_image, orig_image
train.py CHANGED
@@ -10,9 +10,6 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
10
  diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
  diffuser.train()
12
 
13
-
14
-
15
-
16
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
17
 
18
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
@@ -69,14 +66,13 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
69
 
70
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
71
 
72
- del negative_latents, neutral_latents, positive_latents, latents_steps, latents
73
- torch.cuda.empty_cache()
74
-
75
  loss.backward()
76
  optimizer.step()
77
 
78
  torch.save(finetuner.state_dict(), save_path)
79
- del diffuser
 
 
80
  torch.cuda.empty_cache()
81
  if __name__ == '__main__':
82
 
10
  diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
  diffuser.train()
12
 
 
 
 
13
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
14
 
15
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
66
 
67
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
68
 
 
 
 
69
  loss.backward()
70
  optimizer.step()
71
 
72
  torch.save(finetuner.state_dict(), save_path)
73
+
74
+ del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
75
+
76
  torch.cuda.empty_cache()
77
  if __name__ == '__main__':
78