Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
·
ab11bdd
1
Parent(s):
fc73e59
actually use AMP=3x speedup
Browse files- StableDiffuser.py +4 -8
- app.py +23 -12
- train.py +13 -12
StableDiffuser.py
CHANGED
|
@@ -4,7 +4,6 @@ import torch
|
|
| 4 |
from baukit import TraceDict
|
| 5 |
from diffusers import StableDiffusionPipeline
|
| 6 |
from PIL import Image
|
| 7 |
-
from torch.cuda.amp import GradScaler
|
| 8 |
from tqdm.auto import tqdm
|
| 9 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
| 10 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
@@ -35,6 +34,7 @@ class StableDiffuser(torch.nn.Module):
|
|
| 35 |
|
| 36 |
def __init__(self,
|
| 37 |
scheduler='LMS',
|
|
|
|
| 38 |
repo_id_or_path="CompVis/stable-diffusion-v1-4"):
|
| 39 |
|
| 40 |
super().__init__()
|
|
@@ -46,6 +46,7 @@ class StableDiffuser(torch.nn.Module):
|
|
| 46 |
self.tokenizer = self.pipeline.tokenizer
|
| 47 |
self.text_encoder = self.pipeline.text_encoder
|
| 48 |
self.safety_checker = self.pipeline.safety_checker
|
|
|
|
| 49 |
|
| 50 |
if scheduler == 'LMS':
|
| 51 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
|
@@ -55,10 +56,8 @@ class StableDiffuser(torch.nn.Module):
|
|
| 55 |
self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
|
| 56 |
|
| 57 |
self.eval()
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def feature_extractor(self):
|
| 61 |
-
return self.pipeline.feature_extractor
|
| 62 |
|
| 63 |
def get_noise(self, batch_size, width, height, generator=None):
|
| 64 |
param = list(self.parameters())[0]
|
|
@@ -226,9 +225,6 @@ class StableDiffuser(torch.nn.Module):
|
|
| 226 |
|
| 227 |
return images_steps
|
| 228 |
|
| 229 |
-
def save_pretrained(self, path, **kwargs):
|
| 230 |
-
self.pipeline.save_pretrained(path, **kwargs)
|
| 231 |
-
|
| 232 |
|
| 233 |
if __name__ == '__main__':
|
| 234 |
|
|
|
|
| 4 |
from baukit import TraceDict
|
| 5 |
from diffusers import StableDiffusionPipeline
|
| 6 |
from PIL import Image
|
|
|
|
| 7 |
from tqdm.auto import tqdm
|
| 8 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
| 9 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
|
|
| 34 |
|
| 35 |
def __init__(self,
|
| 36 |
scheduler='LMS',
|
| 37 |
+
keep_pipeline=False,
|
| 38 |
repo_id_or_path="CompVis/stable-diffusion-v1-4"):
|
| 39 |
|
| 40 |
super().__init__()
|
|
|
|
| 46 |
self.tokenizer = self.pipeline.tokenizer
|
| 47 |
self.text_encoder = self.pipeline.text_encoder
|
| 48 |
self.safety_checker = self.pipeline.safety_checker
|
| 49 |
+
self.feature_extractor = self.pipeline.feature_extractor
|
| 50 |
|
| 51 |
if scheduler == 'LMS':
|
| 52 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
|
|
|
| 56 |
self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
|
| 57 |
|
| 58 |
self.eval()
|
| 59 |
+
if not keep_pipeline:
|
| 60 |
+
del self.pipeline
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def get_noise(self, batch_size, width, height, generator=None):
|
| 63 |
param = list(self.parameters())[0]
|
|
|
|
| 225 |
|
| 226 |
return images_steps
|
| 227 |
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
if __name__ == '__main__':
|
| 230 |
|
app.py
CHANGED
|
@@ -162,9 +162,9 @@ class Demo:
|
|
| 162 |
info="Prompt corresponding to concept to erase"
|
| 163 |
)
|
| 164 |
|
| 165 |
-
choices = ['ESD-x', 'ESD-self']
|
| 166 |
-
if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40 or is_xformers_available():
|
| 167 |
-
|
| 168 |
|
| 169 |
self.train_method_input = gr.Dropdown(
|
| 170 |
choices=choices,
|
|
@@ -274,7 +274,7 @@ class Demo:
|
|
| 274 |
self.train_use_amp_input,
|
| 275 |
#self.train_use_gradient_checkpointing_input
|
| 276 |
],
|
| 277 |
-
outputs=[self.train_button,
|
| 278 |
)
|
| 279 |
self.export_button.click(self.export, inputs = [
|
| 280 |
self.model_dropdown_export,
|
|
@@ -286,12 +286,19 @@ class Demo:
|
|
| 286 |
)
|
| 287 |
|
| 288 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
| 289 |
-
use_adamw8bit=True, use_xformers=
|
| 290 |
pbar = gr.Progress(track_tqdm=True)):
|
| 291 |
|
| 292 |
if self.training:
|
| 293 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
if train_method == 'ESD-x':
|
| 296 |
modules = ".*attn2$"
|
| 297 |
frozen = []
|
|
@@ -319,20 +326,24 @@ class Demo:
|
|
| 319 |
new_model_name = f'*new* {os.path.basename(save_path)}'
|
| 320 |
model_map[new_model_name] = save_path
|
| 321 |
|
| 322 |
-
return [gr.update(interactive=True, value='Train'),
|
| 323 |
-
'Try your model ({new_model_name}) in the "Test" tab'),
|
|
|
|
| 324 |
gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
|
| 325 |
|
| 326 |
def export(self, model_name, base_repo_id_or_path, save_path, save_half):
|
| 327 |
model_path = model_map[model_name]
|
| 328 |
checkpoint = torch.load(model_path)
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
| 331 |
with finetuner:
|
| 332 |
if save_half:
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
|
| 337 |
|
| 338 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
|
|
|
| 162 |
info="Prompt corresponding to concept to erase"
|
| 163 |
)
|
| 164 |
|
| 165 |
+
choices = ['ESD-x', 'ESD-self', 'ESD-u']
|
| 166 |
+
#if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40 or is_xformers_available():
|
| 167 |
+
# choices.append('ESD-u')
|
| 168 |
|
| 169 |
self.train_method_input = gr.Dropdown(
|
| 170 |
choices=choices,
|
|
|
|
| 274 |
self.train_use_amp_input,
|
| 275 |
#self.train_use_gradient_checkpointing_input
|
| 276 |
],
|
| 277 |
+
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 278 |
)
|
| 279 |
self.export_button.click(self.export, inputs = [
|
| 280 |
self.model_dropdown_export,
|
|
|
|
| 286 |
)
|
| 287 |
|
| 288 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
| 289 |
+
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
| 290 |
pbar = gr.Progress(track_tqdm=True)):
|
| 291 |
|
| 292 |
if self.training:
|
| 293 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
| 294 |
|
| 295 |
+
print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
|
| 296 |
+
print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
|
| 297 |
+
print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
|
| 298 |
+
print(f" {'✅' if use_amp else '❌'} AMP")
|
| 299 |
+
print(f" {'✅' if use_xformers else '❌'} xformers")
|
| 300 |
+
print(f" {'✅' if use_adamw8bit else '❌'} 8-bit AdamW")
|
| 301 |
+
|
| 302 |
if train_method == 'ESD-x':
|
| 303 |
modules = ".*attn2$"
|
| 304 |
frozen = []
|
|
|
|
| 326 |
new_model_name = f'*new* {os.path.basename(save_path)}'
|
| 327 |
model_map[new_model_name] = save_path
|
| 328 |
|
| 329 |
+
return [gr.update(interactive=True, value='Train'),
|
| 330 |
+
gr.update(value=f'Done Training! Try your model ({new_model_name}) in the "Test" tab'),
|
| 331 |
+
save_path,
|
| 332 |
gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
|
| 333 |
|
| 334 |
def export(self, model_name, base_repo_id_or_path, save_path, save_half):
|
| 335 |
model_path = model_map[model_name]
|
| 336 |
checkpoint = torch.load(model_path)
|
| 337 |
+
diffuser = StableDiffuser(scheduler='DDIM',
|
| 338 |
+
keep_pipeline=True,
|
| 339 |
+
repo_id_or_path=base_repo_id_or_path
|
| 340 |
+
).eval()
|
| 341 |
+
finetuner = FineTunedModel.from_checkpoint(diffuser, checkpoint).eval()
|
| 342 |
with finetuner:
|
| 343 |
if save_half:
|
| 344 |
+
diffuser = diffuser.half()
|
| 345 |
+
diffuser.pipeline.to(torch.float16, torch_device=diffuser.device)
|
| 346 |
+
diffuser.pipeline.save_pretrained(save_path)
|
| 347 |
|
| 348 |
|
| 349 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
train.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from StableDiffuser import StableDiffuser
|
| 2 |
from finetuning import FineTunedModel
|
| 3 |
import torch
|
|
@@ -8,20 +10,17 @@ from memory_efficiency import MemoryEfficiencyWrapper
|
|
| 8 |
|
| 9 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
| 10 |
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
|
| 11 |
-
|
| 12 |
-
nsteps = 50
|
| 13 |
|
|
|
|
| 14 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
| 15 |
|
| 16 |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
|
| 17 |
use_gradient_checkpointing=use_gradient_checkpointing )
|
| 18 |
with memory_efficiency_wrapper:
|
| 19 |
-
|
| 20 |
diffuser.train()
|
| 21 |
-
|
| 22 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
| 23 |
-
|
| 24 |
if use_adamw8bit:
|
|
|
|
| 25 |
import bitsandbytes as bnb
|
| 26 |
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
|
| 27 |
lr=lr,
|
|
@@ -30,13 +29,13 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 30 |
eps=1e-8
|
| 31 |
)
|
| 32 |
else:
|
|
|
|
| 33 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
| 34 |
criteria = torch.nn.MSELoss()
|
| 35 |
|
| 36 |
pbar = tqdm(range(iterations))
|
| 37 |
|
| 38 |
with torch.no_grad():
|
| 39 |
-
|
| 40 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
| 41 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
| 42 |
|
|
@@ -56,7 +55,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 56 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
| 57 |
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
|
| 58 |
|
| 59 |
-
with finetuner:
|
| 60 |
latents_steps, _ = diffuser.diffusion(
|
| 61 |
latents,
|
| 62 |
positive_text_embeddings,
|
|
@@ -67,19 +66,21 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 67 |
)
|
| 68 |
|
| 69 |
diffuser.set_scheduler_timesteps(1000)
|
| 70 |
-
|
| 71 |
iteration = int(iteration / nsteps * 1000)
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
with finetuner:
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
positive_latents.requires_grad = False
|
| 80 |
neutral_latents.requires_grad = False
|
| 81 |
|
| 82 |
-
|
|
|
|
| 83 |
memory_efficiency_wrapper.step(optimizer, loss)
|
| 84 |
optimizer.step()
|
| 85 |
|
|
|
|
| 1 |
+
from torch.cuda.amp import autocast
|
| 2 |
+
|
| 3 |
from StableDiffuser import StableDiffuser
|
| 4 |
from finetuning import FineTunedModel
|
| 5 |
import torch
|
|
|
|
| 10 |
|
| 11 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
| 12 |
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
nsteps = 50
|
| 15 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
| 16 |
|
| 17 |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
|
| 18 |
use_gradient_checkpointing=use_gradient_checkpointing )
|
| 19 |
with memory_efficiency_wrapper:
|
|
|
|
| 20 |
diffuser.train()
|
|
|
|
| 21 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
|
|
|
| 22 |
if use_adamw8bit:
|
| 23 |
+
use print("using AdamW 8Bit optimizer")
|
| 24 |
import bitsandbytes as bnb
|
| 25 |
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
|
| 26 |
lr=lr,
|
|
|
|
| 29 |
eps=1e-8
|
| 30 |
)
|
| 31 |
else:
|
| 32 |
+
print("using Adam optimizer")
|
| 33 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
| 34 |
criteria = torch.nn.MSELoss()
|
| 35 |
|
| 36 |
pbar = tqdm(range(iterations))
|
| 37 |
|
| 38 |
with torch.no_grad():
|
|
|
|
| 39 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
| 40 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
| 41 |
|
|
|
|
| 55 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
| 56 |
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
|
| 57 |
|
| 58 |
+
with autocast(enabled=use_amp), finetuner:
|
| 59 |
latents_steps, _ = diffuser.diffusion(
|
| 60 |
latents,
|
| 61 |
positive_text_embeddings,
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
diffuser.set_scheduler_timesteps(1000)
|
|
|
|
| 69 |
iteration = int(iteration / nsteps * 1000)
|
| 70 |
|
| 71 |
+
with autocast(enabled=use_amp):
|
| 72 |
+
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
| 73 |
+
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
|
| 74 |
|
| 75 |
with finetuner:
|
| 76 |
+
with autocast(enabled=use_amp):
|
| 77 |
+
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
| 78 |
|
| 79 |
positive_latents.requires_grad = False
|
| 80 |
neutral_latents.requires_grad = False
|
| 81 |
|
| 82 |
+
# loss = criteria(e_n, e_0) works the best try 5000 epochs
|
| 83 |
+
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
|
| 84 |
memory_efficiency_wrapper.step(optimizer, loss)
|
| 85 |
optimizer.step()
|
| 86 |
|