CraftsMan3D / craftsman /systems /pixart_diffusion.py
wyysf's picture
update
9505fe5
from dataclasses import dataclass, field
import numpy as np
import json
import copy
import torch
import torch.nn.functional as F
from skimage import measure
from einops import repeat
from tqdm import tqdm
from PIL import Image
from diffusers import (
DDPMScheduler,
DDIMScheduler,
UniPCMultistepScheduler,
KarrasVeScheduler,
DPMSolverMultistepScheduler
)
import craftsman
from craftsman.systems.base import BaseSystem
from craftsman.utils.misc import get_rank
from craftsman.utils.typing import *
from diffusers import DDIMScheduler
def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def ddim_sample(ddim_scheduler: DDIMScheduler,
diffusion_model: torch.nn.Module,
shape: Union[List[int], Tuple[int]],
cond: torch.FloatTensor,
steps: int,
eta: float = 0.0,
guidance_scale: float = 3.0,
do_classifier_free_guidance: bool = True,
generator: Optional[torch.Generator] = None,
device: torch.device = "cuda:0",
disable_prog: bool = True):
assert steps > 0, f"{steps} must > 0."
# init latents
bsz = cond.shape[0]
if do_classifier_free_guidance:
bsz = bsz // 2
latents = torch.randn(
(bsz, *shape),
generator=generator,
device=cond.device,
dtype=cond.dtype,
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * ddim_scheduler.init_noise_sigma
# set timesteps
ddim_scheduler.set_timesteps(steps)
timesteps = ddim_scheduler.timesteps.to(device)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
extra_step_kwargs = {
# "eta": eta,
"generator": generator
}
# reverse
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
if do_classifier_free_guidance
else latents
)
# predict the noise residual
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = ddim_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
yield latents, t
# DEBUG = True
@craftsman.register("pixart-diffusion-system")
class PixArtDiffusionSystem(BaseSystem):
@dataclass
class Config(BaseSystem.Config):
val_samples_json: str = None
extract_mesh_func: str = "mc"
# diffusion config
z_scale_factor: float = 1.0
guidance_scale: float = 7.5
num_inference_steps: int = 50
eta: float = 0.0
snr_gamma: float = 5.0
# shape vae model
shape_model_type: str = None
shape_model: dict = field(default_factory=dict)
# condition model
condition_model_type: str = None
condition_model: dict = field(default_factory=dict)
# diffusion model
denoiser_model_type: str = None
denoiser_model: dict = field(default_factory=dict)
# noise scheduler
noise_scheduler_type: str = None
noise_scheduler: dict = field(default_factory=dict)
# denoise scheduler
denoise_scheduler_type: str = None
denoise_scheduler: dict = field(default_factory=dict)
cfg: Config
def configure(self):
super().configure()
self.shape_model = craftsman.find(self.cfg.shape_model_type)(self.cfg.shape_model)
self.shape_model.eval()
self.shape_model.requires_grad_(False)
self.condition = craftsman.find(self.cfg.condition_model_type)(self.cfg.condition_model)
self.denoiser_model = craftsman.find(self.cfg.denoiser_model_type)(self.cfg.denoiser_model)
self.noise_scheduler = craftsman.find(self.cfg.noise_scheduler_type)(**self.cfg.noise_scheduler)
self.denoise_scheduler = craftsman.find(self.cfg.denoise_scheduler_type)(**self.cfg.denoise_scheduler)
def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]:
# 1. encode shape latents
shape_embeds, kl_embed, _ = self.shape_model.encode(
batch["surface"][..., :3 + self.cfg.shape_model.point_feats],
sample_posterior=True
)
latents = kl_embed * self.cfg.z_scale_factor
# 2. gain condition. assert not (text_cond and image_cond), "Only one of text or image condition must be provided."
if "image" in batch and batch['image'].dim() == 5:
if self.training:
bs, n_images = batch['image'].shape[:2]
batch['image'] = batch['image'].view(bs*n_images, *batch['image'].shape[-3:])
else:
batch['image'] = batch['image'][:, 0, ...]
n_images = 1
bs = batch['image'].shape[0]
cond_latents = self.condition(batch).to(latents)
latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1)
latents = latents.view(bs*n_images, *latents.shape[-2:])
else:
cond_latents = self.condition(batch).to(latents)
cond_latents = cond_latents.view(cond_latents.shape[0], -1, cond_latents.shape[-1])
# 3. sample noise that we"ll add to the latents
noise = torch.randn_like(latents).to(latents) # [batch_size, n_token, latent_dim]
bs = latents.shape[0]
# 4. Sample a random timestep for each motion
timesteps = torch.randint(
0,
self.cfg.noise_scheduler.num_train_timesteps,
(bs,),
device=latents.device,
)
timesteps = timesteps.long()
# 5. add noise
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
# 6. diffusion model forward
noise_pred = self.denoiser_model(noisy_z, timesteps, cond_latents)
# 7. compute loss
if self.noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif self.noise_scheduler.config.prediction_type == "v_prediction":
target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Prediction Type: {self.noise_scheduler.prediction_type} not supported.")
if self.cfg.snr_gamma == 0:
if self.cfg.loss.loss_type == "l1":
loss = F.l1_loss(noise_pred, target, reduction="mean")
elif self.cfg.loss.loss_type in ["mse", "l2"]:
loss = F.mse_loss(noise_pred, target, reduction="mean")
else:
raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(self.noise_scheduler, timesteps)
mse_loss_weights = torch.stack([snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
dim=1
)[0]
if self.noise_scheduler.config.prediction_type == "epsilon":
mse_loss_weights = mse_loss_weights / snr
elif self.noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
if self.cfg.loss.loss_type == "l1":
loss = F.l1_loss(noise_pred, target, reduction="none")
elif self.cfg.loss.loss_type in ["mse", "l2"]:
loss = F.mse_loss(noise_pred, target, reduction="none")
else:
raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
return {
"loss_diffusion": loss,
"latents": latents,
"x_t": noisy_z,
"noise": noise,
"noise_pred": noise_pred,
"timesteps": timesteps,
}
def training_step(self, batch, batch_idx):
out = self(batch)
loss = 0.
for name, value in out.items():
if name.startswith("loss_"):
self.log(f"train/{name}", value)
loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])
for name, value in self.cfg.loss.items():
if name.startswith("lambda_"):
self.log(f"train_params/{name}", self.C(value))
return {"loss": loss}
@torch.no_grad()
def validation_step(self, batch, batch_idx):
self.eval()
if get_rank() == 0:
sample_inputs = json.loads(open(self.cfg.val_samples_json).read()) # condition
sample_inputs_ = copy.deepcopy(sample_inputs)
sample_outputs = self.sample(sample_inputs) # list
for i, sample_output in enumerate(sample_outputs):
mesh_v_f, has_surface = self.shape_model.extract_geometry(sample_output, octree_depth=7, extract_mesh_func=self.cfg.extract_mesh_func)
for j in range(len(mesh_v_f)):
if "image" in sample_inputs_:
name = sample_inputs_["image"][j].split("/")[-1].replace(".png", "")
elif "mvimages" in sample_inputs_:
name = sample_inputs_["mvimages"][j][0].split("/")[-2].replace(".png", "")
self.save_mesh(
f"it{self.true_global_step}/{name}_{i}.obj",
mesh_v_f[j][0], mesh_v_f[j][1]
)
out = self(batch)
if self.global_step == 0:
latents = self.shape_model.decode(out["latents"])
mesh_v_f, has_surface = self.shape_model.extract_geometry(latents=latents, extract_mesh_func=self.cfg.extract_mesh_func)
self.save_mesh(
f"it{self.true_global_step}/{batch['uid'][0]}_{batch['sel_idx'][0] if 'sel_idx' in batch.keys() else 0}.obj",
mesh_v_f[0][0], mesh_v_f[0][1]
)
return {"val/loss": out["loss_diffusion"]}
@torch.no_grad()
def sample(self,
sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]],
sample_times: int = 1,
steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
eta: float = 0.0,
seed: Optional[int] = None,
**kwargs):
if steps is None:
steps = self.cfg.num_inference_steps
if guidance_scale is None:
guidance_scale = self.cfg.guidance_scale
do_classifier_free_guidance = guidance_scale != 1.0
# conditional encode
if "image" in sample_inputs:
sample_inputs["image"] = [Image.open(img) if type(img) == str else img for img in sample_inputs["image"]]
cond = self.condition.encode_image(sample_inputs["image"])
if do_classifier_free_guidance:
un_cond = self.condition.empty_image_embeds.repeat(len(sample_inputs["image"]), 1, 1).to(cond)
cond = torch.cat([un_cond, cond], dim=0)
elif "mvimages" in sample_inputs: # by default 4 views
bs = len(sample_inputs["mvimages"])
cond = []
for image in sample_inputs["mvimages"]:
if isinstance(image, list) and isinstance(image[0], str):
sample_inputs["image"] = [Image.open(img) for img in image] # List[PIL]
else:
sample_inputs["image"] = image
cond += [self.condition.encode_image(sample_inputs["image"])]
cond = torch.stack(cond, dim=0).view(bs, -1, self.cfg.denoiser_model.context_dim)
if do_classifier_free_guidance:
un_cond = self.condition.empty_image_embeds.unsqueeze(0).repeat(len(sample_inputs["mvimages"]), 1, 1, 1).view(bs, cond.shape[1], self.cfg.denoiser_model.context_dim).to(cond) # shape 为[len(sample_inputs["mvimages"], 4*(num_latents+1), context_dim]
cond = torch.cat([un_cond, cond], dim=0).view(bs * 2, -1, cond[0].shape[-1])
else:
raise NotImplementedError("Only image or mvimages condition is supported.")
outputs = []
latents = None
if seed != None:
generator = torch.Generator(device="cuda").manual_seed(seed)
else:
generator = None
for _ in range(sample_times):
sample_loop = ddim_sample(
self.denoise_scheduler,
self.denoiser_model.eval(),
shape=self.shape_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=False,
generator= generator
)
for sample, t in sample_loop:
latents = sample
outputs.append(self.shape_model.decode(latents / self.cfg.z_scale_factor, **kwargs))
return outputs
def on_validation_epoch_end(self):
pass