|
from StableDiffuser import StableDiffuser |
|
from finetuning import FineTunedModel |
|
import torch |
|
from tqdm import tqdm |
|
|
|
def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path): |
|
|
|
nsteps = 50 |
|
|
|
diffuser = StableDiffuser(scheduler='DDIM').to('cuda') |
|
diffuser.train() |
|
|
|
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) |
|
|
|
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) |
|
criteria = torch.nn.MSELoss() |
|
|
|
pbar = tqdm(range(iterations)) |
|
|
|
with torch.no_grad(): |
|
|
|
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1) |
|
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1) |
|
|
|
del diffuser.vae |
|
del diffuser.text_encoder |
|
del diffuser.tokenizer |
|
|
|
torch.cuda.empty_cache() |
|
|
|
for i in pbar: |
|
|
|
with torch.no_grad(): |
|
|
|
diffuser.set_scheduler_timesteps(nsteps) |
|
|
|
optimizer.zero_grad() |
|
|
|
iteration = torch.randint(1, nsteps - 1, (1,)).item() |
|
|
|
latents = diffuser.get_initial_latents(1, 512, 1) |
|
|
|
with finetuner: |
|
|
|
latents_steps, _ = diffuser.diffusion( |
|
latents, |
|
positive_text_embeddings, |
|
start_iteration=0, |
|
end_iteration=iteration, |
|
guidance_scale=3, |
|
show_progress=False |
|
) |
|
|
|
diffuser.set_scheduler_timesteps(1000) |
|
|
|
iteration = int(iteration / nsteps * 1000) |
|
|
|
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) |
|
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1) |
|
|
|
with finetuner: |
|
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) |
|
|
|
positive_latents.requires_grad = False |
|
neutral_latents.requires_grad = False |
|
|
|
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
torch.save(finetuner.state_dict(), save_path) |
|
|
|
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents |
|
|
|
torch.cuda.empty_cache() |
|
if __name__ == '__main__': |
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--prompt', required=True) |
|
parser.add_argument('--modules', required=True) |
|
parser.add_argument('--freeze_modules', nargs='+', required=True) |
|
parser.add_argument('--save_path', required=True) |
|
parser.add_argument('--iterations', type=int, required=True) |
|
parser.add_argument('--lr', type=float, required=True) |
|
parser.add_argument('--negative_guidance', type=float, required=True) |
|
|
|
train(**vars(parser.parse_args())) |