import glob import os import sys import copy from typing import List import numpy as np import torch from einops import rearrange from omegaconf import OmegaConf from PIL import Image from pytorch_lightning import seed_everything from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.renderer import look_at_view_transform from pytorch3d.renderer.camera_utils import join_cameras_as_batch import json sys.path.append('./') from sgm.util import instantiate_from_config, load_safetensors choices = [] def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def load_base_model(config, ckpt=None, verbose=True): config = OmegaConf.load(config) # load model config.model.params.network_config.params.far = 3 config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors" guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef', 'params': {'scale': 7.5, 'scale_im': 3.5} } config.model.params.sampler_config.params.guider_config = guider_config model = instantiate_from_config(config.model) if ckpt is not None: print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) if 'modifier_token' in config.data.params: del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] del sd['conditioner.embedders.1.model.token_embedding.weight'] else: raise NotImplementedError m, u = model.load_state_dict(sd, strict=False) model.cuda() model.eval() return model def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True): """ model is preloaded base stable diffusion model """ msg = None if delta_ckpt is not None: pl_sd_delta = torch.load(delta_ckpt, map_location="cpu") sd_delta = pl_sd_delta["delta_state_dict"] # TODO: add new delta loading embedding stuff? for name, module in model.model.diffusion_model.named_modules(): if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': if hasattr(module, 'pose_emb_layers'): module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references']) del sd_delta[f'model.diffusion_model.{name}.references'] m, u = model.load_state_dict(sd_delta, strict=False) if len(m) > 0 and verbose: print("missing keys:") if len(u) > 0 and verbose: print("unexpected keys:") if freeze: for param in model.parameters(): param.requires_grad = False model.cuda() model.eval() return model, msg def get_unique_embedder_keys_from_conditioner(conditioner): p = [x.input_keys for x in conditioner.embedders] return list(set([item for sublist in p for item in sublist])) + ['jpg_ref'] def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None, drop_im=None): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] b, c, h, w = x.shape x_in = x fg_masks = [] alphas = [] rgbs = [] x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) prev_weights = None counter = 0 for i, block in enumerate(self.transformer_blocks): if i > 0 and len(context) == 1: i = 0 # use same context for each block if self.image_cross and (counter % self.poscontrol_interval == 0): x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights, drop_im=drop_im) prev_weights = weights fg_masks.append(fg_mask) if alpha is not None: alphas.append(alpha) if rgb is not None: rgbs.append(rgb) else: x, _, _, _, _ = block(x, context=context[i], drop_im=drop_im) counter += 1 if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) if len(fg_masks) > 0: if len(rgbs) <= 0: rgbs = None if len(alphas) <= 0: alphas = None return x + x_in, None, fg_masks, prev_weights, alphas, rgbs else: return x + x_in, None, None, prev_weights, None, None def _customforward( self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, drop_im=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 ): if context_ref is not None: global choices batch_size = x.size(0) # IP2P like sampling or default sampling if batch_size % 3 == 0: batch_size = batch_size // 3 context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0) else: batch_size = batch_size // 2 context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0) fg_mask = None weights = None alphas = None predicted_rgb = None x = ( self.attn1( self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) x = ( self.attn2( self.norm2(x), context=context, additional_tokens=additional_tokens, ) + x ) if context_ref is not None: if self.rendered_feat is not None: x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1)) else: xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x, context_ref, context, pose, prev_weights, mask_ref) self.rendered_feat = xref x = self.pose_emb_layers(torch.cat([x, xref], -1)) x = self.ff(self.norm3(x)) + x return x, fg_mask, weights, alphas, predicted_rgb def log_images( model, batch, N: int = 1, noise=None, scale_im=3.5, num_steps: int = 10, ucg_keys: List[str] = None, **kwargs, ): log = dict() conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders] ucg_keys = conditioner_input_keys pose = batch['pose'] c, uc = model.conditioner.get_unconditional_conditioning( batch, force_uc_zero_embeddings=ucg_keys if len(model.conditioner.embedders) > 0 else [], force_ref_zero_embeddings=True ) _, n = 1, len(pose)-1 sampling_kwargs = {} if scale_im > 0: if uc is not None: if isinstance(pose, list): pose = pose[:N]*3 else: pose = torch.cat([pose[:N]] * 3) else: if uc is not None: if isinstance(pose, list): pose = pose[:N]*2 else: pose = torch.cat([pose[:N]] * 2) sampling_kwargs['pose'] = pose sampling_kwargs['drop_im'] = None sampling_kwargs['mask_ref'] = None for k in c: if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc)) import time st = time.time() with model.ema_scope("Plotting"): samples = model.sample( c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs ) model.clear_rendered_feat() samples = model.decode_first_stage(samples) print("Time taken for sampling", time.time() - st) log["samples"] = samples.cpu() return log def process_camera_json(camera_json, example_cam): # replace all single quotes in the camera_json with quotes quotes camera_json = camera_json.replace("'", "\"") print("input camera json") print(camera_json) camera_dict = json.loads(camera_json)["scene.camera"] eye = torch.tensor([camera_dict["eye"]["x"], camera_dict["eye"]["y"], camera_dict["eye"]["z"]], dtype=torch.float32).unsqueeze(0) up = torch.tensor([camera_dict["up"]["x"], camera_dict["up"]["y"], camera_dict["up"]["z"]], dtype=torch.float32).unsqueeze(0) center = torch.tensor([camera_dict["center"]["x"], camera_dict["center"]["y"], camera_dict["center"]["z"]], dtype=torch.float32).unsqueeze(0) new_R, new_T = look_at_view_transform(eye=eye, at=center, up=up) ## temp # new_R = torch.tensor([[[ 0.4988, 0.2666, 0.8247], # [-0.1917, -0.8940, 0.4049], # [ 0.8453, -0.3601, -0.3948]]], dtype=torch.float32) # new_T = torch.tensor([[ 0.0739, -0.0013, 0.9973]], dtype=torch.float32) # new_R = torch.tensor([[[ 0.2530, 0.2989, 0.9201], # [-0.2652, -0.8932, 0.3631], # [ 0.9304, -0.3359, -0.1467],]], dtype=torch.float32) # new_T = torch.tensor([[ 0.0081, 0.0337, 1.0452]], dtype=torch.float32) print("focal length", example_cam.focal_length) print("principal point", example_cam.principal_point) newcam = PerspectiveCameras(R=new_R, T=new_T, focal_length=example_cam.focal_length, principal_point=example_cam.principal_point, image_size=512) print("input pose") print(newcam.get_world_to_view_transform().get_matrix()) return newcam def load_and_return_model_and_data(config, model, ckpt="/data/gdsu/customization3d/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors", delta_ckpt=None, train=False, valid=False, far=3, num_images=1, num_ref=8, max_images=20, ): config = OmegaConf.load(config) # load data data = None # config.data.params.jitter = False # config.data.params.addreg = False # config.data.params.bbox = False # data = instantiate_from_config(config.data) # data = data.train_dataset # single_id = data.single_id # if hasattr(data, 'rotations'): # total_images = len(data.rotations[data.sequence_list[single_id]]) # else: # total_images = len(data.annotations['chair']) # print(f"Total images in dataset: {total_images}") model, msg = load_delta_model(model, delta_ckpt,) # change forward methods to store rendered features and use the pre-calculated reference features def register_recr(net_): if net_.__class__.__name__ == 'SpatialTransformer': print(net_.__class__.__name__, "adding control") bound_method = customforward.__get__(net_, net_.__class__) setattr(net_, 'forward', bound_method) return elif hasattr(net_, 'children'): for net__ in net_.children(): register_recr(net__) return def register_recr2(net_): if net_.__class__.__name__ == 'BasicTransformerBlock': print(net_.__class__.__name__, "adding control") bound_method = _customforward.__get__(net_, net_.__class__) setattr(net_, 'forward', bound_method) return elif hasattr(net_, 'children'): for net__ in net_.children(): register_recr2(net__) return sub_nets = model.model.diffusion_model.named_children() for net in sub_nets: register_recr(net[1]) register_recr2(net[1]) # start sampling model.clear_rendered_feat() return model, data def sample(model, data, num_images=1, prompt="", appendpath="", camera_json=None, train=False, scale=7.5, scale_im=3.5, beta=1.0, num_ref=8, skipreflater=False, num_steps=10, valid=False, max_images=20, seed=42, camera_path="pretrained-models/car0/camera.bin", ): """ Only works with num_images=1 (because of camera_json processing) """ if num_images != 1: print("forcing num_images to be 1") num_images = 1 # set guidance scales model.sampler.guider.scale_im = scale_im model.sampler.guider.scale = scale seed_everything(seed) # load cameras cameras_val, cameras_train = torch.load(camera_path) global choices num_ref = 8 max_diff = len(cameras_train)/num_ref choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)] cameras_train_final = [cameras_train[i] for i in choices] # start sampling model.clear_rendered_feat() if prompt == "": prompt = None noise = torch.randn(1, 4, 64, 64).to('cuda').repeat(num_images, 1, 1, 1) # random sample camera poses pose_ids = np.random.choice(len(cameras_val), num_images, replace=False) print(pose_ids) pose_ids[0] = 21 pose = [cameras_val[i] for i in pose_ids] print("example camera") print(pose[0].R) print(pose[0].T) print(pose[0].focal_length) print(pose[0].principal_point) # prepare batches [if translating then call required functions on the target pose] batches = [] for i in range(num_images): batch = {'pose': [pose[i]] + cameras_train_final, "original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), "target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), "crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2), "original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), "target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), "crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2), } batch_ = copy.deepcopy(batch) batch_["pose"][0] = process_camera_json(camera_json, pose[0]) batch_["pose"] = [join_cameras_as_batch(batch_["pose"])] # print('batched') # print(batch_["pose"][0].get_world_to_view_transform().get_matrix()) batches.append(batch_) print(f'len batches: {len(batches)}') image = None with torch.no_grad(): for batch in batches: for key in batch.keys(): if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to('cuda') elif 'pose' in key: batch[key] = [x.to('cuda') for x in batch[key]] else: pass if prompt is not None: batch["txt"] = [prompt for _ in range(1)] batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)] print(batch["txt"]) N = 1 log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=num_steps, scale_im=scale_im) image = log_["samples"] torch.cuda.empty_cache() model.clear_rendered_feat() print("generation done") return image