Spaces:
Runtime error
Runtime error
import sys | |
import copy | |
from typing import List | |
import numpy as np | |
import torch | |
from einops import rearrange | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from pytorch3d.renderer.cameras import PerspectiveCameras | |
from pytorch3d.renderer import look_at_view_transform | |
from pytorch3d.renderer.camera_utils import join_cameras_as_batch | |
import json | |
sys.path.append('./') | |
from sgm.util import instantiate_from_config, load_safetensors | |
choices = [] | |
def load_base_model(config, ckpt=None, verbose=True): | |
config = OmegaConf.load(config) | |
# load model | |
config.model.params.network_config.params.far = 3 | |
config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors" | |
guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef', | |
'params': {'scale': 7.5, 'scale_im': 3.5} | |
} | |
config.model.params.sampler_config.params.guider_config = guider_config | |
model = instantiate_from_config(config.model) | |
if ckpt is not None: | |
print(f"Loading model from {ckpt}") | |
if ckpt.endswith("ckpt"): | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd = pl_sd["state_dict"] | |
elif ckpt.endswith("safetensors"): | |
sd = load_safetensors(ckpt) | |
if 'modifier_token' in config.data.params: | |
del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] | |
del sd['conditioner.embedders.1.model.token_embedding.weight'] | |
else: | |
raise NotImplementedError | |
m, u = model.load_state_dict(sd, strict=False) | |
model.cuda() | |
model.eval() | |
return model | |
def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True): | |
""" | |
model is preloaded base stable diffusion model | |
""" | |
msg = None | |
if delta_ckpt is not None: | |
pl_sd_delta = torch.load(delta_ckpt, map_location="cpu") | |
sd_delta = pl_sd_delta["delta_state_dict"] | |
# TODO: add new delta loading embedding stuff? | |
for name, module in model.model.diffusion_model.named_modules(): | |
if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': | |
if hasattr(module, 'pose_emb_layers'): | |
module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references']) | |
del sd_delta[f'model.diffusion_model.{name}.references'] | |
m, u = model.load_state_dict(sd_delta, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
if freeze: | |
for param in model.parameters(): | |
param.requires_grad = False | |
model.cuda() | |
model.eval() | |
return model, msg | |
def get_unique_embedder_keys_from_conditioner(conditioner): | |
p = [x.input_keys for x in conditioner.embedders] | |
return list(set([item for sublist in p for item in sublist])) + ['jpg_ref'] | |
def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None, drop_im=None): | |
# note: if no context is given, cross-attention defaults to self-attention | |
if not isinstance(context, list): | |
context = [context] | |
b, c, h, w = x.shape | |
x_in = x | |
fg_masks = [] | |
alphas = [] | |
rgbs = [] | |
x = self.norm(x) | |
if not self.use_linear: | |
x = self.proj_in(x) | |
x = rearrange(x, "b c h w -> b (h w) c").contiguous() | |
if self.use_linear: | |
x = self.proj_in(x) | |
prev_weights = None | |
counter = 0 | |
for i, block in enumerate(self.transformer_blocks): | |
if i > 0 and len(context) == 1: | |
i = 0 # use same context for each block | |
if self.image_cross and (counter % self.poscontrol_interval == 0): | |
x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights, drop_im=drop_im) | |
prev_weights = weights | |
fg_masks.append(fg_mask) | |
if alpha is not None: | |
alphas.append(alpha) | |
if rgb is not None: | |
rgbs.append(rgb) | |
else: | |
x, _, _, _, _ = block(x, context=context[i], drop_im=drop_im) | |
counter += 1 | |
if self.use_linear: | |
x = self.proj_out(x) | |
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() | |
if not self.use_linear: | |
x = self.proj_out(x) | |
if len(fg_masks) > 0: | |
if len(rgbs) <= 0: | |
rgbs = None | |
if len(alphas) <= 0: | |
alphas = None | |
return x + x_in, None, fg_masks, prev_weights, alphas, rgbs | |
else: | |
return x + x_in, None, None, prev_weights, None, None | |
def _customforward( | |
self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, drop_im=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 | |
): | |
if context_ref is not None: | |
global choices | |
batch_size = x.size(0) | |
# IP2P like sampling or default sampling | |
if batch_size % 3 == 0: | |
batch_size = batch_size // 3 | |
context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) | |
context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0) | |
else: | |
batch_size = batch_size // 2 | |
context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) | |
context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0) | |
fg_mask = None | |
weights = None | |
alphas = None | |
predicted_rgb = None | |
x = ( | |
self.attn1( | |
self.norm1(x), | |
context=context if self.disable_self_attn else None, | |
additional_tokens=additional_tokens, | |
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self | |
if not self.disable_self_attn | |
else 0, | |
) | |
+ x | |
) | |
x = ( | |
self.attn2( | |
self.norm2(x), context=context, additional_tokens=additional_tokens, | |
) | |
+ x | |
) | |
if context_ref is not None: | |
if self.rendered_feat is not None: | |
x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1)) | |
else: | |
xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x, | |
context_ref, | |
context, | |
pose, | |
prev_weights, | |
mask_ref) | |
self.rendered_feat = xref | |
x = self.pose_emb_layers(torch.cat([x, xref], -1)) | |
x = self.ff(self.norm3(x)) + x | |
return x, fg_mask, weights, alphas, predicted_rgb | |
def log_images( | |
model, | |
batch, | |
N: int = 1, | |
noise=None, | |
scale_im=3.5, | |
num_steps: int = 10, | |
ucg_keys: List[str] = None, | |
**kwargs, | |
): | |
log = dict() | |
conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders] | |
ucg_keys = conditioner_input_keys | |
pose = batch['pose'] | |
c, uc = model.conditioner.get_unconditional_conditioning( | |
batch, | |
force_uc_zero_embeddings=ucg_keys | |
if len(model.conditioner.embedders) > 0 | |
else [], | |
force_ref_zero_embeddings=True | |
) | |
_, n = 1, len(pose)-1 | |
sampling_kwargs = {} | |
if scale_im > 0: | |
if uc is not None: | |
if isinstance(pose, list): | |
pose = pose[:N]*3 | |
else: | |
pose = torch.cat([pose[:N]] * 3) | |
else: | |
if uc is not None: | |
if isinstance(pose, list): | |
pose = pose[:N]*2 | |
else: | |
pose = torch.cat([pose[:N]] * 2) | |
sampling_kwargs['pose'] = pose | |
sampling_kwargs['drop_im'] = None | |
sampling_kwargs['mask_ref'] = None | |
for k in c: | |
if isinstance(c[k], torch.Tensor): | |
c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc)) | |
import time | |
st = time.time() | |
with model.ema_scope("Plotting"): | |
samples = model.sample( | |
c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs | |
) | |
model.clear_rendered_feat() | |
samples = model.decode_first_stage(samples) | |
print("Time taken for sampling", time.time() - st) | |
log["samples"] = samples.cpu() | |
return log | |
def process_camera_json(camera_json, example_cam): | |
# replace all single quotes in the camera_json with quotes quotes | |
camera_json = camera_json.replace("'", "\"") | |
print("input camera json") | |
print(camera_json) | |
camera_dict = json.loads(camera_json)["scene.camera"] | |
eye = torch.tensor([camera_dict["eye"]["x"], camera_dict["eye"]["y"], camera_dict["eye"]["z"]], dtype=torch.float32).unsqueeze(0) | |
up = torch.tensor([camera_dict["up"]["x"], camera_dict["up"]["y"], camera_dict["up"]["z"]], dtype=torch.float32).unsqueeze(0) | |
center = torch.tensor([camera_dict["center"]["x"], camera_dict["center"]["y"], camera_dict["center"]["z"]], dtype=torch.float32).unsqueeze(0) | |
new_R, new_T = look_at_view_transform(eye=eye, at=center, up=up) | |
print("focal length", example_cam.focal_length) | |
print("principal point", example_cam.principal_point) | |
newcam = PerspectiveCameras(R=new_R, | |
T=new_T, | |
focal_length=example_cam.focal_length, | |
principal_point=example_cam.principal_point, | |
image_size=512) | |
print("input pose") | |
print(newcam.get_world_to_view_transform().get_matrix()) | |
return newcam | |
def load_and_return_model_and_data(config, model, | |
ckpt="/data/gdsu/customization3d/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors", | |
delta_ckpt=None, | |
train=False, | |
valid=False, | |
far=3, | |
num_images=1, | |
num_ref=8, | |
max_images=20, | |
): | |
config = OmegaConf.load(config) | |
# load data | |
data = None | |
# config.data.params.jitter = False | |
# config.data.params.addreg = False | |
# config.data.params.bbox = False | |
# data = instantiate_from_config(config.data) | |
# data = data.train_dataset | |
# single_id = data.single_id | |
# if hasattr(data, 'rotations'): | |
# total_images = len(data.rotations[data.sequence_list[single_id]]) | |
# else: | |
# total_images = len(data.annotations['chair']) | |
# print(f"Total images in dataset: {total_images}") | |
model, msg = load_delta_model(model, delta_ckpt,) | |
# change forward methods to store rendered features and use the pre-calculated reference features | |
def register_recr(net_): | |
if net_.__class__.__name__ == 'SpatialTransformer': | |
print(net_.__class__.__name__, "adding control") | |
bound_method = customforward.__get__(net_, net_.__class__) | |
setattr(net_, 'forward', bound_method) | |
return | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
register_recr(net__) | |
return | |
def register_recr2(net_): | |
if net_.__class__.__name__ == 'BasicTransformerBlock': | |
print(net_.__class__.__name__, "adding control") | |
bound_method = _customforward.__get__(net_, net_.__class__) | |
setattr(net_, 'forward', bound_method) | |
return | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
register_recr2(net__) | |
return | |
sub_nets = model.model.diffusion_model.named_children() | |
for net in sub_nets: | |
register_recr(net[1]) | |
register_recr2(net[1]) | |
# start sampling | |
model.clear_rendered_feat() | |
return model, data | |
def sample(model, data, | |
num_images=1, | |
prompt="", | |
appendpath="", | |
camera_json=None, | |
train=False, | |
scale=7.5, | |
scale_im=3.5, | |
beta=1.0, | |
num_ref=8, | |
skipreflater=False, | |
num_steps=10, | |
valid=False, | |
max_images=20, | |
seed=42, | |
camera_path="pretrained-models/car0/camera.bin", | |
): | |
""" | |
Only works with num_images=1 (because of camera_json processing) | |
""" | |
if num_images != 1: | |
print("forcing num_images to be 1") | |
num_images = 1 | |
# set guidance scales | |
model.sampler.guider.scale_im = scale_im | |
model.sampler.guider.scale = scale | |
seed_everything(seed) | |
# load cameras | |
cameras_val, cameras_train = torch.load(camera_path) | |
global choices | |
num_ref = 8 | |
max_diff = len(cameras_train)/num_ref | |
choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)] | |
cameras_train_final = [cameras_train[i] for i in choices] | |
# start sampling | |
model.clear_rendered_feat() | |
if prompt == "": | |
prompt = None | |
noise = torch.randn(1, 4, 64, 64).to('cuda').repeat(num_images, 1, 1, 1) | |
# random sample camera poses | |
pose_ids = np.random.choice(len(cameras_val), num_images, replace=False) | |
print(pose_ids) | |
pose_ids[0] = 21 | |
pose = [cameras_val[i] for i in pose_ids] | |
print("example camera") | |
print(pose[0].R) | |
print(pose[0].T) | |
print(pose[0].focal_length) | |
print(pose[0].principal_point) | |
# prepare batches [if translating then call required functions on the target pose] | |
batches = [] | |
for i in range(num_images): | |
batch = {'pose': [pose[i]] + cameras_train_final, | |
"original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), | |
"target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), | |
"crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2), | |
"original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), | |
"target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), | |
"crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2), | |
} | |
batch_ = copy.deepcopy(batch) | |
batch_["pose"][0] = process_camera_json(camera_json, pose[0]) | |
batch_["pose"] = [join_cameras_as_batch(batch_["pose"])] | |
# print('batched') | |
# print(batch_["pose"][0].get_world_to_view_transform().get_matrix()) | |
batches.append(batch_) | |
print(f'len batches: {len(batches)}') | |
image = None | |
with torch.no_grad(): | |
for batch in batches: | |
for key in batch.keys(): | |
if isinstance(batch[key], torch.Tensor): | |
batch[key] = batch[key].to('cuda') | |
elif 'pose' in key: | |
batch[key] = [x.to('cuda') for x in batch[key]] | |
else: | |
pass | |
if prompt is not None: | |
batch["txt"] = [prompt for _ in range(1)] | |
batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)] | |
print(batch["txt"]) | |
N = 1 | |
log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=num_steps, scale_im=scale_im) | |
image = log_["samples"] | |
torch.cuda.empty_cache() | |
model.clear_rendered_feat() | |
print("generation done") | |
return image | |