V3D / sgm /models /video3d_diffusion.py
heheyas
init
cfb7702
import re
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR
from torchvision.utils import make_grid
from einops import rearrange, repeat
from ..modules import UNCONDITIONAL_CONFIG
from ..modules.autoencoding.temporal_ae import VideoDecoder
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder
from ..util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
video_frames_as_grid,
)
def flatten_for_video(input):
return input.flatten()
class Video3DDiffusionEngine(pl.LightningModule):
def __init__(
self,
network_config,
denoiser_config,
first_stage_config,
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
network_wrapper: Union[None, str] = None,
ckpt_path: Union[None, str] = None,
use_ema: bool = False,
ema_decay_rate: float = 0.9999,
scale_factor: float = 1.0,
disable_first_stage_autocast=False,
input_key: str = "frames", # for video inputs
log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.log_keys = log_keys
self.input_key = input_key
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.AdamW"}
)
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = (
instantiate_from_config(sampler_config)
if sampler_config is not None
else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self.scheduler_config = scheduler_config
self._init_first_stage(first_stage_config)
self.loss_fn = (
instantiate_from_config(loss_fn_config)
if loss_fn_config is not None
else None
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def _load_last_embedder(self, original_state_dict):
original_module_name = "conditioner.embedders.3"
state_dict = dict()
for k, v in original_state_dict.items():
m = re.match(rf"^{original_module_name}\.(.*)$", k)
if m is None:
continue
state_dict[m.group(1)] = v
idx = -1
for i in range(len(self.conditioner.embedders)):
if isinstance(
self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder
):
idx = i
print(f"Embedder [{idx}] is the frame encoder, make sure this is expected")
self.conditioner.embedders[idx].load_state_dict(state_dict)
def init_from_ckpt(
self,
path: str,
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
self_sd = self.state_dict()
input_keys = [
"model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight",
]
for input_key in input_keys:
if input_key not in sd or input_key not in self_sd:
continue
input_weight = self_sd[input_key]
if input_weight.shape != sd[input_key].shape:
print("Manual init: {}".format(input_key))
input_weight.zero_()
input_weight[:, :8, :, :].copy_(sd[input_key])
deleted_keys = []
for k, v in self.state_dict().items():
# resolve shape dismatch
if k in sd:
if v.shape != sd[k].shape:
del sd[k]
deleted_keys.append(k)
if len(deleted_keys) > 0:
print(f"Deleted Keys: {deleted_keys}")
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
if len(deleted_keys) > 0:
print(f"Deleted Keys: {deleted_keys}")
if len(missing) > 0 or len(unexpected) > 0:
# means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id)
print("Modified embedder to support 3d spiral video inputs")
try:
self._load_last_embedder(sd)
except:
print("Failed to load last embedder, make sure this is expected")
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def get_input(self, batch):
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
is_video_input = False
bs = z.shape[0]
if z.dim() == 5:
is_video_input = True
# for video diffusion
z = rearrange(z, "b t c h w -> (b t) c h w")
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
if is_video_input:
out = rearrange(out, "(b t) c h w -> b t c h w", b=bs)
return out
@torch.no_grad()
def encode_first_stage(self, x):
if self.input_key == "latents":
return x
bs = x.shape[0]
is_video_input = False
if x.dim() == 5:
is_video_input = True
# for video diffusion
x = rearrange(x, "b t c h w -> (b t) c h w")
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
)
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
# if is_video_input:
# z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
return z
def forward(self, x, batch):
loss, model_output = self.loss_fn(
self.model,
self.denoiser,
self.conditioner,
x,
batch,
return_model_output=True,
)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean, "model_output": model_output}
return loss_mean, loss_dict
def shared_step(self, batch: Dict) -> Any:
# TODO: move this shit to collate_fn in dataloader
# if "fps_id" in batch:
# batch["fps_id"] = flatten_for_video(batch["fps_id"])
# if "motion_bucket_id" in batch:
# batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"])
# if "cond_aug" in batch:
# batch["cond_aug"] = flatten_for_video(batch["cond_aug"])
x = self.get_input(batch)
x = self.encode_first_stage(x)
# ## debug
# x_recon = self.decode_first_stage(x)
# video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg")
# video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg")
# ## debug
batch["global_step"] = self.global_step
loss, loss_dict = self(x, batch)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
with torch.no_grad():
if "model_output" in loss_dict:
if batch_idx % 100 == 0:
if isinstance(self.logger, WandbLogger):
model_output = loss_dict["model_output"].detach()[
: batch["num_video_frames"]
]
recons = (
(self.decode_first_stage(model_output) + 1.0) / 2.0
).clamp(0.0, 1.0)
recon_grid = make_grid(recons, nrow=4)
self.logger.log_image(
key=f"train/model_output_recon",
images=[recon_grid],
step=self.global_step,
)
del loss_dict["model_output"]
if torch.isnan(loss).any():
print("Nan detected")
loss = None
self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
if self.scheduler_config is not None:
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
return loss
def on_train_start(self, *args, **kwargs):
if self.sampler is None or self.loss_fn is None:
raise ValueError("Sampler and loss function need to be set for training.")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
for embedder in self.conditioner.embedders:
if embedder.is_trainable:
params = params + list(embedder.parameters())
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(self.device)
denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[-2:]
log = dict()
for embedder in self.conditioner.embedders:
if (
(self.log_keys is None) or (embedder.input_key in self.log_keys)
) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = [
"x".join([str(xx) for xx in x[i].tolist()])
for i in range(x.shape[0])
]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
elif x.dim() == 4:
# image
xc = x
else:
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
# strings
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log
# for video diffusions will be logging frames of a video
@torch.no_grad()
def log_images(
self,
batch: Dict,
N: int = 1,
sample: bool = True,
ucg_keys: List[str] = None,
**kwargs,
) -> Dict:
# # debug
# return {}
# # debug
assert "num_video_frames" in batch, "num_video_frames must be in batch"
num_video_frames = batch["num_video_frames"]
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
conditioner_input_keys = []
for e in self.conditioner.embedders:
if e.input_key is not None:
conditioner_input_keys.append(e.input_key)
else:
conditioner_input_keys.extend(e.input_keys)
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
log = dict()
x = self.get_input(batch)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(self.conditioner.embedders) > 0
else [],
)
sampling_kwargs = {"num_video_frames": num_video_frames}
n = min(x.shape[0] // num_video_frames, N)
sampling_kwargs["image_only_indicator"] = torch.cat(
[batch["image_only_indicator"][:n]] * 2
)
N = min(x.shape[0] // num_video_frames, N) * num_video_frames
x = x.to(self.device)[:N]
# log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames)
log["inputs"] = x
z = self.encode_first_stage(x)
recon = self.decode_first_stage(z)
# log["reconstructions"] = rearrange(
# recon, "(b t) c h w -> b c h (t w)", t=num_video_frames
# )
log["reconstructions"] = recon
log.update(self.log_conditionings(batch, N))
log["pixelnerf_rgb"] = c["rgb"]
for k in ["crossattn", "concat", "vector"]:
if k in c:
c[k] = c[k][:N]
uc[k] = uc[k][:N]
# for k in c:
# if isinstance(c[k], torch.Tensor):
# if k == "vector":
# end = N
# else:
# end = n
# c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc))
# # for k in c:
# # print(c[k].shape)
# breakpoint()
# for k in ["crossattn", "concat"]:
# c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames)
# c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames)
# uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames)
# uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames)
# for k in c:
# print(c[k].shape)
if sample:
with self.ema_scope("Plotting"):
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
log["samples"] = samples
return log