from omegaconf import OmegaConf import torch from PIL import Image from torchvision import transforms import os from tqdm import tqdm import numpy as np from pathlib import Path import matplotlib.pyplot as plt import wandb import sys sys.path.append('.') from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler from stable_diffusion.ldm.util import instantiate_from_config import random import argparse def load_model_from_config(config, ckpt, device="cpu", verbose=False): """Loads a model from config and a ckpt if config is a path will use omegaconf to load """ if isinstance(config, (str, Path)): config = OmegaConf.load(config) pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) model.to(device) model.eval() model.cond_stage_model.device = device return model @torch.no_grad() def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): """Sample the model""" uc = None if scale != 1.0: uc = model.get_learned_conditioning(n_samples * [""]) log_t = 100 if log_every_t is not None: log_t = log_every_t shape = [4, h // 8, w // 8] samples_ddim, inters = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, x_T=start_code, unconditional_guidance_scale=scale, unconditional_conditioning=uc, eta=ddim_eta, verbose_iter = verbose, t_start=t_start, log_every_t = log_t, till_T = till_T ) if log_every_t is not None: return samples_ddim, inters return samples_ddim def load_img(path, target_size=512): """Load an image, resize and output -1..1""" image = Image.open(path).convert("RGB") tform = transforms.Compose([ transforms.Resize(target_size), transforms.CenterCrop(target_size), transforms.ToTensor(), ]) image = tform(image) return 2.*image - 1. def moving_average(a, n=3) : ret = np.cumsum(a, dtype=float) ret[n:] = ret[n:] - ret[:-n] return ret[n - 1:] / n def plot_loss(losses, path,word, n=100): v = moving_average(losses, n) plt.plot(v, label=f'{word}_loss') plt.legend(loc="upper left") plt.title('Average loss in trainings', fontsize=20) plt.xlabel('Data point', fontsize=16) plt.ylabel('Loss value', fontsize=16) plt.savefig(path) ##################### ESD Functions def get_models(config_path, ckpt_path, devices): model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) sampler_orig = DDIMSampler(model_orig) model = load_model_from_config(config_path, ckpt_path, devices[0]) sampler = DDIMSampler(model) return model_orig, sampler_orig, model, sampler def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, devices, output_name, seperator=None, image_size=512, ddim_steps=50): ''' Function to train diffusion models to erase concepts from model weights Parameters ---------- prompt : str The concept to erase from diffusion model (Eg: "Van Gogh"). train_method : str The parameters to train for erasure (ESD-x, ESD-u, full, selfattn). start_guidance : float Guidance to generate images for training. negative_guidance : float Guidance to erase the concepts from diffusion model. iterations : int Number of iterations to train. lr : float learning rate for fine tuning. config_path : str config path for compvis diffusion format. ckpt_path : str checkpoint path for pre-trained compvis diffusion weights. devices : str 2 devices used to load the models (Eg: '0,1' will load in cuda:0 and cuda:1). seperator : str, optional If the prompt has commas can use this to seperate the prompt for individual simulataneous erasures. The default is None. image_size : int, optional Image size for generated images. The default is 512. ddim_steps : int, optional Number of diffusion time steps. The default is 50. Returns ------- None ''' # PROMPT CLEANING word_print = prompt.replace(' ','') if prompt == 'allartist': prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng" if prompt == 'i2p': prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood" if prompt == "artifact": prompt = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy" if seperator is not None: words = prompt.split(seperator) words = [word.strip() for word in words] else: words = [prompt] print(words) ddim_eta = 0 # MODEL TRAINING SETUP model_orig, sampler_orig, model, sampler = get_models(config_path, ckpt_path, devices) # choose parameters to train based on train_method parameters = [] for name, param in model.model.diffusion_model.named_parameters(): # train all layers except x-attns and time_embed layers if train_method == 'noxattn': if name.startswith('out.') or 'attn2' in name or 'time_embed' in name: pass else: # print(name) parameters.append(param) # train only self attention layers if train_method == 'selfattn': if 'attn1' in name: # print(name) parameters.append(param) # train only x attention layers if train_method == 'xattn': if 'attn2' in name: # print(name) parameters.append(param) # train all layers if train_method == 'full': # print(name) parameters.append(param) # train all layers except time embed layers if train_method == 'notime': if not (name.startswith('out.') or 'time_embed' in name): # print(name) parameters.append(param) if train_method == 'xlayer': if 'attn2' in name: if 'output_blocks.6.' in name or 'output_blocks.8.' in name: # print(name) parameters.append(param) if train_method == 'selflayer': if 'attn1' in name: if 'input_blocks.4.' in name or 'input_blocks.7.' in name: # print(name) parameters.append(param) # set model to train model.train() # create a lambda function for cleaner use of sampling code (only denoising till time step t) quick_sample_till_t = lambda x, s, code, t: sample_model(model, sampler, x, image_size, image_size, ddim_steps, s, ddim_eta, start_code=code, till_T=t, verbose=False) opt = torch.optim.Adam(parameters, lr=lr) criteria = torch.nn.MSELoss() # name = f'compvis-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{negative_guidance}-iter_{iterations}-lr_{lr}' # TRAINING CODE pbar = tqdm(range(iterations)) for _ in pbar: word = random.sample(words,1)[0] # get text embeddings for unconditional and conditional prompts emb_0 = model.get_learned_conditioning(['']) emb_p = model.get_learned_conditioning([word]) emb_n = model.get_learned_conditioning([f'{word}']) opt.zero_grad() t_enc = torch.randint(ddim_steps, (1,), device=devices[0]) # time step from 1000 to 0 (0 being good) og_num = round((int(t_enc)/ddim_steps)*1000) og_num_lim = round((int(t_enc+1)/ddim_steps)*1000) t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0]) start_code = torch.randn((1, 4, 64, 64)).to(devices[0]) with torch.no_grad(): # generate an image with the concept from ESD model z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, int(t_enc)) # emb_p seems to work better instead of emb_0 # get conditional and unconditional scores from frozen model at time step t and image z e_0 = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_0.to(devices[1])) e_p = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_p.to(devices[1])) # breakpoint() # get conditional score from ESD model e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0])) e_0.requires_grad = False e_p.requires_grad = False # reconstruction loss for ESD objective from frozen model and conditional score of ESD model loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) #loss = criteria(e_n, e_0) works the best try 5000 epochs # update weights to erase the concept loss.backward() pbar.set_postfix({"loss": loss.item()}) opt.step() model.eval() torch.save({"state_dict": model.state_dict()}, output_name) def save_history(losses, name, word_print): folder_path = f'models/{name}' os.makedirs(folder_path, exist_ok=True) with open(f'{folder_path}/loss.txt', 'w') as f: f.writelines([str(i) for i in losses]) plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3) if __name__ == '__main__': parser = argparse.ArgumentParser( prog = 'TrainESD', description = 'Finetuning stable diffusion model to erase concepts using ESD method') parser.add_argument('--train_method', help='method of training', type=str, default='noxattn', choices=['xattn','noxattn', 'selfattn', 'full']) parser.add_argument('--start_guidance', help='guidance of start image used to train', type=float, required=False, default=3) parser.add_argument('--negative_guidance', help='guidance of negative training used to train', type=float, required=False, default=1) parser.add_argument('--iterations', help='iterations used to train', type=int, required=False, default=1000) parser.add_argument('--lr', help='learning rate used to train', type=int, required=False, default=1e-5) parser.add_argument('--config_path', help='config path for stable diffusion v1-4 inference', type=str, required=False, default='configs/train_esd.yaml') parser.add_argument('--ckpt_path', help='ckpt path for stable diffusion v1-4', type=str, required=True) parser.add_argument('--devices', help='cuda devices to train on', type=str, required=False, default='0,0') parser.add_argument('--seperator', help='separator if you want to train bunch of words separately', type=str, required=False, default=None) parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512) parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50) parser.add_argument('--output_dir', help='output directory to save results', type=str, required=False, default='results/style50') parser.add_argument('--object_class', type=str, required=True) parser.add_argument('--dry-run', action='store_true', help='dry run') args = parser.parse_args() # if not args.dry_run: # wandb.init(project='quick-canvas-machine-unlearning', name=args.object_class, config=args) # else: # wandb = None os.makedirs(args.output_dir, exist_ok=True) output_name = f'{args.output_dir}/{args.object_class}.pth' print(f"Saving the model to {output_name}") prompt = f'An image of {args.object_class}.' print(f"Prompt for unlearning: {prompt}") train_method = args.train_method start_guidance = args.start_guidance negative_guidance = args.negative_guidance iterations = args.iterations lr = args.lr config_path = args.config_path ckpt_path = args.ckpt_path devices = [f'cuda:{int(d.strip())}' for d in args.devices.split(',')] seperator = args.seperator image_size = args.image_size ddim_steps = args.ddim_steps train_esd(prompt=prompt, train_method=train_method, start_guidance=start_guidance, negative_guidance=negative_guidance, iterations=iterations, lr=lr, config_path=config_path, ckpt_path=ckpt_path, devices=devices, seperator=seperator, image_size=image_size, ddim_steps=ddim_steps, output_name=output_name)