ReubenSun's picture
1
2ac1c2d
import os
import random
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from diffusers import DDPMScheduler, UNet2DConditionModel
from diffusers.models import AutoencoderKL
from diffusers.training_utils import compute_snr
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from ..pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline
from ..schedulers.scheduling_shift_snr import ShiftSNRScheduler
from ..utils.core import find
from ..utils.typing import *
from .base import BaseSystem
from .utils import encode_prompt, vae_encode
def compute_embeddings(
prompt_batch,
empty_prompt_indices,
text_encoders,
tokenizers,
is_train=True,
**kwargs,
):
original_size = kwargs["original_size"]
target_size = kwargs["target_size"]
crops_coords_top_left = kwargs["crops_coords_top_left"]
for i in range(empty_prompt_indices.shape[0]):
if empty_prompt_indices[i]:
prompt_batch[i] = ""
prompt_embeds, pooled_prompt_embeds = encode_prompt(
prompt_batch, text_encoders, tokenizers, 0, is_train
)
add_text_embeds = pooled_prompt_embeds.to(
device=prompt_embeds.device, dtype=prompt_embeds.dtype
)
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
add_time_ids = add_time_ids.to(
device=prompt_embeds.device, dtype=prompt_embeds.dtype
)
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
class IG2MVSDXLSystem(BaseSystem):
@dataclass
class Config(BaseSystem.Config):
# Model / Adapter
pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-xl-base-1.0"
pretrained_vae_name_or_path: Optional[str] = "madebyollin/sdxl-vae-fp16-fix"
pretrained_adapter_name_or_path: Optional[str] = None
pretrained_unet_name_or_path: Optional[str] = None
init_adapter_kwargs: Dict[str, Any] = field(default_factory=dict)
use_fp16_vae: bool = True
use_fp16_clip: bool = True
# Training
trainable_modules: List[str] = field(default_factory=list)
train_cond_encoder: bool = True
prompt_drop_prob: float = 0.0
image_drop_prob: float = 0.0
cond_drop_prob: float = 0.0
gradient_checkpointing: bool = False
# Noise sampler
noise_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
noise_offset: float = 0.0
input_perturbation: float = 0.0
snr_gamma: Optional[float] = 5.0
prediction_type: Optional[str] = None
shift_noise: bool = False
shift_noise_mode: str = "interpolated"
shift_noise_scale: float = 1.0
# Evaluation
eval_seed: int = 0
eval_num_inference_steps: int = 30
eval_guidance_scale: float = 1.0
eval_height: int = 512
eval_width: int = 512
cfg: Config
def configure(self):
super().configure()
# Prepare pipeline
pipeline_kwargs = {}
if self.cfg.pretrained_vae_name_or_path is not None:
pipeline_kwargs["vae"] = AutoencoderKL.from_pretrained(
self.cfg.pretrained_vae_name_or_path
)
if self.cfg.pretrained_unet_name_or_path is not None:
pipeline_kwargs["unet"] = UNet2DConditionModel.from_pretrained(
self.cfg.pretrained_unet_name_or_path
)
pipeline: IG2MVSDXLPipeline
pipeline = IG2MVSDXLPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path, **pipeline_kwargs
)
init_adapter_kwargs = OmegaConf.to_container(self.cfg.init_adapter_kwargs)
if "self_attn_processor" in init_adapter_kwargs:
self_attn_processor = init_adapter_kwargs["self_attn_processor"]
if self_attn_processor is not None and isinstance(self_attn_processor, str):
self_attn_processor = find(self_attn_processor)
init_adapter_kwargs["self_attn_processor"] = self_attn_processor
pipeline.init_custom_adapter(**init_adapter_kwargs)
if self.cfg.pretrained_adapter_name_or_path:
pretrained_path = os.path.dirname(self.cfg.pretrained_adapter_name_or_path)
adapter_name = os.path.basename(self.cfg.pretrained_adapter_name_or_path)
pipeline.load_custom_adapter(pretrained_path, weight_name=adapter_name)
noise_scheduler = DDPMScheduler.from_config(
pipeline.scheduler.config, **self.cfg.noise_scheduler_kwargs
)
if self.cfg.shift_noise:
noise_scheduler = ShiftSNRScheduler.from_scheduler(
noise_scheduler,
shift_mode=self.cfg.shift_noise_mode,
shift_scale=self.cfg.shift_noise_scale,
scheduler_class=DDPMScheduler,
)
pipeline.scheduler = noise_scheduler
# Prepare models
self.pipeline: IG2MVSDXLPipeline = pipeline
self.vae = self.pipeline.vae.to(
dtype=torch.float16 if self.cfg.use_fp16_vae else torch.float32
)
self.tokenizer = self.pipeline.tokenizer
self.tokenizer_2 = self.pipeline.tokenizer_2
self.text_encoder = self.pipeline.text_encoder.to(
dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32
)
self.text_encoder_2 = self.pipeline.text_encoder_2.to(
dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32
)
self.feature_extractor = self.pipeline.feature_extractor
self.cond_encoder = self.pipeline.cond_encoder
self.unet = self.pipeline.unet
self.noise_scheduler = self.pipeline.scheduler
self.inference_scheduler = DDPMScheduler.from_config(
self.noise_scheduler.config
)
self.pipeline.scheduler = self.inference_scheduler
if self.cfg.prediction_type is not None:
self.noise_scheduler.register_to_config(
prediction_type=self.cfg.prediction_type
)
# Prepare trainable / non-trainable modules
trainable_modules = self.cfg.trainable_modules
if trainable_modules and len(trainable_modules) > 0:
self.unet.requires_grad_(False)
for name, module in self.unet.named_modules():
for trainable_module in trainable_modules:
if trainable_module in name:
module.requires_grad_(True)
else:
self.unet.requires_grad_(True)
self.cond_encoder.requires_grad_(self.cfg.train_cond_encoder)
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)
self.text_encoder_2.requires_grad_(False)
# Others
# Prepare gradient checkpointing
if self.cfg.gradient_checkpointing:
self.unet.enable_gradient_checkpointing()
def forward(
self,
noisy_latents: Tensor,
conditioning_pixel_values: Tensor,
timesteps: Tensor,
ref_latents: Tensor,
prompts: List[str],
num_views: int,
**kwargs,
) -> Dict[str, Any]:
bsz = noisy_latents.shape[0]
b_samples = bsz // num_views
num_batch_images = num_views
prompt_drop_mask = (
torch.rand(b_samples, device=noisy_latents.device)
< self.cfg.prompt_drop_prob
)
image_drop_mask = (
torch.rand(b_samples, device=noisy_latents.device)
< self.cfg.image_drop_prob
)
cond_drop_mask = (
torch.rand(b_samples, device=noisy_latents.device) < self.cfg.cond_drop_prob
)
prompt_drop_mask = prompt_drop_mask | cond_drop_mask
image_drop_mask = image_drop_mask | cond_drop_mask
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
additional_embeds = compute_embeddings(
prompts,
prompt_drop_mask,
[self.text_encoder, self.text_encoder_2],
[self.tokenizer, self.tokenizer_2],
**kwargs,
)
# Process reference latents to obtain reference features
with torch.no_grad():
ref_timesteps = torch.zeros_like(timesteps[:b_samples])
ref_hidden_states = {}
self.unet(
ref_latents,
ref_timesteps,
encoder_hidden_states=additional_embeds["prompt_embeds"],
added_cond_kwargs={
"text_embeds": additional_embeds["text_embeds"],
"time_ids": additional_embeds["time_ids"],
},
cross_attention_kwargs={
"cache_hidden_states": ref_hidden_states,
"use_mv": False,
"use_ref": False,
},
return_dict=False,
)
for k, v in ref_hidden_states.items():
v_ = v
v_[image_drop_mask] = 0.0
ref_hidden_states[k] = v_.repeat_interleave(num_batch_images, dim=0)
# Repeat additional embeddings for each image in the batch
for key, value in additional_embeds.items():
kwargs[key] = value.repeat_interleave(num_batch_images, dim=0)
conditioning_features = self.cond_encoder(conditioning_pixel_values)
added_cond_kwargs = {
"text_embeds": kwargs["text_embeds"],
"time_ids": kwargs["time_ids"],
}
noise_pred = self.unet(
noisy_latents,
timesteps,
encoder_hidden_states=kwargs["prompt_embeds"],
added_cond_kwargs=added_cond_kwargs,
down_intrablock_additional_residuals=conditioning_features,
cross_attention_kwargs={
"ref_hidden_states": ref_hidden_states,
"num_views": num_views,
},
).sample
return {"noise_pred": noise_pred}
def training_step(self, batch, batch_idx):
num_views = batch["num_views"]
vae_max_slice = 8
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
latents = []
for i in range(0, batch["rgb"].shape[0], vae_max_slice):
latents.append(
vae_encode(
self.vae,
batch["rgb"][i : i + vae_max_slice].to(self.vae.dtype) * 2 - 1,
sample=True,
apply_scale=True,
).float()
)
latents = torch.cat(latents, dim=0)
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
ref_latents = vae_encode(
self.vae,
batch["reference_rgb"].to(self.vae.dtype) * 2 - 1,
sample=True,
apply_scale=True,
).float()
bsz = latents.shape[0]
b_samples = bsz // num_views
noise = torch.randn_like(latents)
if self.cfg.noise_offset is not None:
# # https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += self.cfg.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)
noise_mask = (
batch["noise_mask"]
if "noise_mask" in batch
else torch.ones((bsz,), dtype=torch.bool, device=latents.device)
)
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(b_samples,),
device=latents.device,
dtype=torch.long,
)
timesteps = timesteps.repeat_interleave(num_views)
timesteps[~noise_mask] = 0
if self.cfg.input_perturbation is not None:
new_noise = noise + self.cfg.input_perturbation * torch.randn_like(noise)
noisy_latents = self.noise_scheduler.add_noise(
latents, new_noise, timesteps
)
else:
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents[~noise_mask] = latents[~noise_mask]
if self.noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif self.noise_scheduler.config.prediction_type == "v_prediction":
target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unsupported prediction type {self.noise_scheduler.config.prediction_type}"
)
model_pred = self(
noisy_latents, batch["source_rgb"], timesteps, ref_latents, **batch
)["noise_pred"]
model_pred = model_pred[noise_mask]
target = target[noise_mask]
if self.cfg.snr_gamma is None:
loss = F.mse_loss(model_pred, target, reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(self.noise_scheduler, timesteps)
if self.noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack(
[snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
loss = F.mse_loss(model_pred, target, reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
self.log("train/loss", loss, prog_bar=True)
# will execute self.on_check_train every self.cfg.check_train_every_n_steps steps
self.check_train(batch)
return {"loss": loss}
def on_train_batch_end(self, outputs, batch, batch_idx):
pass
def get_input_visualizations(self, batch):
return [
{
"type": "rgb",
"img": rearrange(
batch["source_rgb"],
"(B N) C H W -> (B H) (N W) C",
N=batch["num_views"],
),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": rearrange(
batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"]
),
"kwargs": {"data_format": "HWC"},
},
]
def get_output_visualizations(self, batch, outputs):
images = [
{
"type": "rgb",
"img": rearrange(
batch["source_rgb"],
"(B N) C H W -> (B H) (N W) C",
N=batch["num_views"],
),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": rearrange(
batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"]
),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": rearrange(
outputs, "(B N) C H W -> (B H) (N W) C", N=batch["num_views"]
),
"kwargs": {"data_format": "HWC"},
},
]
return images
def generate_images(self, batch, **kwargs):
return self.pipeline(
prompt=batch["prompts"],
control_image=batch["source_rgb"],
num_images_per_prompt=batch["num_views"],
generator=torch.Generator(device=self.device).manual_seed(
self.cfg.eval_seed
),
num_inference_steps=self.cfg.eval_num_inference_steps,
guidance_scale=self.cfg.eval_guidance_scale,
height=self.cfg.eval_height,
width=self.cfg.eval_width,
reference_image=batch["reference_rgb"],
output_type="pt",
).images
def on_save_checkpoint(self, checkpoint):
if self.global_rank == 0:
self.pipeline.save_custom_adapter(
os.path.dirname(self.get_save_dir()),
"step1x-3d-ig2v.safetensors",
safe_serialization=True,
include_keys=self.cfg.trainable_modules,
)
def on_check_train(self, batch):
self.save_image_grid(
f"it{self.true_global_step}-train.jpg",
self.get_input_visualizations(batch),
name="train_step_input",
step=self.true_global_step,
)
def validation_step(self, batch, batch_idx):
out = self.generate_images(batch)
if (
self.cfg.check_val_limit_rank > 0
and self.global_rank < self.cfg.check_val_limit_rank
):
self.save_image_grid(
f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg",
self.get_output_visualizations(batch, out),
name=f"validation_step_output_{self.global_rank}_{batch_idx}",
step=self.true_global_step,
)
def on_validation_epoch_end(self):
pass
def test_step(self, batch, batch_idx):
out = self.generate_images(batch)
self.save_image_grid(
f"it{self.true_global_step}-test-{self.global_rank}_{batch_idx}.jpg",
self.get_output_visualizations(batch, out),
name=f"test_step_output_{self.global_rank}_{batch_idx}",
step=self.true_global_step,
)
def on_test_end(self):
pass