from StableDiffuser import StableDiffuser from finetuning import FineTunedModel import torch from tqdm import tqdm def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path): nsteps = 50 diffuser = StableDiffuser(scheduler='DDIM').to('cuda') diffuser.train() finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) criteria = torch.nn.MSELoss() pbar = tqdm(range(iterations)) with torch.no_grad(): neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1) positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1) del diffuser.vae del diffuser.text_encoder del diffuser.tokenizer torch.cuda.empty_cache() for i in pbar: with torch.no_grad(): diffuser.set_scheduler_timesteps(nsteps) optimizer.zero_grad() iteration = torch.randint(1, nsteps - 1, (1,)).item() latents = diffuser.get_initial_latents(1, 512, 1) with finetuner: latents_steps, _ = diffuser.diffusion( latents, positive_text_embeddings, start_iteration=0, end_iteration=iteration, guidance_scale=3, show_progress=False ) diffuser.set_scheduler_timesteps(1000) iteration = int(iteration / nsteps * 1000) positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1) with finetuner: negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) positive_latents.requires_grad = False neutral_latents.requires_grad = False loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs loss.backward() optimizer.step() torch.save(finetuner.state_dict(), save_path) del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents torch.cuda.empty_cache() if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--prompt', required=True) parser.add_argument('--modules', required=True) parser.add_argument('--freeze_modules', nargs='+', required=True) parser.add_argument('--save_path', required=True) parser.add_argument('--iterations', type=int, required=True) parser.add_argument('--lr', type=float, required=True) parser.add_argument('--negative_guidance', type=float, required=True) train(**vars(parser.parse_args()))