customdiffusion360 / sampling_for_demo.py
customdiffusion360's picture
reduce memory usage, use github code
8d3da67
raw
history blame contribute delete
No virus
15.9 kB
import sys
import copy
from typing import List
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.renderer import look_at_view_transform
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
import json
sys.path.append('./custom-diffusion360/')
from sgm.util import instantiate_from_config, load_safetensors
choices = []
def load_base_model(config, ckpt=None, verbose=True):
config = OmegaConf.load(config)
# load model
config.model.params.network_config.params.far = 3
config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors"
guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef',
'params': {'scale': 7.5, 'scale_im': 3.5}
}
config.model.params.sampler_config.params.guider_config = guider_config
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
if 'modifier_token' in config.data.params:
del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight']
del sd['conditioner.embedders.1.model.token_embedding.weight']
else:
raise NotImplementedError
m, u = model.load_state_dict(sd, strict=False)
model.eval()
return model
def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True):
"""
model is preloaded base stable diffusion model
"""
msg = None
if delta_ckpt is not None:
pl_sd_delta = torch.load(delta_ckpt, map_location="cpu")
sd_delta = pl_sd_delta["delta_state_dict"]
# TODO: add new delta loading embedding stuff?
for name, module in model.model.diffusion_model.named_modules():
if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
if hasattr(module, 'pose_emb_layers'):
module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references'])
del sd_delta[f'model.diffusion_model.{name}.references']
m, u = model.load_state_dict(sd_delta, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
if len(u) > 0 and verbose:
print("unexpected keys:")
if freeze:
for param in model.parameters():
param.requires_grad = False
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner):
p = [x.input_keys for x in conditioner.embedders]
return list(set([item for sublist in p for item in sublist])) + ['jpg_ref']
def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None, drop_im=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
fg_masks = []
alphas = []
rgbs = []
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
prev_weights = None
counter = 0
for i, block in enumerate(self.transformer_blocks):
if i > 0 and len(context) == 1:
i = 0 # use same context for each block
if self.image_cross and (counter % self.poscontrol_interval == 0):
x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights, drop_im=drop_im)
prev_weights = weights
fg_masks.append(fg_mask)
if alpha is not None:
alphas.append(alpha)
if rgb is not None:
rgbs.append(rgb)
else:
x, _, _, _, _ = block(x, context=context[i], drop_im=drop_im)
counter += 1
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
if len(fg_masks) > 0:
if len(rgbs) <= 0:
rgbs = None
if len(alphas) <= 0:
alphas = None
return x + x_in, None, fg_masks, prev_weights, alphas, rgbs
else:
return x + x_in, None, None, prev_weights, None, None
def _customforward(
self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, drop_im=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
if context_ref is not None:
global choices
batch_size = x.size(0)
# IP2P like sampling or default sampling
if batch_size % 3 == 0:
batch_size = batch_size // 3
context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1)
context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0)
else:
batch_size = batch_size // 2
context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1)
context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0)
fg_mask = None
weights = None
alphas = None
predicted_rgb = None
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x
)
x = (
self.attn2(
self.norm2(x), context=context, additional_tokens=additional_tokens,
)
+ x
)
if context_ref is not None:
if self.rendered_feat is not None:
x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1))
else:
xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x,
context_ref,
context,
pose,
prev_weights,
mask_ref)
self.rendered_feat = xref
x = self.pose_emb_layers(torch.cat([x, xref], -1))
x = self.ff(self.norm3(x)) + x
return x, fg_mask, weights, alphas, predicted_rgb
def log_images(
model,
batch,
N: int = 1,
noise=None,
scale_im=3.5,
num_steps: int = 10,
ucg_keys: List[str] = None,
**kwargs,
):
log = dict()
conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders]
ucg_keys = conditioner_input_keys
pose = batch['pose']
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(model.conditioner.embedders) > 0
else [],
force_ref_zero_embeddings=True
)
_, n = 1, len(pose)-1
sampling_kwargs = {}
if scale_im > 0:
if uc is not None:
if isinstance(pose, list):
pose = pose[:N]*3
else:
pose = torch.cat([pose[:N]] * 3)
else:
if uc is not None:
if isinstance(pose, list):
pose = pose[:N]*2
else:
pose = torch.cat([pose[:N]] * 2)
sampling_kwargs['pose'] = pose
sampling_kwargs['drop_im'] = None
sampling_kwargs['mask_ref'] = None
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc))
import time
st = time.time()
with model.ema_scope("Plotting"):
samples = model.sample(
c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs
)
model.clear_rendered_feat()
samples = model.decode_first_stage(samples)
print("Time taken for sampling", time.time() - st)
log["samples"] = samples.cpu()
return log
def process_camera_json(camera_json, example_cam):
# replace all single quotes in the camera_json with quotes quotes
camera_json = camera_json.replace("'", "\"")
print("input camera json")
print(camera_json)
camera_dict = json.loads(camera_json)["scene.camera"]
eye = torch.tensor([camera_dict["eye"]["x"], camera_dict["eye"]["y"], camera_dict["eye"]["z"]], dtype=torch.float32).unsqueeze(0)
up = torch.tensor([camera_dict["up"]["x"], camera_dict["up"]["y"], camera_dict["up"]["z"]], dtype=torch.float32).unsqueeze(0)
center = torch.tensor([camera_dict["center"]["x"], camera_dict["center"]["y"], camera_dict["center"]["z"]], dtype=torch.float32).unsqueeze(0)
new_R, new_T = look_at_view_transform(eye=eye, at=center, up=up)
print("focal length", example_cam.focal_length)
print("principal point", example_cam.principal_point)
newcam = PerspectiveCameras(R=new_R,
T=new_T,
focal_length=example_cam.focal_length,
principal_point=example_cam.principal_point,
image_size=512)
print("input pose")
print(newcam.get_world_to_view_transform().get_matrix())
return newcam
def load_and_return_model_and_data(config, model,
ckpt="pretrained-models/sd_xl_base_1.0.safetensors",
delta_ckpt=None,
train=False,
valid=False,
far=3,
num_images=1,
num_ref=8,
max_images=20,
):
config = OmegaConf.load(config)
# load data
data = None
# config.data.params.jitter = False
# config.data.params.addreg = False
# config.data.params.bbox = False
# data = instantiate_from_config(config.data)
# data = data.train_dataset
# single_id = data.single_id
# if hasattr(data, 'rotations'):
# total_images = len(data.rotations[data.sequence_list[single_id]])
# else:
# total_images = len(data.annotations['chair'])
# print(f"Total images in dataset: {total_images}")
model, msg = load_delta_model(model, delta_ckpt,)
model = model.cuda()
# change forward methods to store rendered features and use the pre-calculated reference features
def register_recr(net_):
if net_.__class__.__name__ == 'SpatialTransformer':
print(net_.__class__.__name__, "adding control")
bound_method = customforward.__get__(net_, net_.__class__)
setattr(net_, 'forward', bound_method)
return
elif hasattr(net_, 'children'):
for net__ in net_.children():
register_recr(net__)
return
def register_recr2(net_):
if net_.__class__.__name__ == 'BasicTransformerBlock':
print(net_.__class__.__name__, "adding control")
bound_method = _customforward.__get__(net_, net_.__class__)
setattr(net_, 'forward', bound_method)
return
elif hasattr(net_, 'children'):
for net__ in net_.children():
register_recr2(net__)
return
sub_nets = model.model.diffusion_model.named_children()
for net in sub_nets:
register_recr(net[1])
register_recr2(net[1])
# start sampling
model.clear_rendered_feat()
return model, data
def sample(model, data,
num_images=1,
prompt="",
appendpath="",
camera_json=None,
train=False,
scale=7.5,
scale_im=3.5,
beta=1.0,
num_ref=8,
skipreflater=False,
num_steps=10,
valid=False,
max_images=20,
seed=42,
camera_path="pretrained-models/car0/camera.bin",
):
"""
Only works with num_images=1 (because of camera_json processing)
"""
if num_images != 1:
print("forcing num_images to be 1")
num_images = 1
# set guidance scales
model.sampler.guider.scale_im = scale_im
model.sampler.guider.scale = scale
seed_everything(seed)
# load cameras
cameras_val, cameras_train = torch.load(camera_path)
global choices
num_ref = 8
max_diff = len(cameras_train)/num_ref
choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)]
cameras_train_final = [cameras_train[i] for i in choices]
# start sampling
model.clear_rendered_feat()
if prompt == "":
prompt = None
noise = torch.randn(1, 4, 64, 64).to('cuda').repeat(num_images, 1, 1, 1)
# random sample camera poses
pose_ids = np.random.choice(len(cameras_val), num_images, replace=False)
print(pose_ids)
pose_ids[0] = 21
pose = [cameras_val[i] for i in pose_ids]
print("example camera")
print(pose[0].R)
print(pose[0].T)
print(pose[0].focal_length)
print(pose[0].principal_point)
# prepare batches [if translating then call required functions on the target pose]
batches = []
for i in range(num_images):
batch = {'pose': [pose[i]] + cameras_train_final,
"original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2),
"target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2),
"crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2),
"original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2),
"target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2),
"crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2),
}
batch_ = copy.deepcopy(batch)
batch_["pose"][0] = process_camera_json(camera_json, pose[0])
batch_["pose"] = [join_cameras_as_batch(batch_["pose"])]
# print('batched')
# print(batch_["pose"][0].get_world_to_view_transform().get_matrix())
batches.append(batch_)
print(f'len batches: {len(batches)}')
image = None
with torch.no_grad():
for batch in batches:
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to('cuda')
elif 'pose' in key:
batch[key] = [x.to('cuda') for x in batch[key]]
else:
pass
if prompt is not None:
batch["txt"] = [prompt for _ in range(1)]
batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)]
print(batch["txt"])
N = 1
log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=num_steps, scale_im=scale_im)
image = log_["samples"]
torch.cuda.empty_cache()
model.clear_rendered_feat()
print("generation done")
return image