|
from omegaconf import OmegaConf |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
import os |
|
from tqdm import tqdm |
|
import numpy as np |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
import wandb |
|
|
|
import sys |
|
sys.path.append('.') |
|
from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler |
|
from stable_diffusion.ldm.util import instantiate_from_config |
|
import random |
|
import argparse |
|
|
|
def load_model_from_config(config, ckpt, device="cpu", verbose=False): |
|
"""Loads a model from config and a ckpt |
|
if config is a path will use omegaconf to load |
|
""" |
|
if isinstance(config, (str, Path)): |
|
config = OmegaConf.load(config) |
|
|
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
global_step = pl_sd["global_step"] |
|
sd = pl_sd["state_dict"] |
|
model = instantiate_from_config(config.model) |
|
m, u = model.load_state_dict(sd, strict=False) |
|
model.to(device) |
|
model.eval() |
|
model.cond_stage_model.device = device |
|
return model |
|
|
|
@torch.no_grad() |
|
def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): |
|
"""Sample the model""" |
|
uc = None |
|
if scale != 1.0: |
|
uc = model.get_learned_conditioning(n_samples * [""]) |
|
log_t = 100 |
|
if log_every_t is not None: |
|
log_t = log_every_t |
|
shape = [4, h // 8, w // 8] |
|
samples_ddim, inters = sampler.sample(S=ddim_steps, |
|
conditioning=c, |
|
batch_size=n_samples, |
|
shape=shape, |
|
verbose=False, |
|
x_T=start_code, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=uc, |
|
eta=ddim_eta, |
|
verbose_iter = verbose, |
|
t_start=t_start, |
|
log_every_t = log_t, |
|
till_T = till_T |
|
) |
|
if log_every_t is not None: |
|
return samples_ddim, inters |
|
return samples_ddim |
|
|
|
def load_img(path, target_size=512): |
|
"""Load an image, resize and output -1..1""" |
|
image = Image.open(path).convert("RGB") |
|
|
|
|
|
tform = transforms.Compose([ |
|
transforms.Resize(target_size), |
|
transforms.CenterCrop(target_size), |
|
transforms.ToTensor(), |
|
]) |
|
image = tform(image) |
|
return 2.*image - 1. |
|
|
|
|
|
def moving_average(a, n=3) : |
|
ret = np.cumsum(a, dtype=float) |
|
ret[n:] = ret[n:] - ret[:-n] |
|
return ret[n - 1:] / n |
|
|
|
def plot_loss(losses, path,word, n=100): |
|
v = moving_average(losses, n) |
|
plt.plot(v, label=f'{word}_loss') |
|
plt.legend(loc="upper left") |
|
plt.title('Average loss in trainings', fontsize=20) |
|
plt.xlabel('Data point', fontsize=16) |
|
plt.ylabel('Loss value', fontsize=16) |
|
plt.savefig(path) |
|
|
|
|
|
def get_models(config_path, ckpt_path, devices): |
|
model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) |
|
sampler_orig = DDIMSampler(model_orig) |
|
|
|
model = load_model_from_config(config_path, ckpt_path, devices[0]) |
|
sampler = DDIMSampler(model) |
|
|
|
return model_orig, sampler_orig, model, sampler |
|
|
|
def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, devices, output_name, seperator=None, image_size=512, ddim_steps=50): |
|
''' |
|
Function to train diffusion models to erase concepts from model weights |
|
|
|
Parameters |
|
---------- |
|
prompt : str |
|
The concept to erase from diffusion model (Eg: "Van Gogh"). |
|
train_method : str |
|
The parameters to train for erasure (ESD-x, ESD-u, full, selfattn). |
|
start_guidance : float |
|
Guidance to generate images for training. |
|
negative_guidance : float |
|
Guidance to erase the concepts from diffusion model. |
|
iterations : int |
|
Number of iterations to train. |
|
lr : float |
|
learning rate for fine tuning. |
|
config_path : str |
|
config path for compvis diffusion format. |
|
ckpt_path : str |
|
checkpoint path for pre-trained compvis diffusion weights. |
|
devices : str |
|
2 devices used to load the models (Eg: '0,1' will load in cuda:0 and cuda:1). |
|
seperator : str, optional |
|
If the prompt has commas can use this to seperate the prompt for individual simulataneous erasures. The default is None. |
|
image_size : int, optional |
|
Image size for generated images. The default is 512. |
|
ddim_steps : int, optional |
|
Number of diffusion time steps. The default is 50. |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
''' |
|
|
|
word_print = prompt.replace(' ','') |
|
if prompt == 'allartist': |
|
prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng" |
|
if prompt == 'i2p': |
|
prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood" |
|
if prompt == "artifact": |
|
prompt = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy" |
|
|
|
if seperator is not None: |
|
words = prompt.split(seperator) |
|
words = [word.strip() for word in words] |
|
else: |
|
words = [prompt] |
|
print(words) |
|
ddim_eta = 0 |
|
|
|
|
|
model_orig, sampler_orig, model, sampler = get_models(config_path, ckpt_path, devices) |
|
|
|
|
|
parameters = [] |
|
for name, param in model.model.diffusion_model.named_parameters(): |
|
|
|
if train_method == 'noxattn': |
|
if name.startswith('out.') or 'attn2' in name or 'time_embed' in name: |
|
pass |
|
else: |
|
|
|
parameters.append(param) |
|
|
|
if train_method == 'selfattn': |
|
if 'attn1' in name: |
|
|
|
parameters.append(param) |
|
|
|
if train_method == 'xattn': |
|
if 'attn2' in name: |
|
|
|
parameters.append(param) |
|
|
|
if train_method == 'full': |
|
|
|
parameters.append(param) |
|
|
|
if train_method == 'notime': |
|
if not (name.startswith('out.') or 'time_embed' in name): |
|
|
|
parameters.append(param) |
|
if train_method == 'xlayer': |
|
if 'attn2' in name: |
|
if 'output_blocks.6.' in name or 'output_blocks.8.' in name: |
|
|
|
parameters.append(param) |
|
if train_method == 'selflayer': |
|
if 'attn1' in name: |
|
if 'input_blocks.4.' in name or 'input_blocks.7.' in name: |
|
|
|
parameters.append(param) |
|
|
|
model.train() |
|
|
|
quick_sample_till_t = lambda x, s, code, t: sample_model(model, sampler, |
|
x, image_size, image_size, ddim_steps, s, ddim_eta, |
|
start_code=code, till_T=t, verbose=False) |
|
|
|
opt = torch.optim.Adam(parameters, lr=lr) |
|
criteria = torch.nn.MSELoss() |
|
|
|
|
|
|
|
pbar = tqdm(range(iterations)) |
|
for _ in pbar: |
|
word = random.sample(words,1)[0] |
|
|
|
emb_0 = model.get_learned_conditioning(['']) |
|
emb_p = model.get_learned_conditioning([word]) |
|
emb_n = model.get_learned_conditioning([f'{word}']) |
|
|
|
opt.zero_grad() |
|
|
|
t_enc = torch.randint(ddim_steps, (1,), device=devices[0]) |
|
|
|
og_num = round((int(t_enc)/ddim_steps)*1000) |
|
og_num_lim = round((int(t_enc+1)/ddim_steps)*1000) |
|
|
|
t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0]) |
|
|
|
start_code = torch.randn((1, 4, 64, 64)).to(devices[0]) |
|
|
|
with torch.no_grad(): |
|
|
|
z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, int(t_enc)) |
|
|
|
e_0 = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_0.to(devices[1])) |
|
e_p = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_p.to(devices[1])) |
|
|
|
|
|
e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0])) |
|
e_0.requires_grad = False |
|
e_p.requires_grad = False |
|
|
|
loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) |
|
|
|
loss.backward() |
|
pbar.set_postfix({"loss": loss.item()}) |
|
opt.step() |
|
|
|
model.eval() |
|
torch.save({"state_dict": model.state_dict()}, output_name) |
|
|
|
def save_history(losses, name, word_print): |
|
folder_path = f'models/{name}' |
|
os.makedirs(folder_path, exist_ok=True) |
|
with open(f'{folder_path}/loss.txt', 'w') as f: |
|
f.writelines([str(i) for i in losses]) |
|
plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3) |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser( |
|
prog = 'TrainESD', |
|
description = 'Finetuning stable diffusion model to erase concepts using ESD method') |
|
parser.add_argument('--train_method', help='method of training', type=str, default='noxattn', choices=['xattn','noxattn', 'selfattn', 'full']) |
|
parser.add_argument('--start_guidance', help='guidance of start image used to train', type=float, required=False, default=3) |
|
parser.add_argument('--negative_guidance', help='guidance of negative training used to train', type=float, required=False, default=1) |
|
parser.add_argument('--iterations', help='iterations used to train', type=int, required=False, default=1000) |
|
parser.add_argument('--lr', help='learning rate used to train', type=int, required=False, default=1e-5) |
|
parser.add_argument('--config_path', help='config path for stable diffusion v1-4 inference', type=str, required=False, default='configs/train_esd.yaml') |
|
parser.add_argument('--ckpt_path', help='ckpt path for stable diffusion v1-4', type=str, required=True) |
|
parser.add_argument('--devices', help='cuda devices to train on', type=str, required=False, default='0,0') |
|
parser.add_argument('--seperator', help='separator if you want to train bunch of words separately', type=str, required=False, default=None) |
|
parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512) |
|
parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50) |
|
parser.add_argument('--output_dir', help='output directory to save results', type=str, required=False, default='results/style50') |
|
parser.add_argument('--object_class', type=str, required=True) |
|
parser.add_argument('--dry-run', action='store_true', help='dry run') |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
output_name = f'{args.output_dir}/{args.object_class}.pth' |
|
print(f"Saving the model to {output_name}") |
|
|
|
prompt = f'An image of {args.object_class}.' |
|
print(f"Prompt for unlearning: {prompt}") |
|
train_method = args.train_method |
|
start_guidance = args.start_guidance |
|
negative_guidance = args.negative_guidance |
|
iterations = args.iterations |
|
lr = args.lr |
|
config_path = args.config_path |
|
ckpt_path = args.ckpt_path |
|
devices = [f'cuda:{int(d.strip())}' for d in args.devices.split(',')] |
|
seperator = args.seperator |
|
image_size = args.image_size |
|
ddim_steps = args.ddim_steps |
|
|
|
train_esd(prompt=prompt, train_method=train_method, start_guidance=start_guidance, negative_guidance=negative_guidance, iterations=iterations, lr=lr, config_path=config_path, ckpt_path=ckpt_path, devices=devices, seperator=seperator, image_size=image_size, ddim_steps=ddim_steps, output_name=output_name) |
|
|