ReubenSun's picture
1
2ac1c2d
import random
import numpy as np
import torch
import torch.nn as nn
from diffusers import AutoencoderKL
from torch.optim import lr_scheduler
from ..utils.core import debug, find, info, warn
from ..utils.typing import *
"""Diffusers Model Utils"""
def vae_encode(
vae: AutoencoderKL,
pixel_values: Float[Tensor, "B 3 H W"],
sample: bool = True,
apply_scale: bool = True,
):
latent_dist = vae.encode(pixel_values).latent_dist
latents = latent_dist.sample() if sample else latent_dist.mode()
if apply_scale:
latents = latents * vae.config.scaling_factor
return latents
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True
):
prompt_embeds_list = []
captions = []
for caption in prompt_batch:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
with torch.no_grad():
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(
captions,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
CLIP_INPUT_MEAN = torch.as_tensor(
[0.48145466, 0.4578275, 0.40821073], dtype=torch.float32
)[None, :, None, None]
CLIP_INPUT_STD = torch.as_tensor(
[0.26862954, 0.26130258, 0.27577711], dtype=torch.float32
)[None, :, None, None]
def normalize_image_for_clip(image: Float[Tensor, "B C H W"]):
return (image - CLIP_INPUT_MEAN.to(image)) / CLIP_INPUT_STD.to(image)
"""Training"""
def get_scheduler(name):
if hasattr(lr_scheduler, name):
return getattr(lr_scheduler, name)
else:
raise NotImplementedError
def getattr_recursive(m, attr):
for name in attr.split("."):
m = getattr(m, name)
return m
def get_parameters(model, name):
module = getattr_recursive(model, name)
if isinstance(module, nn.Module):
return module.parameters()
elif isinstance(module, nn.Parameter):
return module
return []
def parse_optimizer(config, model):
if hasattr(config, "params"):
params = [
{"params": get_parameters(model, name), "name": name, **args}
for name, args in config.params.items()
]
debug(f"Specify optimizer params: {config.params}")
else:
params = model.parameters()
if config.name in ["FusedAdam"]:
import apex
optim = getattr(apex.optimizers, config.name)(params, **config.args)
elif config.name in ["Adam8bit", "AdamW8bit"]:
import bitsandbytes as bnb
optim = bnb.optim.Adam8bit(params, **config.args)
else:
optim = getattr(torch.optim, config.name)(params, **config.args)
return optim
def parse_scheduler_to_instance(config, optimizer):
if config.name == "ChainedScheduler":
schedulers = [
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
]
scheduler = lr_scheduler.ChainedScheduler(schedulers)
elif config.name == "Sequential":
schedulers = [
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
]
scheduler = lr_scheduler.SequentialLR(
optimizer, schedulers, milestones=config.milestones
)
else:
scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args)
return scheduler
def parse_scheduler(config, optimizer):
interval = config.get("interval", "epoch")
assert interval in ["epoch", "step"]
if config.name == "SequentialLR":
scheduler = {
"scheduler": lr_scheduler.SequentialLR(
optimizer,
[
parse_scheduler(conf, optimizer)["scheduler"]
for conf in config.schedulers
],
milestones=config.milestones,
),
"interval": interval,
}
elif config.name == "ChainedScheduler":
scheduler = {
"scheduler": lr_scheduler.ChainedScheduler(
[
parse_scheduler(conf, optimizer)["scheduler"]
for conf in config.schedulers
]
),
"interval": interval,
}
else:
scheduler = {
"scheduler": get_scheduler(config.name)(optimizer, **config.args),
"interval": interval,
}
return scheduler