Damian Stewart commited on
Commit
2c1839c
·
1 Parent(s): 50b9662

cleanup and try to get cancellation working

Browse files
Files changed (3) hide show
  1. README.md +14 -0
  2. StableDiffuser.py +4 -2
  3. train.py +19 -3
README.md CHANGED
@@ -10,7 +10,21 @@ pinned: false
10
  license: mit
11
  ---
12
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Erasing Concepts from Diffusion Models
16
 
 
10
  license: mit
11
  ---
12
 
13
+ # A GUI with custom model support, validation, and sample generation for "Erasing Concepts from Diffusion Models"
14
 
15
+ Enables xformers, 8 bit AdamW via bitsandbytes, and AMP - editing SD1.5 models works with 16GB VRAM, and 2.5 models including the ESD-u training works with 24GB VRAM.
16
+
17
+ ## Quick start
18
+
19
+ To run on vast.ai, use eg `pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel` - you need `-devel` for 8bit AdamW to work.
20
+
21
+ On the dev machine:
22
+ ```
23
+ pip install -r requirements.txt
24
+ python app.py
25
+ ```
26
+
27
+ then use the Gradio interface at port 7860.
28
 
29
  # Erasing Concepts from Diffusion Models
30
 
StableDiffuser.py CHANGED
@@ -107,6 +107,7 @@ class StableDiffuser(torch.nn.Module):
107
  return latents
108
 
109
  def get_cond_and_uncond_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
 
110
  text_tokens = self.text_tokenize(prompts)
111
  text_embeddings = self.text_encode(text_tokens)
112
  if negative_prompts is None:
@@ -115,8 +116,9 @@ class StableDiffuser(torch.nn.Module):
115
  negative_prompts.append("")
116
  unconditional_tokens = self.text_tokenize(negative_prompts)
117
  unconditional_embeddings = self.text_encode(unconditional_tokens)
118
- text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
119
- return text_embeddings
 
120
 
121
  def predict_noise(self,
122
  iteration,
 
107
  return latents
108
 
109
  def get_cond_and_uncond_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
110
+ assert n_imgs == 1
111
  text_tokens = self.text_tokenize(prompts)
112
  text_embeddings = self.text_encode(text_tokens)
113
  if negative_prompts is None:
 
116
  negative_prompts.append("")
117
  unconditional_tokens = self.text_tokenize(negative_prompts)
118
  unconditional_embeddings = self.text_encode(unconditional_tokens)
119
+ combined_embeddings = [torch.cat([unconditional_embeddings[i:i+1], text_embeddings[i:i+1]]) for i in range(len(prompts))]
120
+ combined_embeddings = torch.cat(combined_embeddings)
121
+ return combined_embeddings
122
 
123
  def predict_noise(self,
124
  iteration,
train.py CHANGED
@@ -34,7 +34,11 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
34
 
35
  nsteps=50
36
  num_validation_prompts = validation_embeddings.shape[0] // 2
37
- for i in range(0, num_validation_prompts):
 
 
 
 
38
  accumulated_loss = None
39
  this_validation_embeddings = validation_embeddings[i*2:i*2+2]
40
  for j in range(val_count):
@@ -51,10 +55,14 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
51
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
52
  accumulated_loss = (accumulated_loss or 0) + loss.item()
53
  logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
 
54
 
55
  num_samples = sample_embeddings.shape[0] // 2
56
- for i in range(0, num_samples):
57
  print(f'making sample {i}...')
 
 
 
58
  with finetuner:
59
  pipeline = StableDiffusionPipeline(vae=diffuser.vae,
60
  text_encoder=diffuser.text_encoder,
@@ -93,6 +101,8 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
93
  neutral_latents = None
94
  positive_latents = None
95
 
 
 
96
  nsteps = 50
97
  print(f"using img_size of {img_size}")
98
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
@@ -137,13 +147,19 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
137
  seed = random.randint(0, 2 ** 30)
138
  set_seed(int(seed))
139
 
 
 
 
 
 
 
140
  prev_losses = []
141
  start_loss = None
142
  max_prev_loss_count = 10
143
  try:
144
  for i in pbar:
145
  if training_should_cancel:
146
- print("received cancellation request")
147
  return None
148
 
149
  with torch.no_grad():
 
34
 
35
  nsteps=50
36
  num_validation_prompts = validation_embeddings.shape[0] // 2
37
+
38
+ for i in tqdm(range(num_validation_prompts))
39
+ if training_should_cancel:
40
+ print("cancel requested, bailing")
41
+ return
42
  accumulated_loss = None
43
  this_validation_embeddings = validation_embeddings[i*2:i*2+2]
44
  for j in range(val_count):
 
55
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
56
  accumulated_loss = (accumulated_loss or 0) + loss.item()
57
  logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
58
+ pbar.step()
59
 
60
  num_samples = sample_embeddings.shape[0] // 2
61
+ for i in tqdm(range(0, num_samples));
62
  print(f'making sample {i}...')
63
+ if training_should_cancel:
64
+ print("cancel requested, bailing")
65
+ return
66
  with finetuner:
67
  pipeline = StableDiffusionPipeline(vae=diffuser.vae,
68
  text_encoder=diffuser.text_encoder,
 
101
  neutral_latents = None
102
  positive_latents = None
103
 
104
+ global training_should_cancel
105
+
106
  nsteps = 50
107
  print(f"using img_size of {img_size}")
108
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
 
147
  seed = random.randint(0, 2 ** 30)
148
  set_seed(int(seed))
149
 
150
+ validate(diffuser, finetuner,
151
+ validation_embeddings=validation_embeddings,
152
+ sample_embeddings=sample_embeddings,
153
+ neutral_embeddings=neutral_text_embeddings,
154
+ logger=logger, use_amp=False, global_step=0)
155
+
156
  prev_losses = []
157
  start_loss = None
158
  max_prev_loss_count = 10
159
  try:
160
  for i in pbar:
161
  if training_should_cancel:
162
+ print("cancel requested, bailing")
163
  return None
164
 
165
  with torch.no_grad():