Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from diffusers import AutoencoderKL | |
from torch.optim import lr_scheduler | |
from ..utils.core import debug, find, info, warn | |
from ..utils.typing import * | |
"""Diffusers Model Utils""" | |
def vae_encode( | |
vae: AutoencoderKL, | |
pixel_values: Float[Tensor, "B 3 H W"], | |
sample: bool = True, | |
apply_scale: bool = True, | |
): | |
latent_dist = vae.encode(pixel_values).latent_dist | |
latents = latent_dist.sample() if sample else latent_dist.mode() | |
if apply_scale: | |
latents = latents * vae.config.scaling_factor | |
return latents | |
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt | |
def encode_prompt( | |
prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True | |
): | |
prompt_embeds_list = [] | |
captions = [] | |
for caption in prompt_batch: | |
if random.random() < proportion_empty_prompts: | |
captions.append("") | |
elif isinstance(caption, str): | |
captions.append(caption) | |
elif isinstance(caption, (list, np.ndarray)): | |
# take a random caption if there are multiple | |
captions.append(random.choice(caption) if is_train else caption[0]) | |
with torch.no_grad(): | |
for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
text_inputs = tokenizer( | |
captions, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder( | |
text_input_ids.to(text_encoder.device), | |
output_hidden_states=True, | |
) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
pooled_prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
prompt_embeds_list.append(prompt_embeds) | |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) | |
return prompt_embeds, pooled_prompt_embeds | |
CLIP_INPUT_MEAN = torch.as_tensor( | |
[0.48145466, 0.4578275, 0.40821073], dtype=torch.float32 | |
)[None, :, None, None] | |
CLIP_INPUT_STD = torch.as_tensor( | |
[0.26862954, 0.26130258, 0.27577711], dtype=torch.float32 | |
)[None, :, None, None] | |
def normalize_image_for_clip(image: Float[Tensor, "B C H W"]): | |
return (image - CLIP_INPUT_MEAN.to(image)) / CLIP_INPUT_STD.to(image) | |
"""Training""" | |
def get_scheduler(name): | |
if hasattr(lr_scheduler, name): | |
return getattr(lr_scheduler, name) | |
else: | |
raise NotImplementedError | |
def getattr_recursive(m, attr): | |
for name in attr.split("."): | |
m = getattr(m, name) | |
return m | |
def get_parameters(model, name): | |
module = getattr_recursive(model, name) | |
if isinstance(module, nn.Module): | |
return module.parameters() | |
elif isinstance(module, nn.Parameter): | |
return module | |
return [] | |
def parse_optimizer(config, model): | |
if hasattr(config, "params"): | |
params = [ | |
{"params": get_parameters(model, name), "name": name, **args} | |
for name, args in config.params.items() | |
] | |
debug(f"Specify optimizer params: {config.params}") | |
else: | |
params = model.parameters() | |
if config.name in ["FusedAdam"]: | |
import apex | |
optim = getattr(apex.optimizers, config.name)(params, **config.args) | |
elif config.name in ["Adam8bit", "AdamW8bit"]: | |
import bitsandbytes as bnb | |
optim = bnb.optim.Adam8bit(params, **config.args) | |
else: | |
optim = getattr(torch.optim, config.name)(params, **config.args) | |
return optim | |
def parse_scheduler_to_instance(config, optimizer): | |
if config.name == "ChainedScheduler": | |
schedulers = [ | |
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers | |
] | |
scheduler = lr_scheduler.ChainedScheduler(schedulers) | |
elif config.name == "Sequential": | |
schedulers = [ | |
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers | |
] | |
scheduler = lr_scheduler.SequentialLR( | |
optimizer, schedulers, milestones=config.milestones | |
) | |
else: | |
scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) | |
return scheduler | |
def parse_scheduler(config, optimizer): | |
interval = config.get("interval", "epoch") | |
assert interval in ["epoch", "step"] | |
if config.name == "SequentialLR": | |
scheduler = { | |
"scheduler": lr_scheduler.SequentialLR( | |
optimizer, | |
[ | |
parse_scheduler(conf, optimizer)["scheduler"] | |
for conf in config.schedulers | |
], | |
milestones=config.milestones, | |
), | |
"interval": interval, | |
} | |
elif config.name == "ChainedScheduler": | |
scheduler = { | |
"scheduler": lr_scheduler.ChainedScheduler( | |
[ | |
parse_scheduler(conf, optimizer)["scheduler"] | |
for conf in config.schedulers | |
] | |
), | |
"interval": interval, | |
} | |
else: | |
scheduler = { | |
"scheduler": get_scheduler(config.name)(optimizer, **config.args), | |
"interval": interval, | |
} | |
return scheduler | |