diff --git a/.gitignore b/.gitignore index f478597b97dfefc44ea7c2ddc3e2d9cba4ee8f45..bd8df4a0364212e725781a9085e6feca84542f14 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ __pycache__/ # ignore sdxl pretrained-models/*.safetensors + +# ignore internal code repo +custom-diffusion360/ diff --git a/Dockerfile b/Dockerfile index 0d912a4c4d34feb20c07f6da18e82bc7042d402c..2cf46d559315994da49ef14bf3f364bdf113a9c2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -32,6 +32,8 @@ RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" RUN wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P pretrained-models RUN wget https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors -P pretrained-models +RUN git clone https://github.com/customdiffusion360/custom-diffusion360.git + ENV GRADIO_SERVER_NAME=0.0.0.0 ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "pose", "python", "app.py"] diff --git a/app.py b/app.py index 0f917bba581e1852d100687c7bca7fac7da84755..c94fcbff665f68763e6471aa4895111c84b79a11 100644 --- a/app.py +++ b/app.py @@ -13,13 +13,10 @@ import sys # Mesh imports from pytorch3d.io import load_objs_as_meshes from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene -from pytorch3d.transforms import Transform3d, RotateAxisAngle, Translate, Rotate +from pytorch3d.transforms import RotateAxisAngle, Translate from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model -# add current directory to path -# sys.path.append(os.path.dirname(os.path.realpath(__file__))) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -184,7 +181,7 @@ current_data = None current_model = None global base_model -BASE_CONFIG = "configs/train_co3d_concept.yaml" +BASE_CONFIG = "custom-diffusion360/configs/train_co3d_concept.yaml" BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors" start_time = time.time() diff --git a/configs/train_co3d_concept.yaml b/configs/train_co3d_concept.yaml deleted file mode 100644 index a159974476846410847249f4112aaad2c1168632..0000000000000000000000000000000000000000 --- a/configs/train_co3d_concept.yaml +++ /dev/null @@ -1,198 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.13025 - disable_first_stage_autocast: True - trainkeys: pose - multiplier: 0.05 - loss_rgb_lambda: 5 - loss_fg_lambda: 10 - loss_bg_lambda: 10 - log_keys: - - txt - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - weighting_config: - target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - adm_in_channels: 2816 - num_classes: sequential - use_checkpoint: False - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2] - num_res_blocks: 2 - channel_mult: [1, 2, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: [1, 2, 10] - context_dim: 2048 - spatial_transformer_attn_type: softmax-xformers - image_cross_blocks: [0, 2, 4, 6, 8, 10] - rgb: True - far: 2 - num_samples: 24 - not_add_context_in_triplane: False - rgb_predict: True - add_lora: False - average: False - use_prev_weights_imp_sample: True - stratified: True - imp_sampling_percent: 0.9 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - # crossattn cond - - is_trainable: False - input_keys: txt,txt_ref - target: sgm.modules.encoders.modules.FrozenCLIPEmbedder - params: - layer: hidden - layer_idx: 11 - modifier_token: - # crossattn and vector cond - - is_trainable: False - input_keys: txt,txt_ref - target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - arch: ViT-bigG-14 - version: laion2b_s39b_b160k - layer: penultimate - always_return_pooled: True - legacy: False - modifier_token: - # vector cond - - is_trainable: False - input_keys: original_size_as_tuple,original_size_as_tuple_ref - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 # multiplied by two - # vector cond - - is_trainable: False - input_keys: crop_coords_top_left,crop_coords_top_left_ref - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 # multiplied by two - # vector cond - - is_trainable: False - input_keys: target_size_as_tuple,target_size_as_tuple_ref - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 # multiplied by two - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper - params: - ckpt_path: pretrained-models/sdxl_vae.safetensors - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef - params: - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling - params: - num_idx: 1000 - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - sigma_sampler_config_ref: - target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling - params: - num_idx: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef - params: - scale: 7.5 - -data: - target: sgm.data.data_co3d.CustomDataDictLoader - params: - batch_size: 1 - num_workers: 4 - category: teddybear - img_size: 512 - skip: 2 - num_images: 5 - mask_images: True - single_id: 0 - bbox: True - addreg: True - drop_ratio: 0.25 - drop_txt: 0.1 - modifier_token: - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 1600 - save_top_k: -1 - save_on_train_epoch_end: False - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - enable_autocast: False - batch_frequency: 5000 - max_images: 8 - increase_log_steps: False - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 1 - n_rows: 2 - - trainer: - devices: 0,1,2,3 - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_steps: 1610 - # val_check_interval: 400 diff --git a/sampling_for_demo.py b/sampling_for_demo.py index c6dd74bbe410b35dbc3ef1ae3ae96904edec62a7..6800fd643687a52ad9fce0f602d46b674f770969 100644 --- a/sampling_for_demo.py +++ b/sampling_for_demo.py @@ -14,7 +14,7 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch import json -sys.path.append('./') +sys.path.append('./custom-diffusion360/') from sgm.util import instantiate_from_config, load_safetensors choices = [] @@ -49,7 +49,6 @@ def load_base_model(config, ckpt=None, verbose=True): m, u = model.load_state_dict(sd, strict=False) - model.cuda() model.eval() return model @@ -84,7 +83,6 @@ def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True): for param in model.parameters(): param.requires_grad = False - model.cuda() model.eval() return model, msg @@ -290,7 +288,7 @@ def process_camera_json(camera_json, example_cam): def load_and_return_model_and_data(config, model, - ckpt="/data/gdsu/customization3d/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors", + ckpt="pretrained-models/sd_xl_base_1.0.safetensors", delta_ckpt=None, train=False, valid=False, @@ -318,6 +316,7 @@ def load_and_return_model_and_data(config, model, # print(f"Total images in dataset: {total_images}") model, msg = load_delta_model(model, delta_ckpt,) + model = model.cuda() # change forward methods to store rendered features and use the pre-calculated reference features def register_recr(net_): diff --git a/sgm/__init__.py b/sgm/__init__.py deleted file mode 100644 index 24bc84af8b1041de34b9816e0507cb1ac207bd13..0000000000000000000000000000000000000000 --- a/sgm/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .models import AutoencodingEngine, DiffusionEngine -from .util import get_configs_path, instantiate_from_config - -__version__ = "0.1.0" diff --git a/sgm/data/__init__.py b/sgm/data/__init__.py deleted file mode 100644 index c3076683da945963a7b12d4444a382fb43ddab58..0000000000000000000000000000000000000000 --- a/sgm/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# from .dataset import StableDataModuleFromConfig diff --git a/sgm/data/data_co3d.py b/sgm/data/data_co3d.py deleted file mode 100644 index fb338d3441c4a41c0e4604919f1f9d5655b8319b..0000000000000000000000000000000000000000 --- a/sgm/data/data_co3d.py +++ /dev/null @@ -1,762 +0,0 @@ -# code taken and modified from https://github.com/amyxlase/relpose-plus-plus/blob/b33f7d5000cf2430bfcda6466c8e89bc2dcde43f/relpose/dataset/co3d_v2.py#L346) -import os.path as osp -import random - -import numpy as np -import torch -import pytorch_lightning as pl - -from PIL import Image, ImageFile -import json -import gzip -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms -from pytorch3d.renderer.cameras import PerspectiveCameras -from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.implicitron.dataset.utils import adjust_camera_to_bbox_crop_, adjust_camera_to_image_scale_ -from pytorch3d.transforms import Rotate, Translate - - -CO3D_DIR = "data/training/" - -Image.MAX_IMAGE_PIXELS = None -ImageFile.LOAD_TRUNCATED_IMAGES = True - - -# Added: normalize camera poses -def intersect_skew_line_groups(p, r, mask): - # p, r both of shape (B, N, n_intersected_lines, 3) - # mask of shape (B, N, n_intersected_lines) - p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) - _, p_line_intersect = _point_line_distance( - p, r, p_intersect[..., None, :].expand_as(p) - ) - intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( - dim=-1 - ) - return p_intersect, p_line_intersect, intersect_dist_squared, r - - -def intersect_skew_lines_high_dim(p, r, mask=None): - # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions - dim = p.shape[-1] - # make sure the heading vectors are l2-normed - if mask is None: - mask = torch.ones_like(p[..., 0]) - r = torch.nn.functional.normalize(r, dim=-1) - - eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] - I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] - sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) - p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] - - if torch.any(torch.isnan(p_intersect)): - print(p_intersect) - assert False - return p_intersect, r - - -def _point_line_distance(p1, r1, p2): - df = p2 - p1 - proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) - line_pt_nearest = p2 - proj_vector - d = (proj_vector).norm(dim=-1) - return d, line_pt_nearest - - -def compute_optical_axis_intersection(cameras): - centers = cameras.get_camera_center() - principal_points = cameras.principal_point - - one_vec = torch.ones((len(cameras), 1)) - optical_axis = torch.cat((principal_points, one_vec), -1) - - pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) - - pp2 = torch.zeros((pp.shape[0], 3)) - for i in range(0, pp.shape[0]): - pp2[i] = pp[i][i] - - directions = pp2 - centers - centers = centers.unsqueeze(0).unsqueeze(0) - directions = directions.unsqueeze(0).unsqueeze(0) - - p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( - p=centers, r=directions, mask=None - ) - - p_intersect = p_intersect.squeeze().unsqueeze(0) - dist = (p_intersect - centers).norm(dim=-1) - - return p_intersect, dist, p_line_intersect, pp2, r - - -def normalize_cameras(cameras, scale=1.0): - """ - Normalizes cameras such that the optical axes point to the origin and the average - distance to the origin is 1. - - Args: - cameras (List[camera]). - """ - - # Let distance from first camera to origin be unit - new_cameras = cameras.clone() - new_transform = new_cameras.get_world_to_view_transform() - - p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection( - cameras - ) - t = Translate(p_intersect) - - # scale = dist.squeeze()[0] - scale = max(dist.squeeze()) - - # Degenerate case - if scale == 0: - print(cameras.T) - print(new_transform.get_matrix()[:, 3, :3]) - return -1 - assert scale != 0 - - new_transform = t.compose(new_transform) - new_cameras.R = new_transform.get_matrix()[:, :3, :3] - new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale - return new_cameras, p_intersect, p_line_intersect, pp, r - - -def centerandalign(cameras, scale=1.0): - """ - Normalizes cameras such that the optical axes point to the origin and the average - distance to the origin is 1. - - Args: - cameras (List[camera]). - """ - - # Let distance from first camera to origin be unit - new_cameras = cameras.clone() - new_transform = new_cameras.get_world_to_view_transform() - - p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection( - cameras - ) - t = Translate(p_intersect) - - centers = [cam.get_camera_center() for cam in new_cameras] - centers = torch.concat(centers, 0).cpu().numpy() - m = len(cameras) - - # https://math.stackexchange.com/questions/99299/best-fitting-plane-given-a-set-of-points - A = np.hstack((centers[:m, :2], np.ones((m, 1)))) - B = centers[:m, 2:] - if A.shape[0] == 2: - x = A.T @ np.linalg.inv(A @ A.T) @ B - else: - x = np.linalg.inv(A.T @ A) @ A.T @ B - a, b, c = x.flatten() - n = np.array([a, b, 1]) - n /= np.linalg.norm(n) - - # https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d - v = np.cross(n, [0, 1, 0]) - s = np.linalg.norm(v) - c = np.dot(n, [0, 1, 0]) - V = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) - rot = torch.from_numpy(np.eye(3) + V + V @ V * (1 - c) / s**2).float() - - scale = dist.squeeze()[0] - - # Degenerate case - if scale == 0: - print(cameras.T) - print(new_transform.get_matrix()[:, 3, :3]) - return -1 - assert scale != 0 - - rot = Rotate(rot.T) - - new_transform = rot.compose(t).compose(new_transform) - new_cameras.R = new_transform.get_matrix()[:, :3, :3] - new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale - return new_cameras - - -def square_bbox(bbox, padding=0.0, astype=None): - """ - Computes a square bounding box, with optional padding parameters. - - Args: - bbox: Bounding box in xyxy format (4,). - - Returns: - square_bbox in xyxy format (4,). - """ - if astype is None: - astype = type(bbox[0]) - bbox = np.array(bbox) - center = ((bbox[:2] + bbox[2:]) / 2).round().astype(int) - extents = (bbox[2:] - bbox[:2]) / 2 - s = (max(extents) * (1 + padding)).round().astype(int) - square_bbox = np.array( - [center[0] - s, center[1] - s, center[0] + s, center[1] + s], - dtype=astype, - ) - - return square_bbox - - -class Co3dDataset(Dataset): - def __init__( - self, - category, - split="train", - skip=2, - img_size=1024, - num_images=4, - mask_images=False, - single_id=0, - bbox=False, - modifier_token=None, - addreg=False, - drop_ratio=0.5, - drop_txt=0.1, - categoryname=None, - aligncameras=False, - repeat=100, - addlen=False, - onlyref=False, - ): - """ - Args: - category (iterable): List of categories to use. If "all" is in the list, - all training categories are used. - num_images (int): Default number of images in each batch. - normalize_cameras (bool): If True, normalizes cameras so that the - intersection of the optical axes is placed at the origin and the norm - of the first camera translation is 1. - mask_images (bool): If True, masks out the background of the images. - """ - # category = CATEGORIES - category = sorted(category.split(',')) - self.category = category - self.single_id = single_id - self.addlen = addlen - self.onlyref = onlyref - self.categoryname = categoryname - self.bbox = bbox - self.modifier_token = modifier_token - self.addreg = addreg - self.drop_txt = drop_txt - self.skip = skip - if self.addreg: - with open(f'data/regularization/{category[0]}_sp_generated/caption.txt', "r") as f: - self.regcaptions = f.read().splitlines() - self.reglen = len(self.regcaptions) - self.regimpath = f'data/regularization/{category[0]}_sp_generated' - - self.low_quality_translations = [] - self.rotations = {} - self.category_map = {} - co3d_dir = CO3D_DIR - for c in category: - subset = 'fewview_dev' - category_dir = osp.join(co3d_dir, c) - frame_file = osp.join(category_dir, "frame_annotations.jgz") - sequence_file = osp.join(category_dir, "sequence_annotations.jgz") - subset_lists_file = osp.join(category_dir, f"set_lists/set_lists_{subset}.json") - bbox_file = osp.join(category_dir, f"{c}_bbox.jgz") - - with open(subset_lists_file) as f: - subset_lists_data = json.load(f) - - with gzip.open(sequence_file, "r") as fin: - sequence_data = json.loads(fin.read()) - - with gzip.open(bbox_file, "r") as fin: - bbox_data = json.loads(fin.read()) - - with gzip.open(frame_file, "r") as fin: - frame_data = json.loads(fin.read()) - - frame_data_processed = {} - for f_data in frame_data: - sequence_name = f_data["sequence_name"] - if sequence_name not in frame_data_processed: - frame_data_processed[sequence_name] = {} - frame_data_processed[sequence_name][f_data["frame_number"]] = f_data - - good_quality_sequences = set() - for seq_data in sequence_data: - if seq_data["viewpoint_quality_score"] > 0.5: - good_quality_sequences.add(seq_data["sequence_name"]) - - for subset in ["train"]: - for seq_name, frame_number, filepath in subset_lists_data[subset]: - if seq_name not in good_quality_sequences: - continue - - if seq_name not in self.rotations: - self.rotations[seq_name] = [] - self.category_map[seq_name] = c - - mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") - - frame_data = frame_data_processed[seq_name][frame_number] - - self.rotations[seq_name].append( - { - "filepath": filepath, - "R": frame_data["viewpoint"]["R"], - "T": frame_data["viewpoint"]["T"], - "focal_length": frame_data["viewpoint"]["focal_length"], - "principal_point": frame_data["viewpoint"]["principal_point"], - "mask": mask_path, - "txt": "a car", - "bbox": bbox_data[mask_path] - } - ) - - for seq_name in self.rotations: - seq_data = self.rotations[seq_name] - cameras = PerspectiveCameras( - focal_length=[data["focal_length"] for data in seq_data], - principal_point=[data["principal_point"] for data in seq_data], - R=[data["R"] for data in seq_data], - T=[data["T"] for data in seq_data], - ) - - normalized_cameras, _, _, _, _ = normalize_cameras(cameras) - if aligncameras: - normalized_cameras = centerandalign(cameras) - - if normalized_cameras == -1: - print("Error in normalizing cameras: camera scale was 0") - del self.rotations[seq_name] - continue - - for i, data in enumerate(seq_data): - self.rotations[seq_name][i]["R"] = normalized_cameras.R[i] - self.rotations[seq_name][i]["T"] = normalized_cameras.T[i] - self.rotations[seq_name][i]["R_original"] = torch.from_numpy(np.array(seq_data[i]["R"])) - self.rotations[seq_name][i]["T_original"] = torch.from_numpy(np.array(seq_data[i]["T"])) - - # Make sure translations are not ridiculous - if self.rotations[seq_name][i]["T"][0] + self.rotations[seq_name][i]["T"][1] + self.rotations[seq_name][i]["T"][2] > 1e5: - bad_seq = True - self.low_quality_translations.append(seq_name) - break - - for seq_name in self.low_quality_translations: - if seq_name in self.rotations: - del self.rotations[seq_name] - - self.sequence_list = list(self.rotations.keys()) - - self.transform = transforms.Compose( - [ - transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Lambda(lambda x: x * 2.0 - 1.0) - ] - ) - self.transformim = transforms.Compose( - [ - transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(img_size), - transforms.ToTensor(), - transforms.Lambda(lambda x: x * 2.0 - 1.0) - ] - ) - self.transformmask = transforms.Compose( - [ - transforms.Resize(img_size // 8), - transforms.ToTensor(), - ] - ) - - self.num_images = num_images - self.image_size = img_size - self.normalize_cameras = normalize_cameras - self.mask_images = mask_images - self.drop_ratio = drop_ratio - self.kernel_tensor = torch.ones((1, 1, 7, 7)) - self.repeat = repeat - print(self.sequence_list, "$$$$$$$$$$$$$$$$$$$$$") - self.valid_ids = np.arange(0, len(self.rotations[self.sequence_list[self.single_id]]), skip).tolist() - if split == 'test': - self.valid_ids = list(set(np.arange(0, len(self.rotations[self.sequence_list[self.single_id]])).tolist()).difference(self.valid_ids)) - - print( - f"Low quality translation sequences, not used: {self.low_quality_translations}" - ) - print(f"Data size: {len(self)}") - - def __len__(self): - return (len(self.valid_ids))*self.repeat + (1 if self.addlen else 0) - - def _padded_bbox(self, bbox, w, h): - if w < h: - bbox = np.array([0, 0, w, h]) - else: - bbox = np.array([0, 0, w, h]) - return square_bbox(bbox.astype(np.float32)) - - def _crop_bbox(self, bbox, w, h): - bbox = square_bbox(bbox.astype(np.float32)) - - side_length = bbox[2] - bbox[0] - center = (bbox[:2] + bbox[2:]) / 2 - extent = side_length / 2 - - # Final coordinates need to be integer for cropping. - ul = (center - extent).round().astype(int) - lr = ul + np.round(2 * extent).astype(int) - return np.concatenate((ul, lr)) - - def _crop_image(self, image, bbox, white_bg=False): - if white_bg: - # Only support PIL Images - image_crop = Image.new( - "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255) - ) - image_crop.paste(image, (-bbox[0], -bbox[1])) - else: - image_crop = transforms.functional.crop( - image, - top=bbox[1], - left=bbox[0], - height=bbox[3] - bbox[1], - width=bbox[2] - bbox[0], - ) - return image_crop - - def __getitem__(self, index, specific_id=None, validation=False): - sequence_name = self.sequence_list[self.single_id] - - metadata = self.rotations[sequence_name] - - if validation: - drop_text = False - drop_im = False - else: - drop_im = np.random.uniform(0, 1) < self.drop_ratio - if not drop_im: - drop_text = np.random.uniform(0, 1) < self.drop_txt - else: - drop_text = False - - size = self.image_size - - # sample reference ids - listofindices = self.valid_ids.copy() - max_diff = len(listofindices) // (self.num_images-1) - if (index*self.skip) % len(metadata) in listofindices: - listofindices.remove((index*self.skip) % len(metadata)) - references = np.random.choice(np.arange(0, len(listofindices)+1, max_diff), self.num_images-1, replace=False) - rem = np.random.randint(0, max_diff) - references = [listofindices[(x + rem) % len(listofindices)] for x in references] - ids = [(index*self.skip) % len(metadata)] + references - - # special case to save features corresponding to ref image as part of model buffer - if self.onlyref: - ids = references + [(index*self.skip) % len(metadata)] - if specific_id is not None: # remove this later - ids = specific_id - - # get data - batch = self.get_data(index=self.single_id, ids=ids) - - # text prompt - if self.modifier_token is not None: - name = self.category[0] if self.categoryname is None else self.categoryname - batch['txt'] = [f'photo of a {self.modifier_token} {name}' for _ in range(len(batch['txt']))] - - # replace with regularization image if drop_im - if drop_im and self.addreg: - select_id = np.random.randint(0, self.reglen) - batch["image"] = [self.transformim(Image.open(f'{self.regimpath}/images/{select_id}.png').convert('RGB'))] - batch['txt'] = [self.regcaptions[select_id]] - batch["original_size_as_tuple"] = torch.ones_like(batch["original_size_as_tuple"])*1024 - - # create camera class and adjust intrinsics for crop - cameras = [PerspectiveCameras(R=batch['R'][i].unsqueeze(0), - T=batch['T'][i].unsqueeze(0), - focal_length=batch['focal_lengths'][i].unsqueeze(0), - principal_point=batch['principal_points'][i].unsqueeze(0), - image_size=self.image_size - ) - for i in range(len(ids))] - for i, cam in enumerate(cameras): - adjust_camera_to_bbox_crop_(cam, batch["original_size_as_tuple"][i, :2], batch["crop_coords"][i]) - adjust_camera_to_image_scale_(cam, batch["original_size_as_tuple"][i, 2:], torch.tensor([self.image_size, self.image_size])) - - # create mask and dilated mask for mask based losses - batch["depth"] = batch["mask"].clone() - batch["mask"] = torch.clamp(torch.nn.functional.conv2d(batch["mask"], self.kernel_tensor, padding='same'), 0, 1) - if not self.mask_images: - batch["mask"] = [None for i in range(len(ids))] - - # special case to save features corresponding to zero image - if index == self.__len__()-1 and self.addlen: - batch["image"][0] *= 0. - - return {"jpg": batch["image"][0], - "txt": batch["txt"][0] if not drop_text else "", - "jpg_ref": batch["image"][1:] if not drop_im else torch.stack([2*torch.rand_like(batch["image"][0])-1. for _ in range(len(ids)-1)], dim=0), - "txt_ref": batch["txt"][1:] if not drop_im else ["" for _ in range(len(ids)-1)], - "pose": cameras, - "mask": batch["mask"][0] if not drop_im else torch.ones_like(batch["mask"][0]), - "mask_ref": batch["masks_padding"][1:], - "depth": batch["depth"][0] if len(batch["depth"]) > 0 else None, - "filepaths": batch["filepaths"], - "original_size_as_tuple": batch["original_size_as_tuple"][0][2:], - "target_size_as_tuple": torch.ones_like(batch["original_size_as_tuple"][0][2:])*size, - "crop_coords_top_left": torch.zeros_like(batch["crop_coords"][0][:2]), - "original_size_as_tuple_ref": batch["original_size_as_tuple"][1:][:, 2:], - "target_size_as_tuple_ref": torch.ones_like(batch["original_size_as_tuple"][1:][:, 2:])*size, - "crop_coords_top_left_ref": torch.zeros_like(batch["crop_coords"][1:][:, :2]), - "drop_im": torch.Tensor([1-drop_im*1.]) - } - - def get_data(self, index=None, sequence_name=None, ids=(0, 1)): - if sequence_name is None: - sequence_name = self.sequence_list[index] - metadata = self.rotations[sequence_name] - category = self.category_map[sequence_name] - annos = [metadata[i] for i in ids] - images = [] - rotations = [] - translations = [] - focal_lengths = [] - principal_points = [] - txts = [] - masks = [] - filepaths = [] - images_transformed = [] - masks_transformed = [] - original_size_as_tuple = [] - crop_parameters = [] - masks_padding = [] - depths = [] - - for counter, anno in enumerate(annos): - filepath = anno["filepath"] - filepaths.append(filepath) - image = Image.open(osp.join(CO3D_DIR, filepath)).convert("RGB") - - mask_name = osp.basename(filepath.replace(".jpg", ".png")) - - mask_path = osp.join( - CO3D_DIR, category, sequence_name, "masks", mask_name - ) - mask = Image.open(mask_path).convert("L") - - if mask.size != image.size: - mask = mask.resize(image.size) - - mask_padded = Image.fromarray((np.ones_like(mask) > 0)) - mask = Image.fromarray((np.array(mask) > 125)) - masks.append(mask) - - # crop image around object - w, h = image.width, image.height - bbox = np.array(anno["bbox"]) - if len(bbox) == 0: - bbox = np.array([0, 0, w, h]) - - if self.bbox and counter > 0: - bbox = self._crop_bbox(bbox, w, h) - else: - bbox = self._padded_bbox(None, w, h) - image = self._crop_image(image, bbox) - mask = self._crop_image(mask, bbox) - mask_padded = self._crop_image(mask_padded, bbox) - masks_padding.append(self.transformmask(mask_padded)) - images_transformed.append(self.transform(image)) - masks_transformed.append(self.transformmask(mask)) - - crop_parameters.append(torch.tensor([bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] ]).int()) - original_size_as_tuple.append(torch.tensor([w, h, bbox[2] - bbox[0], bbox[3] - bbox[1]])) - images.append(image) - rotations.append(anno["R"]) - translations.append(anno["T"]) - focal_lengths.append(torch.tensor(anno["focal_length"])) - principal_points.append(torch.tensor(anno["principal_point"])) - txts.append(anno["txt"]) - - images = images_transformed - batch = { - "model_id": sequence_name, - "category": category, - "original_size_as_tuple": torch.stack(original_size_as_tuple), - "crop_coords": torch.stack(crop_parameters), - "n": len(metadata), - "ind": torch.tensor(ids), - "txt": txts, - "filepaths": filepaths, - "masks_padding": torch.stack(masks_padding) if len(masks_padding) > 0 else [], - "depth": torch.stack(depths) if len(depths) > 0 else [], - } - - batch["R"] = torch.stack(rotations) - batch["T"] = torch.stack(translations) - batch["focal_lengths"] = torch.stack(focal_lengths) - batch["principal_points"] = torch.stack(principal_points) - - # Add images - if self.transform is None: - batch["image"] = images - else: - batch["image"] = torch.stack(images) - batch["mask"] = torch.stack(masks_transformed) - - return batch - - @staticmethod - def collate_fn(batch): - """A function to collate the data across batches. This function must be passed to pytorch's DataLoader to collate batches. - Args: - batch(list): List of objects returned by this class' __getitem__ function. This is given by pytorch's dataloader that calls __getitem__ - multiple times and expects a collated batch. - Returns: - dict: The collated dictionary representing the data in the batch. - """ - result = { - "jpg": [], - "txt": [], - "jpg_ref": [], - "txt_ref": [], - "pose": [], - "original_size_as_tuple": [], - "original_size_as_tuple_ref": [], - "crop_coords_top_left": [], - "crop_coords_top_left_ref": [], - "target_size_as_tuple_ref": [], - "target_size_as_tuple": [], - "drop_im": [], - "mask_ref": [], - } - if batch[0]["mask"] is not None: - result["mask"] = [] - if batch[0]["depth"] is not None: - result["depth"] = [] - - for batch_obj in batch: - for key in result.keys(): - result[key].append(batch_obj[key]) - for key in result.keys(): - if not (key == 'pose' or 'txt' in key or 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key): - result[key] = torch.stack(result[key], dim=0) - elif 'txt_ref' in key: - result[key] = [item for sublist in result[key] for item in sublist] - elif 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key: - result[key] = torch.cat(result[key], dim=0) - elif 'pose' in key: - result[key] = [join_cameras_as_batch(cameras) for cameras in result[key]] - - return result - - -class CustomDataDictLoader(pl.LightningDataModule): - def __init__( - self, - category, - batch_size, - mask_images=False, - skip=1, - img_size=1024, - num_images=4, - num_workers=0, - shuffle=True, - single_id=0, - modifier_token=None, - bbox=False, - addreg=False, - drop_ratio=0.5, - jitter=False, - drop_txt=0.1, - categoryname=None, - ): - super().__init__() - - self.batch_size = batch_size - self.num_workers = num_workers - self.shuffle = shuffle - self.train_dataset = Co3dDataset(category, - img_size=img_size, - mask_images=mask_images, - skip=skip, - num_images=num_images, - single_id=single_id, - modifier_token=modifier_token, - bbox=bbox, - addreg=addreg, - drop_ratio=drop_ratio, - drop_txt=drop_txt, - categoryname=categoryname, - ) - self.val_dataset = Co3dDataset(category, - img_size=img_size, - mask_images=mask_images, - skip=skip, - num_images=2, - single_id=single_id, - modifier_token=modifier_token, - bbox=bbox, - addreg=addreg, - drop_ratio=0., - drop_txt=0., - categoryname=categoryname, - repeat=1, - addlen=True, - onlyref=True, - ) - self.test_dataset = Co3dDataset(category, - img_size=img_size, - mask_images=mask_images, - split="test", - skip=skip, - num_images=2, - single_id=single_id, - modifier_token=modifier_token, - bbox=False, - addreg=addreg, - drop_ratio=0., - drop_txt=0., - categoryname=categoryname, - repeat=1, - ) - self.collate_fn = Co3dDataset.collate_fn - - def prepare_data(self): - pass - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - drop_last=True, - ) - - def test_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - drop_last=True - ) diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py deleted file mode 100644 index b2f4d384c1fcaff0df13e0564450d3fa972ace42..0000000000000000000000000000000000000000 --- a/sgm/lr_scheduler.py +++ /dev/null @@ -1,135 +0,0 @@ -import numpy as np - - -class LambdaWarmUpCosineScheduler: - """ - note: use with a base_lr of 1.0 - """ - - def __init__( - self, - warm_up_steps, - lr_min, - lr_max, - lr_start, - max_decay_steps, - verbosity_interval=0, - ): - self.lr_warm_up_steps = warm_up_steps - self.lr_start = lr_start - self.lr_min = lr_min - self.lr_max = lr_max - self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0.0 - self.verbosity_interval = verbosity_interval - - def schedule(self, n, **kwargs): - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") - if n < self.lr_warm_up_steps: - lr = ( - self.lr_max - self.lr_start - ) / self.lr_warm_up_steps * n + self.lr_start - self.last_lr = lr - return lr - else: - t = (n - self.lr_warm_up_steps) / ( - self.lr_max_decay_steps - self.lr_warm_up_steps - ) - t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi) - ) - self.last_lr = lr - return lr - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaWarmUpCosineScheduler2: - """ - supports repeated iterations, configurable via lists - note: use with a base_lr of 1.0. - """ - - def __init__( - self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 - ): - assert ( - len(warm_up_steps) - == len(f_min) - == len(f_max) - == len(f_start) - == len(cycle_lengths) - ) - self.lr_warm_up_steps = warm_up_steps - self.f_start = f_start - self.f_min = f_min - self.f_max = f_max - self.cycle_lengths = cycle_lengths - self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0.0 - self.verbosity_interval = verbosity_interval - - def find_in_interval(self, n): - interval = 0 - for cl in self.cum_cycles[1:]: - if n <= cl: - return interval - interval += 1 - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ - cycle - ] * n + self.f_start[cycle] - self.last_f = f - return f - else: - t = (n - self.lr_warm_up_steps[cycle]) / ( - self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] - ) - t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi) - ) - self.last_f = f - return f - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) - - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ - cycle - ] * n + self.f_start[cycle] - self.last_f = f - return f - else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( - self.cycle_lengths[cycle] - n - ) / (self.cycle_lengths[cycle]) - self.last_f = f - return f diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py deleted file mode 100644 index c410b3747afc208e4204c8f140170e0a7808eace..0000000000000000000000000000000000000000 --- a/sgm/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .autoencoder import AutoencodingEngine -from .diffusion import DiffusionEngine diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py deleted file mode 100644 index 78fb551a230bfdbbaeb3106fad41ce034bf5f3c9..0000000000000000000000000000000000000000 --- a/sgm/models/autoencoder.py +++ /dev/null @@ -1,335 +0,0 @@ -import re -from abc import abstractmethod -from contextlib import contextmanager -from typing import Any, Dict, Tuple, Union - -import pytorch_lightning as pl -import torch -from omegaconf import ListConfig -from packaging import version -from safetensors.torch import load_file as load_safetensors - -from ..modules.diffusionmodules.model import Decoder, Encoder -from ..modules.distributions.distributions import DiagonalGaussianDistribution -from ..modules.ema import LitEma -from ..util import default, get_obj_from_str, instantiate_from_config - - -class AbstractAutoencoder(pl.LightningModule): - """ - This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, - unCLIP models, etc. Hence, it is fairly general, and specific features - (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. - """ - - def __init__( - self, - ema_decay: Union[None, float] = None, - monitor: Union[None, str] = None, - input_key: str = "jpg", - ckpt_path: Union[None, str] = None, - ignore_keys: Union[Tuple, list, ListConfig] = (), - ): - super().__init__() - self.input_key = input_key - self.use_ema = ema_decay is not None - if monitor is not None: - self.monitor = monitor - - if self.use_ema: - self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - if version.parse(torch.__version__) >= version.parse("2.0.0"): - self.automatic_optimization = False - - def init_from_ckpt( - self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple() - ) -> None: - if path.endswith("ckpt"): - sd = torch.load(path, map_location="cpu")["state_dict"] - elif path.endswith("safetensors"): - sd = load_safetensors(path) - else: - raise NotImplementedError - - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if re.match(ik, k): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - @abstractmethod - def get_input(self, batch) -> Any: - raise NotImplementedError() - - def on_train_batch_end(self, *args, **kwargs): - # for EMA computation - if self.use_ema: - self.model_ema(self) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - @abstractmethod - def encode(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError("encode()-method of abstract base class called") - - @abstractmethod - def decode(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError("decode()-method of abstract base class called") - - def instantiate_optimizer_from_config(self, params, lr, cfg): - print(f"loading >>> {cfg['target']} <<< optimizer from config") - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) - - def configure_optimizers(self) -> Any: - raise NotImplementedError() - - -class AutoencodingEngine(AbstractAutoencoder): - """ - Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL - (we also restore them explicitly as special cases for legacy reasons). - Regularizations such as KL or VQ are moved to the regularizer class. - """ - - def __init__( - self, - *args, - encoder_config: Dict, - decoder_config: Dict, - loss_config: Dict, - regularizer_config: Dict, - optimizer_config: Union[Dict, None] = None, - lr_g_factor: float = 1.0, - **kwargs, - ): - super().__init__(*args, **kwargs) - # todo: add options to freeze encoder/decoder - self.encoder = instantiate_from_config(encoder_config) - self.decoder = instantiate_from_config(decoder_config) - self.loss = instantiate_from_config(loss_config) - self.regularization = instantiate_from_config(regularizer_config) - self.optimizer_config = default( - optimizer_config, {"target": "torch.optim.Adam"} - ) - self.lr_g_factor = lr_g_factor - - def get_input(self, batch: Dict) -> torch.Tensor: - # assuming unified data format, dataloader returns a dict. - # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc) - return batch[self.input_key] - - def get_autoencoder_params(self) -> list: - params = ( - list(self.encoder.parameters()) - + list(self.decoder.parameters()) - + list(self.regularization.get_trainable_parameters()) - + list(self.loss.get_trainable_autoencoder_parameters()) - ) - return params - - def get_discriminator_params(self) -> list: - params = list(self.loss.get_trainable_parameters()) # e.g., discriminator - return params - - def get_last_layer(self): - return self.decoder.get_last_layer() - - def encode(self, x: Any, return_reg_log: bool = False) -> Any: - z = self.encoder(x) - z, reg_log = self.regularization(z) - if return_reg_log: - return z, reg_log - return z - - def decode(self, z: Any) -> torch.Tensor: - x = self.decoder(z) - return x - - def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - z, reg_log = self.encode(x, return_reg_log=True) - dec = self.decode(z) - return z, dec, reg_log - - def training_step(self, batch, batch_idx, optimizer_idx) -> Any: - x = self.get_input(batch) - z, xrec, regularization_log = self(x) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss( - regularization_log, - x, - xrec, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - ) - - self.log_dict( - log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True - ) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss( - regularization_log, - x, - xrec, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - ) - self.log_dict( - log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True - ) - return discloss - - def validation_step(self, batch, batch_idx) -> Dict: - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - log_dict.update(log_dict_ema) - return log_dict - - def _validation_step(self, batch, batch_idx, postfix="") -> Dict: - x = self.get_input(batch) - - z, xrec, regularization_log = self(x) - aeloss, log_dict_ae = self.loss( - regularization_log, - x, - xrec, - 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val" + postfix, - ) - - discloss, log_dict_disc = self.loss( - regularization_log, - x, - xrec, - 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val" + postfix, - ) - self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) - log_dict_ae.update(log_dict_disc) - self.log_dict(log_dict_ae) - return log_dict_ae - - def configure_optimizers(self) -> Any: - ae_params = self.get_autoencoder_params() - disc_params = self.get_discriminator_params() - - opt_ae = self.instantiate_optimizer_from_config( - ae_params, - default(self.lr_g_factor, 1.0) * self.learning_rate, - self.optimizer_config, - ) - opt_disc = self.instantiate_optimizer_from_config( - disc_params, self.learning_rate, self.optimizer_config - ) - - return [opt_ae, opt_disc], [] - - @torch.no_grad() - def log_images(self, batch: Dict, **kwargs) -> Dict: - log = dict() - x = self.get_input(batch) - _, xrec, _ = self(x) - log["inputs"] = x - log["reconstructions"] = xrec - with self.ema_scope(): - _, xrec_ema, _ = self(x) - log["reconstructions_ema"] = xrec_ema - return log - - -class AutoencoderKL(AutoencodingEngine): - def __init__(self, embed_dim: int, **kwargs): - ddconfig = kwargs.pop("ddconfig") - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", ()) - super().__init__( - encoder_config={"target": "torch.nn.Identity"}, - decoder_config={"target": "torch.nn.Identity"}, - regularizer_config={"target": "torch.nn.Identity"}, - loss_config=kwargs.pop("lossconfig"), - **kwargs, - ) - assert ddconfig["double_z"] - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def encode(self, x): - assert ( - not self.training - ), f"{self.__class__.__name__} only supports inference currently" - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z, **decoder_kwargs): - z = self.post_quant_conv(z) - dec = self.decoder(z, **decoder_kwargs) - return dec - - -class AutoencoderKLInferenceWrapper(AutoencoderKL): - def encode(self, x): - return super().encode(x).sample() - - -class IdentityFirstStage(AbstractAutoencoder): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def get_input(self, x: Any) -> Any: - return x - - def encode(self, x: Any, *args, **kwargs) -> Any: - return x - - def decode(self, x: Any, *args, **kwargs) -> Any: - return x diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py deleted file mode 100644 index 81238ae0ee7f730b45aff2ae76539c725c4f9c1e..0000000000000000000000000000000000000000 --- a/sgm/models/diffusion.py +++ /dev/null @@ -1,556 +0,0 @@ -from contextlib import contextmanager -from typing import Any, Dict, List, Tuple, Union, DefaultDict - -import pytorch_lightning as pl -import torch -from omegaconf import ListConfig, OmegaConf -from safetensors.torch import load_file as load_safetensors -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange -import math -import torch.nn as nn -from ..modules import UNCONDITIONAL_CONFIG -from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER -from ..modules.ema import LitEma -from ..util import ( - default, - disabled_train, - get_obj_from_str, - instantiate_from_config, - log_txt_as_img, -) - - -import collections -from functools import partial - - -def save_activations( - activations: DefaultDict, - name: str, - module: nn.Module, - inp: Tuple, - out: torch.Tensor -) -> None: - """PyTorch Forward hook to save outputs at each forward - pass. Mutates specified dict objects with each fwd pass. - """ - if isinstance(out, tuple): - if out[1] is None: - activations[name].append(out[0].detach()) - -class DiffusionEngine(pl.LightningModule): - def __init__( - self, - network_config, - denoiser_config, - first_stage_config, - conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, - sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, - scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, - network_wrapper: Union[None, str] = None, - ckpt_path: Union[None, str] = None, - use_ema: bool = False, - ema_decay_rate: float = 0.9999, - scale_factor: float = 1.0, - disable_first_stage_autocast=False, - input_key: str = "jpg", - log_keys: Union[List, None] = None, - no_cond_log: bool = False, - compile_model: bool = False, - trainkeys='pose', - multiplier=0.05, - loss_rgb_lambda=20., - loss_fg_lambda=10., - loss_bg_lambda=20., - ): - super().__init__() - self.log_keys = log_keys - self.input_key = input_key - self.trainkeys = trainkeys - self.multiplier = multiplier - self.loss_rgb_lambda = loss_rgb_lambda - self.loss_fg_lambda = loss_fg_lambda - self.loss_bg_lambda = loss_bg_lambda - self.rgb = network_config.params.rgb - self.rgb_predict = network_config.params.rgb_predict - self.add_token = ('modifier_token' in conditioner_config.params.emb_models[1].params) - self.optimizer_config = default( - optimizer_config, {"target": "torch.optim.AdamW"} - ) - model = instantiate_from_config(network_config) - self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( - model, compile_model=compile_model - ) - - self.denoiser = instantiate_from_config(denoiser_config) - self.sampler = ( - instantiate_from_config(sampler_config) - if sampler_config is not None - else None - ) - self.conditioner = instantiate_from_config( - default(conditioner_config, UNCONDITIONAL_CONFIG) - ) - self.scheduler_config = scheduler_config - self._init_first_stage(first_stage_config) - - self.loss_fn = ( - instantiate_from_config(loss_fn_config) - if loss_fn_config is not None - else None - ) - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model, decay=ema_decay_rate) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.scale_factor = scale_factor - self.disable_first_stage_autocast = disable_first_stage_autocast - self.no_cond_log = no_cond_log - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path) - - blocks = [] - if self.trainkeys == 'poseattn': - for x in self.model.diffusion_model.named_parameters(): - if not ('pose' in x[0] or 'transformer_blocks' in x[0]): - x[1].requires_grad = False - else: - if 'pose' in x[0]: - x[1].requires_grad = True - blocks.append(x[0].split('.pose')[0]) - - blocks = set(blocks) - for x in self.model.diffusion_model.named_parameters(): - if 'transformer_blocks' in x[0]: - reqgrad = False - for each in blocks: - if each in x[0] and ('attn1' in x[0] or 'attn2' in x[0] or 'pose' in x[0]): - reqgrad = True - x[1].requires_grad = True - if not reqgrad: - x[1].requires_grad = False - elif self.trainkeys == 'pose': - for x in self.model.diffusion_model.named_parameters(): - if not ('pose' in x[0]): - x[1].requires_grad = False - else: - x[1].requires_grad = True - elif self.trainkeys == 'all': - for x in self.model.diffusion_model.named_parameters(): - x[1].requires_grad = True - - self.model = self.model.to(memory_format=torch.channels_last) - - def register_activation_hooks( - self, - ) -> None: - self.activations_dict = collections.defaultdict(list) - handles = [] - for name, module in self.model.diffusion_model.named_modules(): - if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': - if hasattr(module, 'pose_emb_layers'): - handle = module.register_forward_hook( - partial(save_activations, self.activations_dict, name) - ) - handles.append(handle) - self.handles = handles - - def clear_rendered_feat(self,): - for name, module in self.model.diffusion_model.named_modules(): - if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': - if hasattr(module, 'pose_emb_layers'): - module.rendered_feat = None - - def remove_activation_hooks( - self, handles - ) -> None: - for handle in handles: - handle.remove() - - def init_from_ckpt( - self, - path: str, - ) -> None: - if path.endswith("ckpt"): - sd = torch.load(path, map_location="cpu")["state_dict"] - elif path.endswith("safetensors"): - sd = load_safetensors(path) - else: - raise NotImplementedError - - missing, unexpected = self.load_state_dict(sd, strict=False) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - def _init_first_stage(self, config): - model = instantiate_from_config(config).eval() - model.train = disabled_train - for param in model.parameters(): - param.requires_grad = False - self.first_stage_model = model - - def get_input(self, batch): - return batch[self.input_key], batch[self.input_key + '_ref'] if self.input_key + '_ref' in batch else None, batch['pose'] if 'pose' in batch else None, batch['mask'] if "mask" in batch else None, batch['mask_ref'] if "mask_ref" in batch else None, batch['depth'] if "depth" in batch else None, batch['drop_im'] if "drop_im" in batch else 0. - - @torch.no_grad() - def decode_first_stage(self, z): - z = 1.0 / self.scale_factor * z - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - out = self.first_stage_model.decode(z) - return out - - @torch.no_grad() - def encode_first_stage(self, x): - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - z = self.first_stage_model.encode(x) - z = self.scale_factor * z - return z - - def forward(self, x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch): - loss, loss_fg, loss_bg, loss_rgb = self.loss_fn(self.model, self.denoiser, self.conditioner, x, x_rgb, xr, pose, mask, mask_ref, opacity, batch) - loss_mean = loss.mean() - loss_dict = {"loss": loss_mean.item()} - if self.rgb and self.global_step > 0: - loss_fg = (loss_fg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12) - loss_bg = (loss_bg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12) - loss_mean += self.loss_fg_lambda*loss_fg - loss_mean += self.loss_bg_lambda*loss_bg - loss_dict["loss_fg"] = loss_fg.item() - loss_dict["loss_bg"] = loss_bg.item() - if self.rgb_predict and loss_rgb.mean() > 0: - loss_rgb = (loss_rgb.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12) - loss_mean += self.loss_rgb_lambda*loss_rgb - loss_dict["loss_rgb"] = loss_rgb.item() - return loss_mean, loss_dict - - def shared_step(self, batch: Dict) -> Any: - x, xr, pose, mask, mask_ref, opacity, drop_im = self.get_input(batch) - x_rgb = x.clone().detach() - x = self.encode_first_stage(x) - x = x.to(memory_format=torch.channels_last) - if xr is not None: - b, n = xr.shape[0], xr.shape[1] - xr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...")), "(b n) ... -> b n ...", b=b, n=n) - xr = drop_im.reshape(b, 1, 1, 1, 1)*xr + (1-drop_im.reshape(b, 1, 1, 1, 1))*torch.zeros_like(xr) - batch["global_step"] = self.global_step - loss, loss_dict = self(x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - self.log_dict( - loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - - self.log( - "global_step", - self.global_step, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - if self.scheduler_config is not None: - lr = self.optimizers().param_groups[0]["lr"] - self.log( - "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - return loss - - def validation_step(self, batch, batch_idx): - # print("validation data", len(self.trainer.val_dataloaders)) - loss, loss_dict = self.shared_step(batch) - return loss - - def on_train_start(self, *args, **kwargs): - if self.sampler is None or self.loss_fn is None: - raise ValueError("Sampler and loss function need to be set for training.") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - def optimizer_zero_grad(self, epoch, batch_idx, optimizer): - optimizer.zero_grad(set_to_none=True) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def instantiate_optimizer_from_config(self, params, lr, cfg): - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) - - def configure_optimizers(self): - lr = self.learning_rate - params = [] - blocks = [] - lowlrparams = [] - if self.trainkeys == 'poseattn': - lowlrparams = [] - for x in self.model.diffusion_model.named_parameters(): - if ('pose' in x[0]): - params += [x[1]] - blocks.append(x[0].split('.pose')[0]) - print(x[0]) - blocks = set(blocks) - for x in self.model.diffusion_model.named_parameters(): - if 'transformer_blocks' in x[0]: - for each in blocks: - if each in x[0] and not ('pose' in x[0]) and ('attn1' in x[0] or 'attn2' in x[0]): - lowlrparams += [x[1]] - elif self.trainkeys == 'pose': - for x in self.model.diffusion_model.named_parameters(): - if ('pose' in x[0]): - params += [x[1]] - print(x[0]) - elif self.trainkeys == 'all': - lowlrparams = [] - for x in self.model.diffusion_model.named_parameters(): - if ('pose' in x[0]): - params += [x[1]] - print(x[0]) - else: - lowlrparams += [x[1]] - - for i, embedder in enumerate(self.conditioner.embedders[:2]): - if embedder.is_trainable: - params = params + list(embedder.parameters()) - if self.add_token: - if i == 0: - for name, param in embedder.transformer.get_input_embeddings().named_parameters(): - param.requires_grad = True - print(name, "conditional model param") - params += [param] - else: - for name, param in embedder.model.token_embedding.named_parameters(): - param.requires_grad = True - print(name, "conditional model param") - params += [param] - - if len(lowlrparams) > 0: - print("different optimizer groups") - opt = self.instantiate_optimizer_from_config([{'params': params}, {'params': lowlrparams, 'lr': self.multiplier*lr}], lr, self.optimizer_config) - else: - opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - } - ] - return [opt], scheduler - return opt - - @torch.no_grad() - def sample( - self, - cond: Dict, - uc: Union[Dict, None] = None, - batch_size: int = 16, - num_steps=None, - randn=None, - shape: Union[None, Tuple, List] = None, - return_rgb=False, - mask=None, - init_im=None, - **kwargs, - ): - if randn is None: - randn = torch.randn(batch_size, *shape) - - denoiser = lambda input, sigma, c: self.denoiser( - self.model, input, sigma, c, **kwargs - ) - if mask is not None: - samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, mask=mask, init_im=init_im, num_steps=num_steps) - else: - samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, num_steps=num_steps) - if return_rgb: - return samples, rgb_list - return samples - - @torch.no_grad() - def samplemulti( - self, - cond, - uc=None, - batch_size: int = 16, - num_steps=None, - randn=None, - shape: Union[None, Tuple, List] = None, - return_rgb=False, - mask=None, - init_im=None, - multikwargs=None, - ): - if randn is None: - randn = torch.randn(batch_size, *shape) - - samples, rgb_list = self.sampler(self.denoiser, self.model, randn.to(self.device), cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs) - if return_rgb: - return samples, rgb_list - return samples - - @torch.no_grad() - def log_conditionings(self, batch: Dict, n: int, refernce: bool = True) -> Dict: - """ - Defines heuristics to log different conditionings. - These can be lists of strings (text-to-image), tensors, ints, ... - """ - image_h, image_w = batch[self.input_key].shape[2:] - log = dict() - - for embedder in self.conditioner.embedders: - if refernce: - check = (embedder.input_keys[0] in self.log_keys) - else: - check = (embedder.input_key in self.log_keys) - if ( - (self.log_keys is None) or check - ) and not self.no_cond_log: - if refernce: - x = batch[embedder.input_keys[0]][:n] - else: - x = batch[embedder.input_key][:n] - if isinstance(x, torch.Tensor): - if x.dim() == 1: - # class-conditional, convert integer to string - x = [str(x[i].item()) for i in range(x.shape[0])] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) - elif x.dim() == 2: - # size and crop cond and the like - x = [ - "x".join([str(xx) for xx in x[i].tolist()]) - for i in range(x.shape[0]) - ] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - else: - raise NotImplementedError() - elif isinstance(x, (List, ListConfig)): - if isinstance(x[0], str): - # strings - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - else: - raise NotImplementedError() - else: - raise NotImplementedError() - if refernce: - log[embedder.input_keys[0]] = xc - else: - log[embedder.input_key] = xc - return log - - @torch.no_grad() - def log_images( - self, - batch: Dict, - N: int = 8, - sample: bool = True, - ucg_keys: List[str] = None, - **kwargs, - ) -> Dict: - log = dict() - - x, xr, pose, mask, mask_ref, depth, drop_im = self.get_input(batch) - - if xr is not None: - conditioner_input_keys = [e.input_keys for e in self.conditioner.embedders] - else: - conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] - - if ucg_keys: - assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( - "Each defined ucg key for sampling must be in the provided conditioner input keys," - f"but we have {ucg_keys} vs. {conditioner_input_keys}" - ) - else: - ucg_keys = conditioner_input_keys - - c, uc = self.conditioner.get_unconditional_conditioning( - batch, - force_uc_zero_embeddings=ucg_keys - if len(self.conditioner.embedders) > 0 - else [], - ) - - N = min(x.shape[0], N) - x = x.to(self.device)[:N] - zr = None - if xr is not None: - xr = xr.to(self.device)[:N] - b, n = xr.shape[0], xr.shape[1] - log["reference"] = rearrange(xr, "b n ... -> (b n) ...", b=b, n=n) - zr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)), "(b n) ... -> b n ...", b=b, n=n) - - log["inputs"] = x - b = x.shape[0] - if mask is not None: - log["mask"] = mask - if depth is not None: - log["depth"] = depth - z = self.encode_first_stage(x) - - if uc is not None: - if xr is not None: - zr = torch.cat([torch.zeros_like(zr), zr]) - drop_im = torch.cat([drop_im, drop_im]) - if isinstance(pose, list): - pose = pose[:N]*2 - else: - pose = torch.cat([pose[:N]] * 2) - - sampling_kwargs = {'input_ref':zr} - sampling_kwargs['pose'] = pose - sampling_kwargs['mask_ref'] = None - sampling_kwargs['drop_im'] = drop_im - - log["reconstructions"] = self.decode_first_stage(z) - log.update(self.log_conditionings(batch, N, refernce=True if xr is not None else False)) - - for k in c: - if isinstance(c[k], torch.Tensor): - if xr is not None: - c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to(self.device), (c, uc)) - else: - c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) - if sample: - with self.ema_scope("Plotting"): - samples, rgb_list = self.sample( - c, shape=z.shape[1:], uc=uc, batch_size=N, return_rgb=True, **sampling_kwargs - ) - samples = self.decode_first_stage(samples) - log["samples"] = samples - if len(rgb_list) > 0: - size = int(math.sqrt(rgb_list[0].size(1))) - log["predicted_rgb"] = rgb_list[0].reshape(-1, size, size, 3).permute(0, 3, 1, 2) - return log diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py deleted file mode 100644 index 0db1d7716a6e48f77b86a4b59c9289d6fb76b50b..0000000000000000000000000000000000000000 --- a/sgm/modules/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .encoders.modules import GeneralConditioner - -UNCONDITIONAL_CONFIG = { - "target": "sgm.modules.GeneralConditioner", - "params": {"emb_models": []}, -} diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py deleted file mode 100644 index f39d849205b54e6dc82425fb94c8b662e4d0c293..0000000000000000000000000000000000000000 --- a/sgm/modules/attention.py +++ /dev/null @@ -1,1202 +0,0 @@ -import logging -import math -import itertools -from inspect import isfunction -from typing import Any, Optional -import numpy as np -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from packaging import version -from torch import nn -from .diffusionmodules.util import checkpoint -from torch.autograd import Function -from torch.cuda.amp import custom_bwd, custom_fwd - -from ..modules.diffusionmodules.util import zero_module -from ..modules.nerfsd_pytorch3d import NerfSDModule, VolRender - -logpy = logging.getLogger(__name__) - -if version.parse(torch.__version__) >= version.parse("2.0.0"): - SDP_IS_AVAILABLE = True - from torch.backends.cuda import SDPBackend, sdp_kernel - - BACKEND_MAP = { - SDPBackend.MATH: { - "enable_math": True, - "enable_flash": False, - "enable_mem_efficient": False, - }, - SDPBackend.FLASH_ATTENTION: { - "enable_math": False, - "enable_flash": True, - "enable_mem_efficient": False, - }, - SDPBackend.EFFICIENT_ATTENTION: { - "enable_math": False, - "enable_flash": False, - "enable_mem_efficient": True, - }, - None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, - } -else: - from contextlib import nullcontext - - SDP_IS_AVAILABLE = False - sdp_kernel = nullcontext - BACKEND_MAP = {} - logpy.warn( - f"No SDP backend available, likely because you are running in pytorch " - f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " - f"You might want to consider upgrading." - ) - -try: - import xformers - import xformers.ops - - XFORMERS_IS_AVAILABLE = True -except: - XFORMERS_IS_AVAILABLE = False - logpy.warn("no module 'xformers'. Processing without...") - - -def exists(val): - return val is not None - - -def uniq(arr): - return {el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange( - qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - out = rearrange( - out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w - ) - return self.to_out(out) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = rearrange(q, "b c h w -> b (h w) c") - k = rearrange(k, "b c h w -> b c (h w)") - w_ = torch.einsum("bij,bjk->bik", q, k) - - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, "b c h w -> b c (h w)") - w_ = rearrange(w_, "b i j -> b j i") - h_ = torch.einsum("bij,bjk->bik", v, w_) - h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) - h_ = self.proj_out(h_) - - return x + h_ - - -class _TruncExp(Function): # pylint: disable=abstract-method - # Implementation from torch-ngp: - # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, x): # pylint: disable=arguments-differ - ctx.save_for_backward(x) - return torch.exp(x) - - @staticmethod - @custom_bwd - def backward(ctx, g): # pylint: disable=arguments-differ - x = ctx.saved_tensors[0] - return g * torch.exp(x.clamp(-15, 15)) - - -trunc_exp = _TruncExp.apply -"""Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding -gradients.""" - - -class CrossAttention(nn.Module): - def __init__( - self, - query_dim, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - backend=None, - ): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - self.backend = backend - - def forward( - self, - x, - context=None, - mask=None, - additional_tokens=None, - n_times_crossframe_attn_in_self=0, - ): - h = self.heads - - if additional_tokens is not None: - # get the number of masked tokens at the beginning of the output sequence - n_tokens_to_mask = additional_tokens.shape[1] - # add additional token - x = torch.cat([additional_tokens, x], dim=1) - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - if n_times_crossframe_attn_in_self: - # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 - assert x.shape[0] % n_times_crossframe_attn_in_self == 0 - n_cp = x.shape[0] // n_times_crossframe_attn_in_self - k = repeat( - k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp - ) - v = repeat( - v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp - ) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - ## old - """ - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - del q, k - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - sim = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', sim, v) - """ - ## new - with sdp_kernel(**BACKEND_MAP[self.backend]): - # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) - out = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask - ) # scale is dim_head ** -0.5 per default - - del q, k, v - out = rearrange(out, "b h n d -> b n (h d)", h=h) - - if additional_tokens is not None: - # remove additional token - out = out[:, n_tokens_to_mask:] - return self.to_out(out) - - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__( - self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, add_lora=False, **kwargs - ): - super().__init__() - logpy.debug( - f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, " - f"context_dim is {context_dim} and using {heads} heads with a " - f"dimension of {dim_head}." - ) - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - self.add_lora = add_lora - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - if add_lora: - r = 32 - self.to_q_attn3_down = nn.Linear(query_dim, r, bias=False) - self.to_q_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False)) - self.to_k_attn3_down = nn.Linear(context_dim, r, bias=False) - self.to_k_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False)) - self.to_v_attn3_down = nn.Linear(context_dim, r, bias=False) - self.to_v_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False)) - self.to_o_attn3_down = nn.Linear(inner_dim, r, bias=False) - self.to_o_attn3_up = zero_module(nn.Linear(r, query_dim, bias=False)) - self.dropoutq = nn.Dropout(0.1) - self.dropoutk = nn.Dropout(0.1) - self.dropoutv = nn.Dropout(0.1) - self.dropouto = nn.Dropout(0.1) - - nn.init.normal_(self.to_q_attn3_down.weight, std=1 / r) - nn.init.normal_(self.to_k_attn3_down.weight, std=1 / r) - nn.init.normal_(self.to_v_attn3_down.weight, std=1 / r) - nn.init.normal_(self.to_o_attn3_down.weight, std=1 / r) - - self.attention_op: Optional[Any] = None - - def forward( - self, - x, - context=None, - mask=None, - additional_tokens=None, - n_times_crossframe_attn_in_self=0, - ): - if additional_tokens is not None: - # get the number of masked tokens at the beginning of the output sequence - n_tokens_to_mask = additional_tokens.shape[1] - # add additional token - x = torch.cat([additional_tokens, x], dim=1) - - context_k = context # b, n, c, h, w - - q = self.to_q(x) - context = default(context, x) - context_k = default(context_k, x) - k = self.to_k(context_k) - v = self.to_v(context_k) - if self.add_lora: - q += self.dropoutq(self.to_q_attn3_up(self.to_q_attn3_down(x))) - k += self.dropoutk(self.to_k_attn3_up(self.to_k_attn3_down(context_k))) - v += self.dropoutv(self.to_v_attn3_up(self.to_v_attn3_down(context_k))) - - if n_times_crossframe_attn_in_self: - # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 - assert x.shape[0] % n_times_crossframe_attn_in_self == 0 - # n_cp = x.shape[0]//n_times_crossframe_attn_in_self - k = repeat( - k[::n_times_crossframe_attn_in_self], - "b ... -> (b n) ...", - n=n_times_crossframe_attn_in_self, - ) - v = repeat( - v[::n_times_crossframe_attn_in_self], - "b ... -> (b n) ...", - n=n_times_crossframe_attn_in_self, - ) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - attn_bias = None - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=attn_bias, op=self.attention_op - ) - - # TODO: Use this directly in the attention operation, as a bias - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - if additional_tokens is not None: - # remove additional token - out = out[:, n_tokens_to_mask:] - final = self.to_out(out) - if self.add_lora: - final += self.dropouto(self.to_o_attn3_up(self.to_o_attn3_down(out))) - return final - - -class BasicTransformerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention, # ampere - } - - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - disable_self_attn=False, - attn_mode="softmax", - sdp_backend=None, - image_cross=False, - far=2, - num_samples=32, - add_lora=False, - rgb_predict=False, - mode='pixel-nerf', - average=False, - num_freqs=16, - use_prev_weights_imp_sample=False, - imp_sample_next_step=False, - stratified=False, - imp_sampling_percent=0.9, - near_plane=0. - ): - - super().__init__() - assert attn_mode in self.ATTENTION_MODES - self.add_lora = add_lora - self.image_cross = image_cross - self.rgb_predict = rgb_predict - self.use_prev_weights_imp_sample = use_prev_weights_imp_sample - self.imp_sample_next_step = imp_sample_next_step - self.rendered_feat = None - if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: - logpy.warn( - f"Attention mode '{attn_mode}' is not available. Falling " - f"back to native attention. This is not a problem in " - f"Pytorch >= 2.0. FYI, you are running with PyTorch " - f"version {torch.__version__}." - ) - attn_mode = "softmax" - elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: - logpy.warn( - "We do not support vanilla attention anymore, as it is too " - "expensive. Sorry." - ) - if not XFORMERS_IS_AVAILABLE: - assert ( - False - ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" - else: - logpy.info("Falling back to xformers efficient attention.") - attn_mode = "softmax-xformers" - attn_cls = self.ATTENTION_MODES[attn_mode] - if version.parse(torch.__version__) >= version.parse("2.0.0"): - assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) - else: - assert sdp_backend is None - self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls( - query_dim=dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - add_lora=self.add_lora, - context_dim=context_dim if self.disable_self_attn else None, - backend=sdp_backend, - ) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - add_lora=self.add_lora, - backend=sdp_backend, - ) # is self-attn if context is none - if image_cross: - self.pose_emb_layers = nn.Linear(2*dim, dim, bias=False) - nn.init.eye_(self.pose_emb_layers.weight) - self.pose_featurenerf = NerfSDModule(mode=mode, - out_channels=dim, - far_plane=far, - num_samples=num_samples, - rgb_predict=rgb_predict, - average=average, - num_freqs=num_freqs, - stratified=stratified, - imp_sampling_percent=imp_sampling_percent, - near_plane=near_plane, - ) - - self.renderer = VolRender() - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - if self.checkpoint: - logpy.debug(f"{self.__class__.__name__} is using checkpointing") - - def forward( - self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 - ): - kwargs = {"x": x} - - if context is not None: - kwargs.update({"context": context}) - - if context_ref is not None: - kwargs.update({"context_ref": context_ref}) - - if pose is not None: - kwargs.update({"pose": pose}) - - if mask_ref is not None: - kwargs.update({"mask_ref": mask_ref}) - - if prev_weights is not None: - kwargs.update({"prev_weights": prev_weights}) - - if additional_tokens is not None: - kwargs.update({"additional_tokens": additional_tokens}) - - if n_times_crossframe_attn_in_self: - kwargs.update( - {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} - ) - - # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) - return checkpoint( - self._forward, (x, context, context_ref, pose, mask_ref, prev_weights), self.parameters(), self.checkpoint - ) - - def reference_attn(self, x, context_ref, context, pose, prev_weights, mask_ref): - feats, sigmas, dists, _, predicted_rgb, sigmas_uniform, dists_uniform = self.pose_featurenerf(pose, - context_ref, - mask_ref, - prev_weights=prev_weights if self.use_prev_weights_imp_sample else None, - imp_sample_next_step=self.imp_sample_next_step) - - b, hw, d = feats.size()[:3] - feats = rearrange(feats, "b hw d ... -> b (hw d) ...") - - feats = ( - self.attn2( - self.norm2(feats), context=context, - ) - + feats - ) - - feats = rearrange(feats, "b (hw d) ... -> b hw d ...", hw=hw, d=d) - - sigmas_ = trunc_exp(sigmas) - if sigmas_uniform is not None: - sigmas_uniform = trunc_exp(sigmas_uniform) - - context_ref, fg_mask, alphas, weights_uniform, predicted_rgb = self.renderer(feats, sigmas_, dists, densities_uniform=sigmas_uniform, dists_uniform=dists_uniform, return_weights_uniform=True, rgb=F.sigmoid(predicted_rgb) if predicted_rgb is not None else None) - if self.use_prev_weights_imp_sample: - prev_weights = weights_uniform - - return context_ref, fg_mask, prev_weights, alphas, predicted_rgb - - def _forward( - self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 - ): - fg_mask = None - weights = None - alphas = None - predicted_rgb = None - xref = None - - x = ( - self.attn1( - self.norm1(x), - context=context if self.disable_self_attn else None, - additional_tokens=additional_tokens, - n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self - if not self.disable_self_attn - else 0, - ) - + x - ) - x = ( - self.attn2( - self.norm2(x), context=context, additional_tokens=additional_tokens - ) - + x - ) - with torch.amp.autocast(device_type='cuda', dtype=torch.float32): - if context_ref is not None: - xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x, - rearrange(context_ref, "(b n) ... -> b n ...", b=x.size(0), n=context_ref.size(0) // x.size(0)), - context, - pose, - prev_weights, - mask_ref) - x = self.pose_emb_layers(torch.cat([x, xref], -1)) - - x = self.ff(self.norm3(x)) + x - return x, fg_mask, weights, alphas, predicted_rgb - - -class BasicTransformerSingleLayerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version - # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) - } - - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - attn_mode="softmax", - ): - super().__init__() - assert attn_mode in self.ATTENTION_MODES - attn_cls = self.ATTENTION_MODES[attn_mode] - self.attn1 = attn_cls( - query_dim=dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - context_dim=context_dim, - ) - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - return checkpoint( - self._forward, (x, context), self.parameters(), self.checkpoint - ) - - def _forward(self, x, context=None): - x = self.attn1(self.norm1(x), context=context) + x - x = self.ff(self.norm2(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - NEW: use_linear for more efficiency instead of the 1x1 convs - """ - - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - disable_self_attn=False, - use_linear=False, - attn_type="softmax", - use_checkpoint=True, - # sdp_backend=SDPBackend.FLASH_ATTENTION - sdp_backend=None, - image_cross=True, - rgb_predict=False, - far=2, - num_samples=32, - add_lora=False, - mode='feature-nerf', - average=False, - num_freqs=16, - use_prev_weights_imp_sample=False, - stratified=False, - poscontrol_interval=4, - imp_sampling_percent=0.9, - near_plane=0. - ): - super().__init__() - logpy.debug( - f"constructing {self.__class__.__name__} of depth {depth} w/ " - f"{in_channels} channels and {n_heads} heads." - ) - from omegaconf import ListConfig - - if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): - context_dim = [context_dim] - if exists(context_dim) and isinstance(context_dim, list): - if depth != len(context_dim): - logpy.warn( - f"{self.__class__.__name__}: Found context dims " - f"{context_dim} of depth {len(context_dim)}, which does not " - f"match the specified 'depth' of {depth}. Setting context_dim " - f"to {depth * [context_dim[0]]} now." - ) - # depth does not match context dims. - assert all( - map(lambda x: x == context_dim[0], context_dim) - ), "need homogenous context_dim to match depth automatically" - context_dim = depth * [context_dim[0]] - elif context_dim is None: - context_dim = [None] * depth - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - - self.image_cross = image_cross - self.poscontrol_interval = poscontrol_interval - - if not use_linear: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - else: - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - n_heads, - d_head, - dropout=dropout, - context_dim=context_dim[d], - disable_self_attn=disable_self_attn, - attn_mode=attn_type, - checkpoint=use_checkpoint, - sdp_backend=sdp_backend, - image_cross=self.image_cross and (d % poscontrol_interval == 0), - far=far, - num_samples=num_samples, - add_lora=add_lora and self.image_cross and (d % poscontrol_interval == 0), - rgb_predict=rgb_predict, - mode=mode, - average=average, - num_freqs=num_freqs, - use_prev_weights_imp_sample=use_prev_weights_imp_sample, - imp_sample_next_step=(use_prev_weights_imp_sample and self.image_cross and (d % poscontrol_interval == 0) and depth >= poscontrol_interval and d < (depth // poscontrol_interval) * poscontrol_interval ), - stratified=stratified, - imp_sampling_percent=imp_sampling_percent, - near_plane=near_plane, - ) - for d in range(depth) - ] - ) - if not use_linear: - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) - else: - # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) - self.use_linear = use_linear - - def forward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None): - # note: if no context is given, cross-attention defaults to self-attention - if xr is None: - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c").contiguous() - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - if i > 0 and len(context) == 1: - i = 0 # use same context for each block - x, _, _, _, _ = block(x, context=context[i]) - if self.use_linear: - x = self.proj_out(x) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() - if not self.use_linear: - x = self.proj_out(x) - return x + x_in, None, None, None, None, None - else: - if not isinstance(context, list): - context = [context] - contextr = [contextr] - b, c, h, w = x.shape - b1, _, _, _ = xr.shape - x_in = x - xr_in = xr - fg_masks = [] - alphas = [] - rgbs = [] - - x = self.norm(x) - with torch.no_grad(): - xr = self.norm(xr) - - if not self.use_linear: - x = self.proj_in(x) - with torch.no_grad(): - xr = self.proj_in(xr) - - x = rearrange(x, "b c h w -> b (h w) c").contiguous() - xr = rearrange(xr, "b1 c h w -> b1 (h w) c").contiguous() - if self.use_linear: - x = self.proj_in(x) - with torch.no_grad(): - xr = self.proj_in(xr) - - prev_weights = None - counter = 0 - for i, block in enumerate(self.transformer_blocks): - if i > 0 and len(context) == 1: - i = 0 # use same context for each block - if self.image_cross and (counter % self.poscontrol_interval == 0): - with torch.no_grad(): - xr, _, _, _, _ = block(xr, context=contextr[i]) - x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=xr.detach(), pose=pose, mask_ref=mask_ref, prev_weights=prev_weights) - prev_weights = weights - fg_masks.append(fg_mask) - if alpha is not None: - alphas.append(alpha) - if rgb is not None: - rgbs.append(rgb) - else: - with torch.no_grad(): - xr, _, _, _, _ = block(xr, context=contextr[i]) - x, _, _, _, _ = block(x, context=context[i]) - counter += 1 - if self.use_linear: - x = self.proj_out(x) - with torch.no_grad(): - xr = self.proj_out(xr) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() - xr = rearrange(xr, "b1 (h w) c -> b1 c h w", h=h, w=w).contiguous() - if not self.use_linear: - x = self.proj_out(x) - with torch.no_grad(): - xr = self.proj_out(xr) - if len(fg_masks) > 0: - if len(rgbs) <= 0: - rgbs = None - if len(alphas) <= 0: - alphas = None - return x + x_in, (xr + xr_in).detach(), fg_masks, prev_weights, alphas, rgbs - else: - return x + x_in, (xr + xr_in).detach(), None, prev_weights, None, None - - -def benchmark_attn(): - # Lets define a helpful benchmarking function: - # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html - device = "cuda" if torch.cuda.is_available() else "cpu" - import torch.nn.functional as F - import torch.utils.benchmark as benchmark - - def benchmark_torch_function_in_microseconds(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - return t0.blocked_autorange().mean * 1e6 - - # Lets define the hyper-parameters of our input - batch_size = 32 - max_sequence_len = 1024 - num_heads = 32 - embed_dimension = 32 - - dtype = torch.float16 - - query = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - key = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - value = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - - print(f"q/k/v shape:", query.shape, key.shape, value.shape) - - # Lets explore the speed of each of the 3 implementations - from torch.backends.cuda import SDPBackend, sdp_kernel - - # Helpful arguments mapper - backend_map = { - SDPBackend.MATH: { - "enable_math": True, - "enable_flash": False, - "enable_mem_efficient": False, - }, - SDPBackend.FLASH_ATTENTION: { - "enable_math": False, - "enable_flash": True, - "enable_mem_efficient": False, - }, - SDPBackend.EFFICIENT_ATTENTION: { - "enable_math": False, - "enable_flash": False, - "enable_mem_efficient": True, - }, - } - - from torch.profiler import ProfilerActivity, profile, record_function - - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - print( - f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("Default detailed stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - print( - f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - with sdp_kernel(**backend_map[SDPBackend.MATH]): - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("Math implmentation stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): - try: - print( - f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - except RuntimeError: - print("FlashAttention is not supported. See warnings for reasons.") - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("FlashAttention stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): - try: - print( - f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - except RuntimeError: - print("EfficientAttention is not supported. See warnings for reasons.") - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("EfficientAttention stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - -def run_model(model, x, context): - return model(x, context) - - -def benchmark_transformer_blocks(): - device = "cuda" if torch.cuda.is_available() else "cpu" - import torch.utils.benchmark as benchmark - - def benchmark_torch_function_in_microseconds(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - return t0.blocked_autorange().mean * 1e6 - - checkpoint = True - compile = False - - batch_size = 32 - h, w = 64, 64 - context_len = 77 - embed_dimension = 1024 - context_dim = 1024 - d_head = 64 - - transformer_depth = 4 - - n_heads = embed_dimension // d_head - - dtype = torch.float16 - - model_native = SpatialTransformer( - embed_dimension, - n_heads, - d_head, - context_dim=context_dim, - use_linear=True, - use_checkpoint=checkpoint, - attn_type="softmax", - depth=transformer_depth, - sdp_backend=SDPBackend.FLASH_ATTENTION, - ).to(device) - model_efficient_attn = SpatialTransformer( - embed_dimension, - n_heads, - d_head, - context_dim=context_dim, - use_linear=True, - depth=transformer_depth, - use_checkpoint=checkpoint, - attn_type="softmax-xformers", - ).to(device) - if not checkpoint and compile: - print("compiling models") - model_native = torch.compile(model_native) - model_efficient_attn = torch.compile(model_efficient_attn) - - x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) - c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) - - from torch.profiler import ProfilerActivity, profile, record_function - - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - with torch.autocast("cuda"): - print( - f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" - ) - print( - f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" - ) - - print(75 * "+") - print("NATIVE") - print(75 * "+") - torch.cuda.reset_peak_memory_stats() - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("NativeAttention stats"): - for _ in range(25): - model_native(x, c) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") - - print(75 * "+") - print("Xformers") - print(75 * "+") - torch.cuda.reset_peak_memory_stats() - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("xformers stats"): - for _ in range(25): - model_efficient_attn(x, c) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") - - -def test01(): - # conv1x1 vs linear - from ..util import count_params - - conv = nn.Conv2d(3, 32, kernel_size=1).cuda() - print(count_params(conv)) - linear = torch.nn.Linear(3, 32).cuda() - print(count_params(linear)) - - print(conv.weight.shape) - - # use same initialization - linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) - linear.bias = torch.nn.Parameter(conv.bias) - - print(linear.weight.shape) - - x = torch.randn(11, 3, 64, 64).cuda() - - xr = rearrange(x, "b c h w -> b (h w) c").contiguous() - print(xr.shape) - out_linear = linear(xr) - print(out_linear.mean(), out_linear.shape) - - out_conv = conv(x) - print(out_conv.mean(), out_conv.shape) - print("done with test01.\n") - - -def test02(): - # try cosine flash attention - import time - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.backends.cudnn.benchmark = True - print("testing cosine flash attention...") - DIM = 1024 - SEQLEN = 4096 - BS = 16 - - print(" softmax (vanilla) first...") - model = BasicTransformerBlock( - dim=DIM, - n_heads=16, - d_head=64, - dropout=0.0, - context_dim=None, - attn_mode="softmax", - ).cuda() - try: - x = torch.randn(BS, SEQLEN, DIM).cuda() - tic = time.time() - y = model(x) - toc = time.time() - print(y.shape, toc - tic) - except RuntimeError as e: - # likely oom - print(str(e)) - - print("\n now flash-cosine...") - model = BasicTransformerBlock( - dim=DIM, - n_heads=16, - d_head=64, - dropout=0.0, - context_dim=None, - attn_mode="flash-cosine", - ).cuda() - x = torch.randn(BS, SEQLEN, DIM).cuda() - tic = time.time() - y = model(x) - toc = time.time() - print(y.shape, toc - tic) - print("done with test02.\n") - - -if __name__ == "__main__": - # test01() - # test02() - # test03() - - # benchmark_attn() - benchmark_transformer_blocks() - - print("done.") diff --git a/sgm/modules/autoencoding/__init__.py b/sgm/modules/autoencoding/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/autoencoding/lpips/__init__.py b/sgm/modules/autoencoding/lpips/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/autoencoding/lpips/loss.py b/sgm/modules/autoencoding/lpips/loss.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/autoencoding/lpips/loss/LICENSE b/sgm/modules/autoencoding/lpips/loss/LICENSE deleted file mode 100644 index 924cfc85b8d63ef538f5676f830a2a8497932108..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/loss/LICENSE +++ /dev/null @@ -1,23 +0,0 @@ -Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/loss/__init__.py b/sgm/modules/autoencoding/lpips/loss/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/autoencoding/lpips/loss/lpips.py b/sgm/modules/autoencoding/lpips/loss/lpips.py deleted file mode 100644 index 3e34f3d083674f675a5ca024e9bd27fb77e2b6b5..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/loss/lpips.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" - -from collections import namedtuple - -import torch -import torch.nn as nn -from torchvision import models - -from ..util import get_ckpt_path - - -class LPIPS(nn.Module): - # Learned perceptual metric - def __init__(self, use_dropout=True): - super().__init__() - self.scaling_layer = ScalingLayer() - self.chns = [64, 128, 256, 512, 512] # vg16 features - self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - self.load_from_pretrained() - for param in self.parameters(): - param.requires_grad = False - - def load_from_pretrained(self, name="vgg_lpips"): - ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") - self.load_state_dict( - torch.load(ckpt, map_location=torch.device("cpu")), strict=False - ) - print("loaded pretrained LPIPS loss from {}".format(ckpt)) - - @classmethod - def from_pretrained(cls, name="vgg_lpips"): - if name != "vgg_lpips": - raise NotImplementedError - model = cls() - ckpt = get_ckpt_path(name) - model.load_state_dict( - torch.load(ckpt, map_location=torch.device("cpu")), strict=False - ) - return model - - def forward(self, input, target): - in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) - outs0, outs1 = self.net(in0_input), self.net(in1_input) - feats0, feats1, diffs = {}, {}, {} - lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] - for kk in range(len(self.chns)): - feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( - outs1[kk] - ) - diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - - res = [ - spatial_average(lins[kk].model(diffs[kk]), keepdim=True) - for kk in range(len(self.chns)) - ] - val = res[0] - for l in range(1, len(self.chns)): - val += res[l] - return val - - -class ScalingLayer(nn.Module): - def __init__(self): - super(ScalingLayer, self).__init__() - self.register_buffer( - "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] - ) - self.register_buffer( - "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] - ) - - def forward(self, inp): - return (inp - self.shift) / self.scale - - -class NetLinLayer(nn.Module): - """A single linear layer which does a 1x1 conv""" - - def __init__(self, chn_in, chn_out=1, use_dropout=False): - super(NetLinLayer, self).__init__() - layers = ( - [ - nn.Dropout(), - ] - if (use_dropout) - else [] - ) - layers += [ - nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), - ] - self.model = nn.Sequential(*layers) - - -class vgg16(torch.nn.Module): - def __init__(self, requires_grad=False, pretrained=True): - super(vgg16, self).__init__() - vgg_pretrained_features = models.vgg16(pretrained=pretrained).features - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - self.N_slices = 5 - for x in range(4): - self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(4, 9): - self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(9, 16): - self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(16, 23): - self.slice4.add_module(str(x), vgg_pretrained_features[x]) - for x in range(23, 30): - self.slice5.add_module(str(x), vgg_pretrained_features[x]) - if not requires_grad: - for param in self.parameters(): - param.requires_grad = False - - def forward(self, X): - h = self.slice1(X) - h_relu1_2 = h - h = self.slice2(h) - h_relu2_2 = h - h = self.slice3(h) - h_relu3_3 = h - h = self.slice4(h) - h_relu4_3 = h - h = self.slice5(h) - h_relu5_3 = h - vgg_outputs = namedtuple( - "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] - ) - out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) - return out - - -def normalize_tensor(x, eps=1e-10): - norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) - return x / (norm_factor + eps) - - -def spatial_average(x, keepdim=True): - return x.mean([2, 3], keepdim=keepdim) diff --git a/sgm/modules/autoencoding/lpips/model/LICENSE b/sgm/modules/autoencoding/lpips/model/LICENSE deleted file mode 100644 index 4b356e66b5aa689b339f1a80a9f1b5ba378003bb..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/model/LICENSE +++ /dev/null @@ -1,58 +0,0 @@ -Copyright (c) 2017, Jun-Yan Zhu and Taesung Park -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ---------------------------- LICENSE FOR pix2pix -------------------------------- -BSD License - -For pix2pix software -Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - ------------------------------ LICENSE FOR DCGAN -------------------------------- -BSD License - -For dcgan.torch software - -Copyright (c) 2015, Facebook, Inc. All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - -Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - -Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/model/__init__.py b/sgm/modules/autoencoding/lpips/model/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/autoencoding/lpips/model/model.py b/sgm/modules/autoencoding/lpips/model/model.py deleted file mode 100644 index 66357d4e627f9a69a5abbbad15546c96fcd758fe..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/model/model.py +++ /dev/null @@ -1,88 +0,0 @@ -import functools - -import torch.nn as nn - -from ..util import ActNorm - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) - - -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator as in Pix2Pix - --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py - """ - - def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ - super(NLayerDiscriminator, self).__init__() - if not use_actnorm: - norm_layer = nn.BatchNorm2d - else: - norm_layer = ActNorm - if ( - type(norm_layer) == functools.partial - ): # no need to use bias as BatchNorm2d has affine parameters - use_bias = norm_layer.func != nn.BatchNorm2d - else: - use_bias = norm_layer != nn.BatchNorm2d - - kw = 4 - padw = 1 - sequence = [ - nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), - nn.LeakyReLU(0.2, True), - ] - nf_mult = 1 - nf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - nf_mult_prev = nf_mult - nf_mult = min(2**n, 8) - sequence += [ - nn.Conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=2, - padding=padw, - bias=use_bias, - ), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True), - ] - - nf_mult_prev = nf_mult - nf_mult = min(2**n_layers, 8) - sequence += [ - nn.Conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=1, - padding=padw, - bias=use_bias, - ), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True), - ] - - sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) - ] # output 1 channel prediction map - self.main = nn.Sequential(*sequence) - - def forward(self, input): - """Standard forward.""" - return self.main(input) diff --git a/sgm/modules/autoencoding/lpips/util.py b/sgm/modules/autoencoding/lpips/util.py deleted file mode 100644 index 49c76e370bf16888ab61f42844b3c9f14ad9014c..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/util.py +++ /dev/null @@ -1,128 +0,0 @@ -import hashlib -import os - -import requests -import torch -import torch.nn as nn -from tqdm import tqdm - -URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} - -CKPT_MAP = {"vgg_lpips": "vgg.pth"} - -MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} - - -def download(url, local_path, chunk_size=1024): - os.makedirs(os.path.split(local_path)[0], exist_ok=True) - with requests.get(url, stream=True) as r: - total_size = int(r.headers.get("content-length", 0)) - with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: - with open(local_path, "wb") as f: - for data in r.iter_content(chunk_size=chunk_size): - if data: - f.write(data) - pbar.update(chunk_size) - - -def md5_hash(path): - with open(path, "rb") as f: - content = f.read() - return hashlib.md5(content).hexdigest() - - -def get_ckpt_path(name, root, check=False): - assert name in URL_MAP - path = os.path.join(root, CKPT_MAP[name]) - if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): - print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) - download(URL_MAP[name], path) - md5 = md5_hash(path) - assert md5 == MD5_MAP[name], md5 - return path - - -class ActNorm(nn.Module): - def __init__( - self, num_features, logdet=False, affine=True, allow_reverse_init=False - ): - assert affine - super().__init__() - self.logdet = logdet - self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) - self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) - self.allow_reverse_init = allow_reverse_init - - self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) - - def initialize(self, input): - with torch.no_grad(): - flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) - mean = ( - flatten.mean(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - std = ( - flatten.std(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - - self.loc.data.copy_(-mean) - self.scale.data.copy_(1 / (std + 1e-6)) - - def forward(self, input, reverse=False): - if reverse: - return self.reverse(input) - if len(input.shape) == 2: - input = input[:, :, None, None] - squeeze = True - else: - squeeze = False - - _, _, height, width = input.shape - - if self.training and self.initialized.item() == 0: - self.initialize(input) - self.initialized.fill_(1) - - h = self.scale * (input + self.loc) - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - - if self.logdet: - log_abs = torch.log(torch.abs(self.scale)) - logdet = height * width * torch.sum(log_abs) - logdet = logdet * torch.ones(input.shape[0]).to(input) - return h, logdet - - return h - - def reverse(self, output): - if self.training and self.initialized.item() == 0: - if not self.allow_reverse_init: - raise RuntimeError( - "Initializing ActNorm in reverse direction is " - "disabled by default. Use allow_reverse_init=True to enable." - ) - else: - self.initialize(output) - self.initialized.fill_(1) - - if len(output.shape) == 2: - output = output[:, :, None, None] - squeeze = True - else: - squeeze = False - - h = output / self.scale - self.loc - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - return h diff --git a/sgm/modules/autoencoding/lpips/vqperceptual.py b/sgm/modules/autoencoding/lpips/vqperceptual.py deleted file mode 100644 index 6195f0a6ed7ee6fd32c1bccea071e6075e95ee43..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/vqperceptual.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch -import torch.nn.functional as F - - -def hinge_d_loss(logits_real, logits_fake): - loss_real = torch.mean(F.relu(1.0 - logits_real)) - loss_fake = torch.mean(F.relu(1.0 + logits_fake)) - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - - -def vanilla_d_loss(logits_real, logits_fake): - d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) - + torch.mean(torch.nn.functional.softplus(logits_fake)) - ) - return d_loss diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py deleted file mode 100644 index ff2b1815a5ba88892375e8ec9bedacea49024113..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/regularizers/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -from abc import abstractmethod -from typing import Any, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ....modules.distributions.distributions import \ - DiagonalGaussianDistribution -from .base import AbstractRegularizer - - -class DiagonalGaussianRegularizer(AbstractRegularizer): - def __init__(self, sample: bool = True): - super().__init__() - self.sample = sample - - def get_trainable_parameters(self) -> Any: - yield from () - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: - log = dict() - posterior = DiagonalGaussianDistribution(z) - if self.sample: - z = posterior.sample() - else: - z = posterior.mode() - kl_loss = posterior.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - log["kl_loss"] = kl_loss - return z, log diff --git a/sgm/modules/autoencoding/regularizers/base.py b/sgm/modules/autoencoding/regularizers/base.py deleted file mode 100644 index fca681bb3c1f4818b57e956e31b98f76077ccb67..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/regularizers/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import abstractmethod -from typing import Any, Tuple - -import torch -import torch.nn.functional as F -from torch import nn - - -class AbstractRegularizer(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: - raise NotImplementedError() - - @abstractmethod - def get_trainable_parameters(self) -> Any: - raise NotImplementedError() - - -class IdentityRegularizer(AbstractRegularizer): - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: - return z, dict() - - def get_trainable_parameters(self) -> Any: - yield from () - - -def measure_perplexity( - predicted_indices: torch.Tensor, num_centroids: int -) -> Tuple[torch.Tensor, torch.Tensor]: - # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py - # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = ( - F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) - ) - avg_probs = encodings.mean(0) - perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() - cluster_use = torch.sum(avg_probs > 0) - return perplexity, cluster_use diff --git a/sgm/modules/autoencoding/regularizers/quantize.py b/sgm/modules/autoencoding/regularizers/quantize.py deleted file mode 100644 index 86a4dbdd10101b24f03bba134c4f8d2ab007f0db..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/regularizers/quantize.py +++ /dev/null @@ -1,487 +0,0 @@ -import logging -from abc import abstractmethod -from typing import Dict, Iterator, Literal, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch import einsum - -from .base import AbstractRegularizer, measure_perplexity - -logpy = logging.getLogger(__name__) - - -class AbstractQuantizer(AbstractRegularizer): - def __init__(self): - super().__init__() - # Define these in your init - # shape (N,) - self.used: Optional[torch.Tensor] - self.re_embed: int - self.unknown_index: Union[Literal["random"], int] - - def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: - assert self.used is not None, "You need to define used indices for remap" - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( - device=new.device - ) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: - assert self.used is not None, "You need to define used indices for remap" - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - @abstractmethod - def get_codebook_entry( - self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None - ) -> torch.Tensor: - raise NotImplementedError() - - def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: - yield from self.parameters() - - -class GumbelQuantizer(AbstractQuantizer): - """ - credit to @karpathy: - https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) - Gumbel Softmax trick quantizer - Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 - https://arxiv.org/abs/1611.01144 - """ - - def __init__( - self, - num_hiddens: int, - embedding_dim: int, - n_embed: int, - straight_through: bool = True, - kl_weight: float = 5e-4, - temp_init: float = 1.0, - remap: Optional[str] = None, - unknown_index: str = "random", - loss_key: str = "loss/vq", - ) -> None: - super().__init__() - - self.loss_key = loss_key - self.embedding_dim = embedding_dim - self.n_embed = n_embed - - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - - self.proj = nn.Conv2d(num_hiddens, n_embed, 1) - self.embed = nn.Embedding(n_embed, embedding_dim) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - else: - self.used = None - self.re_embed = n_embed - if unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - else: - assert unknown_index == "random" or isinstance( - unknown_index, int - ), "unknown index needs to be 'random', 'extra' or any integer" - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.remap is not None: - logpy.info( - f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - - def forward( - self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False - ) -> Tuple[torch.Tensor, Dict]: - # force hard = True when we are in eval mode, as we must quantize. - # actually, always true seems to work - hard = self.straight_through if self.training else True - temp = self.temperature if temp is None else temp - out_dict = {} - logits = self.proj(z) - if self.remap is not None: - # continue only with used logits - full_zeros = torch.zeros_like(logits) - logits = logits[:, self.used, ...] - - soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) - if self.remap is not None: - # go back to all entries but unused set to zero - full_zeros[:, self.used, ...] = soft_one_hot - soft_one_hot = full_zeros - z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = ( - self.kl_weight - * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() - ) - out_dict[self.loss_key] = diff - - ind = soft_one_hot.argmax(dim=1) - out_dict["indices"] = ind - if self.remap is not None: - ind = self.remap_to_used(ind) - - if return_logits: - out_dict["logits"] = logits - - return z_q, out_dict - - def get_codebook_entry(self, indices, shape): - # TODO: shape not yet optional - b, h, w, c = shape - assert b * h * w == indices.shape[0] - indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) - if self.remap is not None: - indices = self.unmap_to_all(indices) - one_hot = ( - F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() - ) - z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) - return z_q - - -class VectorQuantizer(AbstractQuantizer): - """ - ____________________________________________ - Discretization bottleneck part of the VQ-VAE. - Inputs: - - n_e : number of embeddings - - e_dim : dimension of embedding - - beta : commitment cost used in loss term, - beta * ||z_e(x)-sg[e]||^2 - _____________________________________________ - """ - - def __init__( - self, - n_e: int, - e_dim: int, - beta: float = 0.25, - remap: Optional[str] = None, - unknown_index: str = "random", - sane_index_shape: bool = False, - log_perplexity: bool = False, - embedding_weight_norm: bool = False, - loss_key: str = "loss/vq", - ): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.loss_key = loss_key - - if not embedding_weight_norm: - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - else: - self.embedding = torch.nn.utils.weight_norm( - nn.Embedding(self.n_e, self.e_dim), dim=1 - ) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - else: - self.used = None - self.re_embed = n_e - if unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - else: - assert unknown_index == "random" or isinstance( - unknown_index, int - ), "unknown index needs to be 'random', 'extra' or any integer" - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.remap is not None: - logpy.info( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - - self.sane_index_shape = sane_index_shape - self.log_perplexity = log_perplexity - - def forward( - self, - z: torch.Tensor, - ) -> Tuple[torch.Tensor, Dict]: - do_reshape = z.ndim == 4 - if do_reshape: - # # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, "b c h w -> b h w c").contiguous() - - else: - assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" - z = z.contiguous() - - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 - * torch.einsum( - "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") - ) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - loss_dict = {} - if self.log_perplexity: - perplexity, cluster_usage = measure_perplexity( - min_encoding_indices.detach(), self.n_e - ) - loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) - - # compute loss for embedding - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( - (z_q - z.detach()) ** 2 - ) - loss_dict[self.loss_key] = loss - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - if do_reshape: - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape( - z.shape[0], -1 - ) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - if do_reshape: - min_encoding_indices = min_encoding_indices.reshape( - z_q.shape[0], z_q.shape[2], z_q.shape[3] - ) - else: - min_encoding_indices = rearrange( - min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] - ) - - loss_dict["min_encoding_indices"] = min_encoding_indices - - return z_q, loss_dict - - def get_codebook_entry( - self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None - ) -> torch.Tensor: - # shape specifying (batch, height, width, channel) - if self.remap is not None: - assert shape is not None, "Need to give shape for remap" - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class EmbeddingEMA(nn.Module): - def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): - super().__init__() - self.decay = decay - self.eps = eps - weight = torch.randn(num_tokens, codebook_dim) - self.weight = nn.Parameter(weight, requires_grad=False) - self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) - self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) - self.update = True - - def forward(self, embed_id): - return F.embedding(embed_id, self.weight) - - def cluster_size_ema_update(self, new_cluster_size): - self.cluster_size.data.mul_(self.decay).add_( - new_cluster_size, alpha=1 - self.decay - ) - - def embed_avg_ema_update(self, new_embed_avg): - self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) - - def weight_update(self, num_tokens): - n = self.cluster_size.sum() - smoothed_cluster_size = ( - (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n - ) - # normalize embedding average with smoothed cluster size - embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) - self.weight.data.copy_(embed_normalized) - - -class EMAVectorQuantizer(AbstractQuantizer): - def __init__( - self, - n_embed: int, - embedding_dim: int, - beta: float, - decay: float = 0.99, - eps: float = 1e-5, - remap: Optional[str] = None, - unknown_index: str = "random", - loss_key: str = "loss/vq", - ): - super().__init__() - self.codebook_dim = embedding_dim - self.num_tokens = n_embed - self.beta = beta - self.loss_key = loss_key - - self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - else: - self.used = None - self.re_embed = n_embed - if unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - else: - assert unknown_index == "random" or isinstance( - unknown_index, int - ), "unknown index needs to be 'random', 'extra' or any integer" - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.remap is not None: - logpy.info( - f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: - # reshape z -> (batch, height, width, channel) and flatten - # z, 'b c h w -> b h w c' - z = rearrange(z, "b c h w -> b h w c") - z_flattened = z.reshape(-1, self.codebook_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = ( - z_flattened.pow(2).sum(dim=1, keepdim=True) - + self.embedding.weight.pow(2).sum(dim=1) - - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) - ) # 'n d -> d n' - - encoding_indices = torch.argmin(d, dim=1) - - z_q = self.embedding(encoding_indices).view(z.shape) - encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) - avg_probs = torch.mean(encodings, dim=0) - perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) - - if self.training and self.embedding.update: - # EMA cluster size - encodings_sum = encodings.sum(0) - self.embedding.cluster_size_ema_update(encodings_sum) - # EMA embedding average - embed_sum = encodings.transpose(0, 1) @ z_flattened - self.embedding.embed_avg_ema_update(embed_sum) - # normalize embed_avg and update weight - self.embedding.weight_update(self.num_tokens) - - # compute loss for embedding - loss = self.beta * F.mse_loss(z_q.detach(), z) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - # z_q, 'b h w c -> b c h w' - z_q = rearrange(z_q, "b h w c -> b c h w") - - out_dict = { - self.loss_key: loss, - "encodings": encodings, - "encoding_indices": encoding_indices, - "perplexity": perplexity, - } - - return z_q, out_dict - - -class VectorQuantizerWithInputProjection(VectorQuantizer): - def __init__( - self, - input_dim: int, - n_codes: int, - codebook_dim: int, - beta: float = 1.0, - output_dim: Optional[int] = None, - **kwargs, - ): - super().__init__(n_codes, codebook_dim, beta, **kwargs) - self.proj_in = nn.Linear(input_dim, codebook_dim) - self.output_dim = output_dim - if output_dim is not None: - self.proj_out = nn.Linear(codebook_dim, output_dim) - else: - self.proj_out = nn.Identity() - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: - rearr = False - in_shape = z.shape - - if z.ndim > 3: - rearr = self.output_dim is not None - z = rearrange(z, "b c ... -> b (...) c") - z = self.proj_in(z) - z_q, loss_dict = super().forward(z) - - z_q = self.proj_out(z_q) - if rearr: - if len(in_shape) == 4: - z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) - elif len(in_shape) == 5: - z_q = rearrange( - z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] - ) - else: - raise NotImplementedError( - f"rearranging not available for {len(in_shape)}-dimensional input." - ) - - return z_q, loss_dict diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py deleted file mode 100644 index 52a1e33476904408013ec472967b42299d857fce..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/denoiser.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch.nn as nn -import torch -from ...util import append_dims, instantiate_from_config - - -class Denoiser(nn.Module): - def __init__(self, weighting_config, scaling_config): - super().__init__() - - self.weighting = instantiate_from_config(weighting_config) - self.scaling = instantiate_from_config(scaling_config) - - def possibly_quantize_sigma(self, sigma): - return sigma - - def possibly_quantize_c_noise(self, c_noise): - return c_noise - - def w(self, sigma): - return self.weighting(sigma) - - def __call__(self, network, input, sigma, cond, sigmas_ref=None, **kwargs): - sigma = self.possibly_quantize_sigma(sigma) - sigma_shape = sigma.shape - sigma = append_dims(sigma, input.ndim) - if sigmas_ref is not None: - if kwargs is not None: - kwargs['sigmas_ref'] = sigmas_ref - else: - kwargs = {'sigmas_ref': sigmas_ref} - - if kwargs['input_ref'] is not None: - noise = torch.randn_like(kwargs['input_ref']) - kwargs['input_ref'] = kwargs['input_ref'] + noise * append_dims(sigmas_ref, kwargs['input_ref'].ndim) - - if 'input_ref' in kwargs and kwargs['input_ref'] is not None and 'sigmas_ref' in kwargs: - _, _, c_in_ref, c_noise_ref = self.scaling(append_dims(kwargs['sigmas_ref'], kwargs['input_ref'].ndim)) - kwargs['input_ref'] = kwargs['input_ref']*c_in_ref - kwargs['sigmas_ref'] = self.possibly_quantize_c_noise(kwargs['sigmas_ref']) - - c_skip, c_out, c_in, c_noise = self.scaling(sigma) - c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) - predict, fg_mask_list, alphas_list, rgb_list = network(input * c_in, c_noise, cond, **kwargs) - return predict * c_out + input * c_skip, fg_mask_list, alphas_list, rgb_list - - -class DiscreteDenoiser(Denoiser): - def __init__( - self, - weighting_config, - scaling_config, - num_idx, - discretization_config, - do_append_zero=False, - quantize_c_noise=True, - flip=True, - ): - super().__init__(weighting_config, scaling_config) - sigmas = instantiate_from_config(discretization_config)( - num_idx, do_append_zero=do_append_zero, flip=flip - ) - self.register_buffer("sigmas", sigmas) - self.quantize_c_noise = quantize_c_noise - - def sigma_to_idx(self, sigma): - dists = sigma - self.sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) - - def idx_to_sigma(self, idx): - return self.sigmas[idx] - - def possibly_quantize_sigma(self, sigma): - return self.idx_to_sigma(self.sigma_to_idx(sigma)) - - def possibly_quantize_c_noise(self, c_noise): - if self.quantize_c_noise: - return self.sigma_to_idx(c_noise) - else: - return c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py deleted file mode 100644 index 63075088683ba264f87cd6be3b728e539645fc7a..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/denoiser_scaling.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -from abc import ABC, abstractmethod -from typing import Tuple - - -class DenoiserScaling(ABC): - @abstractmethod - def __call__( - self, sigma: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - pass - - -class EDMScaling: - def __init__(self, sigma_data=0.5): - self.sigma_data = sigma_data - - def __call__(self, sigma): - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 - c_noise = 0.25 * sigma.log() - return c_skip, c_out, c_in, c_noise - - -class EpsScaling: - def __call__(self, sigma): - c_skip = torch.ones_like(sigma, device=sigma.device) - c_out = -sigma - c_in = 1 / (sigma**2 + 1.0) ** 0.5 - c_noise = sigma.clone() - return c_skip, c_out, c_in, c_noise - - -class VScaling: - def __call__(self, sigma): - c_skip = 1.0 / (sigma**2 + 1.0) - c_out = -sigma / (sigma**2 + 1.0) ** 0.5 - c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 - c_noise = sigma.clone() - return c_skip, c_out, c_in, c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py deleted file mode 100644 index b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/denoiser_weighting.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - - -class UnitWeighting: - def __call__(self, sigma): - return torch.ones_like(sigma, device=sigma.device) - - -class EDMWeighting: - def __init__(self, sigma_data=0.5): - self.sigma_data = sigma_data - - def __call__(self, sigma): - return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - - -class VWeighting(EDMWeighting): - def __init__(self): - super().__init__(sigma_data=1.0) - - -class EpsWeighting: - def __call__(self, sigma): - return sigma**-2.0 diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py deleted file mode 100644 index 02add6081c5e3164d4402619b44d5be235d3ec58..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/discretizer.py +++ /dev/null @@ -1,69 +0,0 @@ -from abc import abstractmethod -from functools import partial - -import numpy as np -import torch - -from ...modules.diffusionmodules.util import make_beta_schedule -from ...util import append_zero - - -def generate_roughly_equally_spaced_steps( - num_substeps: int, max_step: int -) -> np.ndarray: - return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] - - -class Discretization: - def __call__(self, n, do_append_zero=True, device="cpu", flip=False): - sigmas = self.get_sigmas(n, device=device) - sigmas = append_zero(sigmas) if do_append_zero else sigmas - return sigmas if not flip else torch.flip(sigmas, (0,)) - - @abstractmethod - def get_sigmas(self, n, device): - pass - - -class EDMDiscretization(Discretization): - def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.rho = rho - - def get_sigmas(self, n, device="cpu"): - ramp = torch.linspace(0, 1, n, device=device) - min_inv_rho = self.sigma_min ** (1 / self.rho) - max_inv_rho = self.sigma_max ** (1 / self.rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho - return sigmas - - -class LegacyDDPMDiscretization(Discretization): - def __init__( - self, - linear_start=0.00085, - linear_end=0.0120, - num_timesteps=1000, - ): - super().__init__() - self.num_timesteps = num_timesteps - betas = make_beta_schedule( - "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end - ) - alphas = 1.0 - betas - self.alphas_cumprod = np.cumprod(alphas, axis=0) - self.to_torch = partial(torch.tensor, dtype=torch.float32) - - def get_sigmas(self, n, device="cpu"): - if n < self.num_timesteps: - timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) - alphas_cumprod = self.alphas_cumprod[timesteps] - elif n == self.num_timesteps: - alphas_cumprod = self.alphas_cumprod - else: - raise ValueError - - to_torch = partial(torch.tensor, dtype=torch.float32, device=device) - sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - return torch.flip(sigmas, (0,)) diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py deleted file mode 100644 index cd62f44f989ff3bc22ab149c89236e0324b7adc8..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/guiders.py +++ /dev/null @@ -1,167 +0,0 @@ -import logging -from abc import ABC, abstractmethod -from typing import Dict, List, Literal, Optional, Tuple, Union - -import torch -from einops import rearrange, repeat - -from ...util import append_dims, default - -logpy = logging.getLogger(__name__) - - -class Guider(ABC): - @abstractmethod - def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: - pass - - def prepare_inputs( - self, x: torch.Tensor, s: float, c: Dict, uc: Dict - ) -> Tuple[torch.Tensor, float, Dict]: - pass - - -class VanillaCFG(Guider): - def __init__(self, scale: float): - self.scale = scale - - def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - x_u, x_c = x.chunk(2) - x_pred = x_u + self.scale * (x_c - x_u) - return x_pred - - def prepare_inputs(self, x, s, c, uc): - c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"]: - c_out[k] = torch.cat((uc[k], c[k]), 0) - else: - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 2), torch.cat([s] * 2), c_out - - -class IdentityGuider(Guider): - def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: - return x - - def prepare_inputs( - self, x: torch.Tensor, s: float, c: Dict, uc: Dict - ) -> Tuple[torch.Tensor, float, Dict]: - c_out = dict() - - for k in c: - c_out[k] = c[k] - - return x, s, c_out - - -class LinearPredictionGuider(Guider): - def __init__( - self, - max_scale: float, - num_frames: int, - min_scale: float = 1.0, - additional_cond_keys: Optional[Union[List[str], str]] = None, - ): - self.min_scale = min_scale - self.max_scale = max_scale - self.num_frames = num_frames - self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) - - additional_cond_keys = default(additional_cond_keys, []) - if isinstance(additional_cond_keys, str): - additional_cond_keys = [additional_cond_keys] - self.additional_cond_keys = additional_cond_keys - - def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - x_u, x_c = x.chunk(2) - - x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) - x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) - scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) - scale = append_dims(scale, x_u.ndim).to(x_u.device) - - return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") - - def prepare_inputs( - self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: - c_out[k] = torch.cat((uc[k], c[k]), 0) - else: - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 2), torch.cat([s] * 2), c_out - - -class ScheduledCFGImgTextRef(Guider): - """ - From InstructPix2Pix - """ - - def __init__(self, scale: float, scale_im: float): - self.scale = scale - self.scale_im = scale_im - - def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - x_u, x_ic, x_c = x.chunk(3) - x_pred = x_u + self.scale * (x_c - x_ic) + self.scale_im*(x_ic - x_u) - return x_pred - - def prepare_inputs(self, x, s, c, uc): - c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"]: - b = uc[k].shape[0] - if k == "crossattn": - uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) - c1, c2 = c[k].split([x.size(0), b - x.size(0)]) - c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0) - else: - uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) - c1, c2 = c[k].split([x.size(0), b - x.size(0)]) - c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0) - else: - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 3), torch.cat([s] * 3), c_out - - -class VanillaCFGImgRef(Guider): - """ - implements parallelized CFG - """ - - def __init__(self, scale: float): - self.scale = scale - - def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - x_u, x_c = x.chunk(2) - x_pred = x_u + self.scale * (x_c - x_u) - return x_pred - - def prepare_inputs(self, x, s, c, uc): - c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"]: - b = uc[k].shape[0] - if k == "crossattn": - uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) - c1, c2 = c[k].split([x.size(0), b - x.size(0)]) - c_out[k] = torch.cat((uc1, c1, uc2, c2), 0) - else: - uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) - c1, c2 = c[k].split([x.size(0), b - x.size(0)]) - c_out[k] = torch.cat((uc1, c1, uc2, c2), 0) - else: - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 2), torch.cat([s] * 2), c_out - diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py deleted file mode 100644 index 89663f60d190233571b973be43db948f9726cc7a..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/loss.py +++ /dev/null @@ -1,216 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union -import math -import torch -import torch.nn as nn - -from ...modules.autoencoding.lpips.loss.lpips import LPIPS -from ...modules.encoders.modules import GeneralConditioner -from ...util import append_dims, instantiate_from_config -from .denoiser import Denoiser - - -class StandardDiffusionLoss(nn.Module): - def __init__( - self, - sigma_sampler_config: dict, - loss_weighting_config: dict, - loss_type: str = "l2", - offset_noise_level: float = 0.0, - batch2model_keys: Optional[Union[str, List[str]]] = None, - ): - super().__init__() - - assert loss_type in ["l2", "l1", "lpips"] - - self.sigma_sampler = instantiate_from_config(sigma_sampler_config) - self.loss_weighting = instantiate_from_config(loss_weighting_config) - - self.loss_type = loss_type - self.offset_noise_level = offset_noise_level - - if loss_type == "lpips": - self.lpips = LPIPS().eval() - - if not batch2model_keys: - batch2model_keys = [] - - if isinstance(batch2model_keys, str): - batch2model_keys = [batch2model_keys] - - self.batch2model_keys = set(batch2model_keys) - - def get_noised_input( - self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor - ) -> torch.Tensor: - noised_input = input + noise * sigmas_bc - return noised_input - - def forward( - self, - network: nn.Module, - denoiser: Denoiser, - conditioner: GeneralConditioner, - input: torch.Tensor, - batch: Dict, - ) -> torch.Tensor: - cond = conditioner(batch) - return self._forward(network, denoiser, cond, input, batch) - - def _forward( - self, - network: nn.Module, - denoiser: Denoiser, - cond: Dict, - input: torch.Tensor, - batch: Dict, - ) -> Tuple[torch.Tensor, Dict]: - additional_model_inputs = { - key: batch[key] for key in self.batch2model_keys.intersection(batch) - } - sigmas = self.sigma_sampler(input.shape[0]).to(input) - - noise = torch.randn_like(input) - if self.offset_noise_level > 0.0: - offset_shape = ( - (input.shape[0], 1, input.shape[2]) - if self.n_frames is not None - else (input.shape[0], input.shape[1]) - ) - noise = noise + self.offset_noise_level * append_dims( - torch.randn(offset_shape, device=input.device), - input.ndim, - ) - sigmas_bc = append_dims(sigmas, input.ndim) - noised_input = self.get_noised_input(sigmas_bc, noise, input) - - model_output = denoiser( - network, noised_input, sigmas, cond, **additional_model_inputs - ) - w = append_dims(self.loss_weighting(sigmas), input.ndim) - return self.get_loss(model_output, input, w) - - def get_loss(self, model_output, target, w): - if self.loss_type == "l2": - return torch.mean( - (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 - ) - elif self.loss_type == "l1": - return torch.mean( - (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 - ) - elif self.loss_type == "lpips": - loss = self.lpips(model_output, target).reshape(-1) - return loss - else: - raise NotImplementedError(f"Unknown loss type {self.loss_type}") - - -class StandardDiffusionLossImgRef(nn.Module): - def __init__( - self, - sigma_sampler_config: dict, - sigma_sampler_config_ref: dict, - type: str = "l2", - offset_noise_level: float = 0.0, - batch2model_keys: Optional[Union[str, List[str]]] = None, - ): - super().__init__() - - assert type in ["l2", "l1", "lpips"] - - self.sigma_sampler = instantiate_from_config(sigma_sampler_config) - self.sigma_sampler_ref = None - if sigma_sampler_config_ref is not None: - self.sigma_sampler_ref = instantiate_from_config(sigma_sampler_config_ref) - - self.type = type - self.offset_noise_level = offset_noise_level - - if type == "lpips": - self.lpips = LPIPS().eval() - - if not batch2model_keys: - batch2model_keys = [] - - if isinstance(batch2model_keys, str): - batch2model_keys = [batch2model_keys] - - self.batch2model_keys = set(batch2model_keys) - - def __call__(self, network, denoiser, conditioner, input, input_rgb, input_ref, pose, mask, mask_ref, opacity, batch): - cond = conditioner(batch) - additional_model_inputs = { - key: batch[key] for key in self.batch2model_keys.intersection(batch) - } - - sigmas = self.sigma_sampler(input.shape[0]).to(input.device) - noise = torch.randn_like(input) - if self.offset_noise_level > 0.0: - noise = noise + self.offset_noise_level * append_dims( - torch.randn(input.shape[0], device=input.device), input.ndim - ) - - additional_model_inputs['pose'] = pose - additional_model_inputs['mask_ref'] = mask_ref - - noised_input = input + noise * append_dims(sigmas, input.ndim) - if self.sigma_sampler_ref is not None: - sigmas_ref = self.sigma_sampler_ref(input.shape[0]).to(input.device) - if input_ref is not None: - noise = torch.randn_like(input_ref) - if self.offset_noise_level > 0.0: - noise = noise + self.offset_noise_level * append_dims( - torch.randn(input_ref.shape[0], device=input_ref.device), input_ref.ndim - ) - input_ref = input_ref + noise * append_dims(sigmas_ref, input_ref.ndim) - additional_model_inputs['sigmas_ref'] = sigmas_ref - - additional_model_inputs['input_ref'] = input_ref - - model_output, fg_mask_list, alphas, predicted_rgb_list = denoiser( - network, noised_input, sigmas, cond, **additional_model_inputs - ) - - w = append_dims(denoiser.w(sigmas), input.ndim) - return self.get_loss(model_output, fg_mask_list, predicted_rgb_list, input, input_rgb, w, mask, mask_ref, opacity, alphas) - - def get_loss(self, model_output, fg_mask_list, predicted_rgb_list, target, target_rgb, w, mask, mask_ref, opacity, alphas_list): - loss_rgb = [] - loss_fg = [] - loss_bg = [] - with torch.amp.autocast(device_type='cuda', dtype=torch.float32): - if self.type == "l2": - loss = (w * (model_output - target) ** 2) - if mask is not None: - loss_l2 = (loss*mask).sum([1, 2, 3])/(mask.sum([1, 2, 3]) + 1e-6) - else: - loss_l2 = torch.mean(loss.reshape(target.shape[0], -1), 1) - if len(fg_mask_list) > 0 and len(alphas_list) > 0: - for fg_mask, alphas in zip(fg_mask_list, alphas_list): - size = int(math.sqrt(fg_mask.size(1))) - opacity = torch.nn.functional.interpolate(opacity, size=size, antialias=True, mode='bilinear').detach() - fg_mask = torch.clamp(fg_mask.reshape(-1, size*size), 0., 1.) - loss_fg_ = ((fg_mask - opacity.reshape(-1, size*size))**2).mean(1) #torch.nn.functional.binary_cross_entropy(rgb, torch.clip(mask.reshape(-1, size*size), 0., 1.), reduce=False) - loss_bg_ = (alphas - opacity.reshape(-1, size*size, 1, 1)).abs()*(1-opacity.reshape(-1, size*size, 1, 1)) #alpahs : b hw d 1 - loss_bg_ = (loss_bg_*((opacity.reshape(-1, size*size, 1, 1) < 0.1)*1)).mean([1, 2, 3]) - loss_fg.append(loss_fg_) - loss_bg.append(loss_bg_) - loss_fg = torch.stack(loss_fg, 1) - loss_bg = torch.stack(loss_bg, 1) - - if len(predicted_rgb_list) > 0: - for rgb in predicted_rgb_list: - size = int(math.sqrt(rgb.size(1))) - mask_ = torch.nn.functional.interpolate(mask, size=size, antialias=True, mode='bilinear').detach() - loss_rgb_ = ((torch.nn.functional.interpolate(target_rgb*0.5+0.5, size=size, antialias=True, mode='bilinear').detach() - rgb.reshape(-1, size, size, 3).permute(0, 3, 1, 2)) ** 2) - loss_rgb.append((loss_rgb_*mask_).sum([1, 2, 3])/(mask.sum([1, 2, 3]) + 1e-6)) - loss_rgb = torch.stack(loss_rgb, 1) - # print(loss_l2, loss_fg, loss_bg, loss_rgb) - return loss_l2, loss_fg, loss_bg, loss_rgb - elif self.type == "l1": - return torch.mean( - (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 - ), loss_rgb - elif self.type == "lpips": - loss = self.lpips(model_output, target).reshape(-1) - return loss, loss_rgb diff --git a/sgm/modules/diffusionmodules/loss_weighting.py b/sgm/modules/diffusionmodules/loss_weighting.py deleted file mode 100644 index e12c0a76635435babd1af33969e82fa284525af8..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/loss_weighting.py +++ /dev/null @@ -1,32 +0,0 @@ -from abc import ABC, abstractmethod - -import torch - - -class DiffusionLossWeighting(ABC): - @abstractmethod - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - pass - - -class UnitWeighting(DiffusionLossWeighting): - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - return torch.ones_like(sigma, device=sigma.device) - - -class EDMWeighting(DiffusionLossWeighting): - def __init__(self, sigma_data: float = 0.5): - self.sigma_data = sigma_data - - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - - -class VWeighting(EDMWeighting): - def __init__(self): - super().__init__(sigma_data=1.0) - - -class EpsWeighting(DiffusionLossWeighting): - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - return sigma**-2.0 diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py deleted file mode 100644 index 4cf9d92140dee8443a0ea6b5cf218f2879ad88f4..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/model.py +++ /dev/null @@ -1,748 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import logging -import math -from typing import Any, Callable, Optional - -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from packaging import version - -logpy = logging.getLogger(__name__) - -try: - import xformers - import xformers.ops - - XFORMERS_IS_AVAILABLE = True -except: - XFORMERS_IS_AVAILABLE = False - logpy.warning("no module 'xformers'. Processing without...") - -from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def attention(self, h_: torch.Tensor) -> torch.Tensor: - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - b, c, h, w = q.shape - q, k, v = map( - lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) - ) - h_ = torch.nn.functional.scaled_dot_product_attention( - q, k, v - ) # scale is dim ** -0.5 per default - # compute attention - - return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) - - def forward(self, x, **kwargs): - h_ = x - h_ = self.attention(h_) - h_ = self.proj_out(h_) - return x + h_ - - -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.attention_op: Optional[Any] = None - - def attention(self, h_: torch.Tensor) -> torch.Tensor: - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) - - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op - ) - - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) - - def forward(self, x, **kwargs): - h_ = x - h_ = self.attention(h_) - h_ = self.proj_out(h_) - return x + h_ - - -class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None, **unused_kwargs): - b, c, h, w = x.shape - x = rearrange(x, "b c h w -> b (h w) c") - out = super().forward(x, context=context, mask=mask) - out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) - return x + out - - -def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in [ - "vanilla", - "vanilla-xformers", - "memory-efficient-cross-attn", - "linear", - "none", - ], f"attn_type {attn_type} unknown" - if ( - version.parse(torch.__version__) < version.parse("2.0.0") - and attn_type != "none" - ): - assert XFORMERS_IS_AVAILABLE, ( - f"We do not support vanilla attention in {torch.__version__} anymore, " - f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" - ) - attn_type = "vanilla-xformers" - logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - assert attn_kwargs is None - return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": - logpy.info( - f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." - ) - return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels - return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - return LinAttnBlock(in_channels) - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla", - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - logpy.info( - "Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape) - ) - ) - - make_attn_cls = self._make_attn() - make_resblock_cls = self._make_resblock() - make_conv_cls = self._make_conv() - # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = make_resblock_cls( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) - self.mid.block_2 = make_resblock_cls( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - make_resblock_cls( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn_cls(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = make_conv_cls( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def _make_attn(self) -> Callable: - return make_attn - - def _make_resblock(self) -> Callable: - return ResnetBlock - - def _make_conv(self) -> Callable: - return torch.nn.Conv2d - - def get_last_layer(self, **kwargs): - return self.conv_out.weight - - def forward(self, z, **kwargs): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb, **kwargs) - h = self.mid.attn_1(h, **kwargs) - h = self.mid.block_2(h, temb, **kwargs) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py deleted file mode 100644 index 909455a19858dc5484df7d3a0786f1262988627d..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ /dev/null @@ -1,1352 +0,0 @@ -import logging -import math -from abc import abstractmethod -from functools import partial -from typing import Iterable, List, Optional, Tuple, Union - -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -import torch -from torch.profiler import profile, record_function, ProfilerActivity -from ...modules.attention import SpatialTransformer -from ...modules.diffusionmodules.util import ( - avg_pool_nd, - checkpoint, - conv_nd, - linear, - normalization, - timestep_embedding, - zero_module, -) -from ...util import default, exists - - -logpy = logging.getLogger(__name__) - - -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 - ) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward( - self, - x, - emb, - context=None, - xr=None, - embr=None, - contextr=None, - pose=None, - mask_ref=None, - prev_weights=None, - ): - weights = None - fg_mask = None - alphas = None - predicted_rgb = None - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - if xr is not None: - with torch.no_grad(): - xr = layer(xr, embr) - xr = xr.detach() - elif isinstance(layer, SpatialTransformer): - x, xr, fg_mask, weights, alphas, predicted_rgb = layer(x, xr, context, contextr, pose, mask_ref, prev_weights=prev_weights) - else: - x = layer(x) - if xr is not None: - with torch.no_grad(): - xr = layer(xr) - xr = xr.detach() - - return x, xr, fg_mask, weights, alphas, predicted_rgb - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__( - self, - channels: int, - use_conv: bool, - dims: int = 2, - out_channels: Optional[int] = None, - padding: int = 1, - third_up: bool = False, - kernel_size: int = 3, - scale_factor: int = 2, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - self.third_up = third_up - self.scale_factor = scale_factor - if use_conv: - self.conv = conv_nd( - dims, self.channels, self.out_channels, kernel_size, padding=padding - ) - - def forward(self, x: th.Tensor) -> th.Tensor: - assert x.shape[1] == self.channels - - if self.dims == 3: - t_factor = 1 if not self.third_up else self.scale_factor - x = F.interpolate( - x, - ( - t_factor * x.shape[2], - x.shape[3] * self.scale_factor, - x.shape[4] * self.scale_factor, - ), - mode="nearest", - ) - else: - x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class TransposedUpsample(nn.Module): - "Learned 2x upsampling without padding" - - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d( - self.channels, self.out_channels, kernel_size=ks, stride=2 - ) - - def forward(self, x): - return self.up(x) - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__( - self, - channels: int, - use_conv: bool, - dims: int = 2, - out_channels: Optional[int] = None, - padding: int = 1, - third_down: bool = False, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) - if use_conv: - logpy.info(f"Building a Downsample layer with {dims} dims.") - logpy.info( - f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " - f"kernel-size: 3, stride: {stride}, padding: {padding}" - ) - if dims == 3: - logpy.info(f" --> Downsampling third axis (time): {third_down}") - self.op = conv_nd( - dims, - self.channels, - self.out_channels, - 3, - stride=stride, - padding=padding, - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x: th.Tensor) -> th.Tensor: - assert x.shape[1] == self.channels - - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels: int, - emb_channels: int, - dropout: float, - out_channels: Optional[int] = None, - use_conv: bool = False, - use_scale_shift_norm: bool = False, - dims: int = 2, - use_checkpoint: bool = False, - up: bool = False, - down: bool = False, - kernel_size: int = 3, - exchange_temb_dims: bool = False, - skip_t_emb: bool = False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - self.exchange_temb_dims = exchange_temb_dims - - if isinstance(kernel_size, Iterable): - padding = [k // 2 for k in kernel_size] - else: - padding = kernel_size // 2 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.skip_t_emb = skip_t_emb - self.emb_out_channels = ( - 2 * self.out_channels if use_scale_shift_norm else self.out_channels - ) - if self.skip_t_emb: - logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}") - assert not self.use_scale_shift_norm - self.emb_layers = None - self.exchange_temb_dims = False - else: - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - self.emb_out_channels, - ), - ) - - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd( - dims, - self.out_channels, - self.out_channels, - kernel_size, - padding=padding, - ) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, kernel_size, padding=padding - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - - if self.skip_t_emb: - emb_out = th.zeros_like(h) - else: - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - if self.exchange_temb_dims: - emb_out = rearrange(emb_out, "b t c ... -> b c t ...") - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels: int, - num_heads: int = 1, - num_head_channels: int = -1, - use_checkpoint: bool = False, - use_new_attention_order: bool = False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x, **kwargs): - # TODO add crossframe attention and use mixed checkpoint - return checkpoint( - self._forward, (x,), self.parameters(), True - ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - - def _forward(self, x: th.Tensor) -> th.Tensor: - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial**2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads: int): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv: th.Tensor) -> th.Tensor: - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads: int): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv: th.Tensor) -> th.Tensor: - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class Timestep(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.dim = dim - - def forward(self, t: th.Tensor) -> th.Tensor: - return timestep_embedding(t, self.dim) - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - in_channels: int, - model_channels: int, - out_channels: int, - num_res_blocks: int, - attention_resolutions: int, - dropout: float = 0.0, - channel_mult: Union[List, Tuple] = (1, 2, 4, 8), - conv_resample: bool = True, - dims: int = 2, - num_classes: Optional[Union[int, str]] = None, - use_checkpoint: bool = False, - num_heads: int = -1, - num_head_channels: int = -1, - num_heads_upsample: int = -1, - use_scale_shift_norm: bool = False, - resblock_updown: bool = False, - transformer_depth: int = 1, - context_dim: Optional[int] = None, - disable_self_attentions: Optional[List[bool]] = None, - num_attention_blocks: Optional[List[int]] = None, - disable_middle_self_attn: bool = False, - use_linear_in_transformer: bool = False, - spatial_transformer_attn_type: str = "softmax", - adm_in_channels: Optional[int] = None, - use_fairscale_checkpoint=False, - offload_to_cpu=False, - transformer_depth_middle: Optional[int] = None, - ## new args - image_cross_blocks: Union[List, Tuple] = None, - rgb: bool = False, - far: float = 2., - num_samples: float = 32, - not_add_context_in_triplane: bool = False, - rgb_predict: bool = False, - add_lora: bool = False, - mode: str = 'feature-nerf', - average: bool = False, - num_freqs: int = 16, - use_prev_weights_imp_sample: bool = False, - stratified: bool = False, - poscontrol_interval: int = 4, - imp_sampling_percent: float = 0.9, - near_plane: float = 0. - ): - super().__init__() - from omegaconf.listconfig import ListConfig - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert ( - num_head_channels != -1 - ), "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert ( - num_heads != -1 - ), "Either num_heads or num_head_channels has to be set" - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.rgb = rgb - self.rgb_predict = rgb_predict - if image_cross_blocks is None: - image_cross_blocks = [] - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - elif isinstance(transformer_depth, ListConfig): - transformer_depth = list(transformer_depth) - transformer_depth_middle = default( - transformer_depth_middle, transformer_depth[-1] - ) - - if isinstance(num_res_blocks, int): - self.num_res_blocks = len(channel_mult) * [num_res_blocks] - else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError( - "provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult" - ) - self.num_res_blocks = num_res_blocks - if disable_self_attentions is not None: - assert len(disable_self_attentions) == len(channel_mult) - if num_attention_blocks is not None: - assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all( - map( - lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], - range(len(num_attention_blocks)), - ) - ) - logpy.info( - f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set." - ) - - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - assert use_fairscale_checkpoint != use_checkpoint or not ( - use_checkpoint or use_fairscale_checkpoint - ) - - self.use_fairscale_checkpoint = False - checkpoint_wrapper_fn = ( - partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) - if self.use_fairscale_checkpoint - else lambda x: x - ) - - time_embed_dim = model_channels * 4 - self.time_embed = checkpoint_wrapper_fn( - nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - ) - - if self.num_classes is not None: - if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - elif self.num_classes == "continuous": - logpy.info("setting up linear c_adm embedding layer") - self.label_emb = nn.Linear(1, time_embed_dim) - elif self.num_classes == "timestep": - self.label_emb = checkpoint_wrapper_fn( - nn.Sequential( - Timestep(model_channels), - nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ), - ) - ) - elif self.num_classes == "sequential": - assert adm_in_channels is not None - self.label_emb = nn.Sequential( - nn.Sequential( - linear(adm_in_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - ) - else: - raise ValueError - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - id_attention = 0 - for level, mult in enumerate(channel_mult): - for nr in range(self.num_res_blocks[level]): - layers = [ - checkpoint_wrapper_fn( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - if context_dim is not None and exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if ( - not exists(num_attention_blocks) - or nr < num_attention_blocks[level] - ): - layers.append( - checkpoint_wrapper_fn( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth[level], - context_dim=context_dim, - disable_self_attn=disabled_sa, - use_linear=use_linear_in_transformer, - attn_type=spatial_transformer_attn_type, - use_checkpoint=use_checkpoint, - # image_cross=False, - image_cross=(id_attention in image_cross_blocks), - rgb_predict=self.rgb_predict, - far=far, - num_samples=num_samples, - add_lora=add_lora, - mode=mode, - average=average, - num_freqs=num_freqs, - use_prev_weights_imp_sample=use_prev_weights_imp_sample, - stratified=stratified, - poscontrol_interval=poscontrol_interval, - imp_sampling_percent=imp_sampling_percent, - near_plane=near_plane, - ) - ) - ) - print("({}) in Encoder".format(id_attention)) - id_attention += 1 - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - checkpoint_wrapper_fn( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - self.middle_block = TimestepEmbedSequential( - checkpoint_wrapper_fn( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ), - checkpoint_wrapper_fn( - SpatialTransformer( # always uses a self-attn - ch, - num_heads, - dim_head, - depth=transformer_depth_middle, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - attn_type=spatial_transformer_attn_type, - use_checkpoint=use_checkpoint, - image_cross=(id_attention in image_cross_blocks), - rgb_predict=self.rgb_predict, - far=far, - num_samples=num_samples, - add_lora=add_lora, - mode=mode, - average=average, - num_freqs=num_freqs, - use_prev_weights_imp_sample=use_prev_weights_imp_sample, - stratified=stratified, - poscontrol_interval=poscontrol_interval, - imp_sampling_percent=imp_sampling_percent, - near_plane=near_plane, - ) - ), - checkpoint_wrapper_fn( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ), - ) - - print("({}) in Middle".format(id_attention)) - id_attention += 1 - - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(self.num_res_blocks[level] + 1): - ich = input_block_chans.pop() - layers = [ - checkpoint_wrapper_fn( - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - if exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if ( - not exists(num_attention_blocks) - or i < num_attention_blocks[level] - ): - layers.append( - checkpoint_wrapper_fn( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth[level], - context_dim=context_dim, - disable_self_attn=disabled_sa, - use_linear=use_linear_in_transformer, - attn_type=spatial_transformer_attn_type, - use_checkpoint=use_checkpoint, - image_cross=(id_attention in image_cross_blocks), - rgb_predict=self.rgb_predict, - far=far, - num_samples=num_samples, - add_lora=add_lora, - mode=mode, - average=average, - num_freqs=num_freqs, - use_prev_weights_imp_sample=use_prev_weights_imp_sample, - stratified=stratified, - poscontrol_interval=poscontrol_interval, - imp_sampling_percent=imp_sampling_percent, - near_plane=near_plane, - ) - ) - ) - print("({}) in Decoder".format(id_attention)) - id_attention += 1 - if level and i == self.num_res_blocks[level]: - out_ch = ch - layers.append( - checkpoint_wrapper_fn( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = checkpoint_wrapper_fn( - nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - ) - - def forward( - self, - x: th.Tensor, - timesteps: Optional[th.Tensor] = None, - context: Optional[th.Tensor] = None, - y: Optional[th.Tensor] = None, - timesteps2: Optional[th.Tensor] = None, - **kwargs, - ): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. - """ - with torch.amp.autocast(device_type='cuda', dtype=torch.float32 if self.training else torch.float16): - b = x.size(0) - contextr = None - reference_image = False - pose = None - mask_ref = None - embr = None - fg_mask_list = [] - use_img_cond = True - alphas_list = [] - predicted_rgb_list = [] - - if 'pose' in kwargs: - pose = kwargs['pose'] - if 'mask_ref' in kwargs: - mask_ref = kwargs['mask_ref'] - if 'input_ref' in kwargs: - reference_image = True - contextr = context[b:] - if y is not None: - yr = y[b:] - xr = kwargs['input_ref'] - if xr is not None: - b, n = xr.shape[:2] - - context = context[: b] - - if y is not None: - y = y[:b] - - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - hr = None - hs = [] - hrs = [] - use_img_cond = True - - if reference_image: - with torch.no_grad(): - if 'sigmas_ref' in kwargs: - t_embr = timestep_embedding(kwargs['sigmas_ref'], self.model_channels, repeat_only=False) - elif timesteps2 is not None: - t_embr = timestep_embedding(timesteps2, self.model_channels, repeat_only=False) - else: - t_embr = timestep_embedding(torch.zeros_like(timesteps), self.model_channels, repeat_only=False) - embr = (self.time_embed(t_embr)[:, None].expand(-1, xr.size(1), -1)).reshape(b*n, -1) - if self.num_classes is not None: - embr = embr + self.label_emb(yr.reshape(b*n, -1)) - hr = rearrange(xr, "b n ... -> (b n) ...", b=b, n=n) - hr = hr.to(memory_format=torch.channels_last) - - for module in self.input_blocks: - h, hr, fg_mask, weights, alphas, predicted_rgb = module(h, emb, context, hr, embr, contextr, pose, mask_ref=mask_ref, prev_weights=None) - if fg_mask is not None: - fg_mask_list += fg_mask - if alphas is not None: - alphas_list += alphas - if predicted_rgb is not None: - predicted_rgb_list.extend(predicted_rgb) - hs.append(h) - hrs.append(hr) - - h, hr, fg_mask, weights, alphas, predicted_rgb = self.middle_block(h, emb, context, hr, embr, contextr, pose, mask_ref=mask_ref, prev_weights=None) - - if fg_mask is not None: - fg_mask_list += fg_mask - if alphas is not None: - alphas_list += alphas - if predicted_rgb is not None: - predicted_rgb_list.extend(predicted_rgb) - - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - if reference_image: - hr = th.cat([hr, hrs.pop()], dim=1) - h, hr, fg_mask, weights, alphas, predicted_rgb = module(h, emb, context, hr, embr, contextr, pose, mask_ref=mask_ref, prev_weights=None) - if fg_mask is not None: - fg_mask_list += fg_mask - if alphas is not None: - alphas_list += alphas - if predicted_rgb is not None: - predicted_rgb_list.extend(predicted_rgb) - - h = h.type(x.dtype) - if reference_image: - hr = hr.type(xr.dtype) - out = self.out(h) - - if use_img_cond: - return out, fg_mask_list, alphas_list, predicted_rgb_list - else: - return out - - -class NoTimeUNetModel(UNetModel): - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - timesteps = th.zeros_like(timesteps) - return super().forward(x, timesteps, context, y, **kwargs) - - -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - For usage, see UNet. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - results = [] - # h = x.type(self.dtype) - h = x - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(x.dtype) - return self.out(h) - - -if __name__ == "__main__": - - class Dummy(nn.Module): - def __init__(self, in_channels=3, model_channels=64): - super().__init__() - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(2, in_channels, model_channels, 3, padding=1) - ) - ] - ) - - model = UNetModel( - use_checkpoint=True, - image_size=64, - in_channels=4, - out_channels=4, - model_channels=128, - attention_resolutions=[4, 2], - num_res_blocks=2, - channel_mult=[1, 2, 4], - num_head_channels=64, - use_spatial_transformer=False, - use_linear_in_transformer=True, - transformer_depth=1, - legacy=False, - ).cuda() - x = th.randn(11, 4, 64, 64).cuda() - t = th.randint(low=0, high=10, size=(11,), device="cuda") - o = model(x, t) - print("done.") diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py deleted file mode 100644 index e82cc26b2d338dc05a0edd82c560892c108118e7..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/sampling.py +++ /dev/null @@ -1,465 +0,0 @@ -""" - Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py -""" - - -from typing import Dict, Union - -import torch -from omegaconf import ListConfig, OmegaConf -from tqdm import tqdm - -from ...modules.diffusionmodules.sampling_utils import ( - get_ancestral_step, - linear_multistep_coeff, - to_d, - to_neg_log_sigma, - to_sigma, -) -from ...util import append_dims, default, instantiate_from_config - -DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} - - -class BaseDiffusionSampler: - def __init__( - self, - discretization_config: Union[Dict, ListConfig, OmegaConf], - num_steps: Union[int, None] = None, - guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, - verbose: bool = False, - device: str = "cuda", - ): - self.num_steps = num_steps - self.discretization = instantiate_from_config(discretization_config) - self.guider = instantiate_from_config( - default( - guider_config, - DEFAULT_GUIDER, - ) - ) - self.verbose = verbose - self.device = device - - def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): - sigmas = self.discretization( - self.num_steps if num_steps is None else num_steps, device=self.device - ) - uc = default(uc, cond) - - x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) - num_sigmas = len(sigmas) - - s_in = x.new_ones([x.shape[0]]) - - return x, s_in, sigmas, num_sigmas, cond, uc - - def denoise(self, x, denoiser, sigma, cond, uc): - denoised, _, _, rgb_list = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) - denoised = self.guider(denoised, sigma) - return denoised, rgb_list - - def get_sigma_gen(self, num_sigmas): - sigma_generator = range(num_sigmas - 1) - if self.verbose: - print("#" * 30, " Sampling setting ", "#" * 30) - print(f"Sampler: {self.__class__.__name__}") - print(f"Discretization: {self.discretization.__class__.__name__}") - print(f"Guider: {self.guider.__class__.__name__}") - sigma_generator = tqdm( - sigma_generator, - total=num_sigmas, - desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", - ) - return sigma_generator - - -class SingleStepDiffusionSampler(BaseDiffusionSampler): - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): - raise NotImplementedError - - def euler_step(self, x, d, dt): - return x + dt * d - - -class EDMSampler(SingleStepDiffusionSampler): - def __init__( - self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs - ): - super().__init__(*args, **kwargs) - - self.s_churn = s_churn - self.s_tmin = s_tmin - self.s_tmax = s_tmax - self.s_noise = s_noise - - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): - sigma_hat = sigma * (gamma + 1.0) - if gamma > 0: - eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 - - denoised, rgb_list = self.denoise(x, denoiser, sigma_hat, cond, uc) - d = to_d(x, sigma_hat, denoised) - dt = append_dims(next_sigma - sigma_hat, x.ndim) - - euler_step = self.euler_step(x, d, dt) - x = self.possible_correction_step( - euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ) - return x, rgb_list - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, mask=None, init_im=None): - return self.forward(denoiser, x, cond, uc=uc, num_steps=num_steps, mask=mask, init_im=init_im) - - def forward(self, denoiser, x, cond, uc=None, num_steps=None, mask=None, init_im=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - for i in self.get_sigma_gen(num_sigmas): - gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) - if self.s_tmin <= sigmas[i] <= self.s_tmax - else 0.0 - ) - x_new, rgb_list = self.sampler_step( - s_in * sigmas[i], - s_in * sigmas[i + 1], - denoiser, - x, - cond, - uc, - gamma, - ) - x = x_new - - return x, rgb_list - - -def get_views(panorama_height, panorama_width, window_size=64, stride=48): - # panorama_height /= 8 - # panorama_width /= 8 - num_blocks_height = (panorama_height - window_size) // stride + 1 - num_blocks_width = (panorama_width - window_size) // stride + 1 - total_num_blocks = int(num_blocks_height * num_blocks_width) - views = [] - for i in range(total_num_blocks): - h_start = int((i // num_blocks_width) * stride) - h_end = h_start + window_size - w_start = int((i % num_blocks_width) * stride) - w_end = w_start + window_size - views.append((h_start, h_end, w_start, w_end)) - return views - - -class EDMMultidiffusionSampler(SingleStepDiffusionSampler): - def __init__( - self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs - ): - super().__init__(*args, **kwargs) - - self.s_churn = s_churn - self.s_tmin = s_tmin - self.s_tmax = s_tmax - self.s_noise = s_noise - - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): - sigma_hat = sigma * (gamma + 1.0) - if gamma > 0: - eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 - - denoised, rgb_list = self.denoise(x, denoiser, sigma_hat, cond, uc) - d = to_d(x, sigma_hat, denoised) - dt = append_dims(next_sigma - sigma_hat, x.ndim) - - euler_step = self.euler_step(x, d, dt) - x = self.possible_correction_step( - euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ) - return x, rgb_list - - def __call__(self, denoiser, model, x, cond, uc=None, num_steps=None, multikwargs=None): - return self.forward(denoiser, model, x, cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs) - - def forward(self, denoiser, model, x, cond, uc=None, num_steps=None, multikwargs=None): - views = get_views(x.shape[-2], 48*(len(multikwargs)+1)) - shape = x.shape - x = torch.randn(shape[0], shape[1], shape[2], 48*(len(multikwargs)+1)).to(x.device) - count = torch.zeros_like(x, device=x.device) - value = torch.zeros_like(x, device=x.device) - - x, s_in, sigmas, num_sigmas, cond_, uc = self.prepare_sampling_loop( - x, cond[0], uc[0], num_steps - ) - - for i in self.get_sigma_gen(num_sigmas): - gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) - if self.s_tmin <= sigmas[i] <= self.s_tmax - else 0.0 - ) - count.zero_() - value.zero_() - - for j, (h_start, h_end, w_start, w_end) in enumerate(views): - # TODO we can support batches, and pass multiple views at once to the unet - latent_view = x[:, :, h_start:h_end, w_start:w_end] - # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. - kwargs = {'pose': multikwargs[j]['pose'], 'mask_ref':None, 'drop_im':j} - x_new, rgb_list = self.sampler_step( - s_in * sigmas[i], - s_in * sigmas[i + 1], - lambda input, sigma, c: denoiser( - model, input, sigma, c, **kwargs - ), - latent_view, - cond[j], - uc, - gamma, - ) - # compute the denoising step with the reference model - value[:, :, h_start:h_end, w_start:w_end] += x_new - count[:, :, h_start:h_end, w_start:w_end] += 1 - - # take the MultiDiffusion step - x = torch.where(count > 0, value / count, value) - - return x, rgb_list - - def possible_correction_step( - self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ): - return euler_step - - -class AncestralSampler(SingleStepDiffusionSampler): - def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.eta = eta - self.s_noise = s_noise - self.noise_sampler = lambda x: torch.randn_like(x) - - def ancestral_euler_step(self, x, denoised, sigma, sigma_down): - d = to_d(x, sigma, denoised) - dt = append_dims(sigma_down - sigma, x.ndim) - - return self.euler_step(x, d, dt) - - def ancestral_step(self, x, sigma, next_sigma, sigma_up): - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, - x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), - x, - ) - return x - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - - for i in self.get_sigma_gen(num_sigmas): - x = self.sampler_step( - s_in * sigmas[i], - s_in * sigmas[i + 1], - denoiser, - x, - cond, - uc, - ) - - return x - - -class LinearMultistepSampler(BaseDiffusionSampler): - def __init__( - self, - order=4, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.order = order - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - - ds = [] - sigmas_cpu = sigmas.detach().cpu().numpy() - for i in self.get_sigma_gen(num_sigmas): - sigma = s_in * sigmas[i] - denoised, _ = denoiser( - *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs - ) - denoised = self.guider(denoised, sigma) - d = to_d(x, sigma, denoised) - ds.append(d) - if len(ds) > self.order: - ds.pop(0) - cur_order = min(i + 1, self.order) - coeffs = [ - linear_multistep_coeff(cur_order, sigmas_cpu, i, j) - for j in range(cur_order) - ] - x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) - - return x - - -class EulerEDMSampler(EDMSampler): - def possible_correction_step( - self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ): - return euler_step - - -class HeunEDMSampler(EDMSampler): - def possible_correction_step( - self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ): - if torch.sum(next_sigma) < 1e-14: - # Save a network evaluation if all noise levels are 0 - return euler_step - else: - denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) - d_new = to_d(euler_step, next_sigma, denoised) - d_prime = (d + d_new) / 2.0 - - # apply correction if noise level is not 0 - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step - ) - return x - - -class EulerAncestralSampler(AncestralSampler): - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): - sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) - denoised = self.denoise(x, denoiser, sigma, cond, uc) - x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) - x = self.ancestral_step(x, sigma, next_sigma, sigma_up) - - return x - - -class DPMPP2SAncestralSampler(AncestralSampler): - def get_variables(self, sigma, sigma_down): - t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] - h = t_next - t - s = t + 0.5 * h - return h, s, t, t_next - - def get_mult(self, h, s, t, t_next): - mult1 = to_sigma(s) / to_sigma(t) - mult2 = (-0.5 * h).expm1() - mult3 = to_sigma(t_next) / to_sigma(t) - mult4 = (-h).expm1() - - return mult1, mult2, mult3, mult4 - - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): - sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) - denoised = self.denoise(x, denoiser, sigma, cond, uc) - x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) - - if torch.sum(sigma_down) < 1e-14: - # Save a network evaluation if all noise levels are 0 - x = x_euler - else: - h, s, t, t_next = self.get_variables(sigma, sigma_down) - mult = [ - append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) - ] - - x2 = mult[0] * x - mult[1] * denoised - denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) - x_dpmpp2s = mult[2] * x - mult[3] * denoised2 - - # apply correction if noise level is not 0 - x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) - - x = self.ancestral_step(x, sigma, next_sigma, sigma_up) - return x - - -class DPMPP2MSampler(BaseDiffusionSampler): - def get_variables(self, sigma, next_sigma, previous_sigma=None): - t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] - h = t_next - t - - if previous_sigma is not None: - h_last = t - to_neg_log_sigma(previous_sigma) - r = h_last / h - return h, r, t, t_next - else: - return h, None, t, t_next - - def get_mult(self, h, r, t, t_next, previous_sigma): - mult1 = to_sigma(t_next) / to_sigma(t) - mult2 = (-h).expm1() - - if previous_sigma is not None: - mult3 = 1 + 1 / (2 * r) - mult4 = 1 / (2 * r) - return mult1, mult2, mult3, mult4 - else: - return mult1, mult2 - - def sampler_step( - self, - old_denoised, - previous_sigma, - sigma, - next_sigma, - denoiser, - x, - cond, - uc=None, - ): - denoised = self.denoise(x, denoiser, sigma, cond, uc) - - h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) - mult = [ - append_dims(mult, x.ndim) - for mult in self.get_mult(h, r, t, t_next, previous_sigma) - ] - - x_standard = mult[0] * x - mult[1] * denoised - if old_denoised is None or torch.sum(next_sigma) < 1e-14: - # Save a network evaluation if all noise levels are 0 or on the first step - return x_standard, denoised - else: - denoised_d = mult[2] * denoised - mult[3] * old_denoised - x_advanced = mult[0] * x - mult[1] * denoised_d - - # apply correction if noise level is not 0 and not first step - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard - ) - - return x, denoised - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - - old_denoised = None - for i in self.get_sigma_gen(num_sigmas): - x, old_denoised = self.sampler_step( - old_denoised, - None if i == 0 else s_in * sigmas[i - 1], - s_in * sigmas[i], - s_in * sigmas[i + 1], - denoiser, - x, - cond, - uc=uc, - ) - - return x diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py deleted file mode 100644 index 7cca6361c2c6aeb97940b314eea5a607f1cd6a59..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/sampling_utils.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from scipy import integrate - -from ...util import append_dims - - -class NoDynamicThresholding: - def __call__(self, uncond, cond, scale): - return uncond + scale * (cond - uncond) - - -def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): - if order - 1 > i: - raise ValueError(f"Order {order} too high for step {i}") - - def fn(tau): - prod = 1.0 - for k in range(order): - if j == k: - continue - prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) - return prod - - return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] - - -def get_ancestral_step(sigma_from, sigma_to, eta=1.0): - if not eta: - return sigma_to, 0.0 - sigma_up = torch.minimum( - sigma_to, - eta - * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, - ) - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - return sigma_down, sigma_up - - -def to_d(x, sigma, denoised): - return (x - denoised) / append_dims(sigma, x.ndim) - - -def to_neg_log_sigma(sigma): - return sigma.log().neg() - - -def to_sigma(neg_log_sigma): - return neg_log_sigma.neg().exp() diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py deleted file mode 100644 index 154732eab91b23dad08663e03edb5c44616d2724..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/sigma_sampling.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch - -from ...util import default, instantiate_from_config - - -class EDMSampling: - def __init__(self, p_mean=-1.2, p_std=1.2): - self.p_mean = p_mean - self.p_std = p_std - - def __call__(self, n_samples, rand=None): - log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) - return log_sigma.exp() - - -class DiscreteSampling: - def __init__(self, discretization_config, num_idx, num_idx_start=0, do_append_zero=False, flip=True): - self.num_idx = num_idx - self.num_idx_start = num_idx_start - self.sigmas = instantiate_from_config(discretization_config)( - num_idx, do_append_zero=do_append_zero, flip=flip - ) - - def idx_to_sigma(self, idx): - return self.sigmas[idx] - - def __call__(self, n_samples, rand=None): - idx = default( - rand, - torch.randint(self.num_idx_start, self.num_idx, (n_samples,)), - ) - return self.idx_to_sigma(idx) - - -class CubicSampling: - def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): - self.num_idx = num_idx - self.sigmas = instantiate_from_config(discretization_config)( - num_idx, do_append_zero=do_append_zero, flip=flip - ) - - def idx_to_sigma(self, idx): - return self.sigmas[idx] - - def __call__(self, n_samples, rand=None): - t = torch.rand((n_samples,)) - t = (1 - t ** 3) * (self.num_idx-1) - t = t.long() - idx = default( - rand, - t, - ) - return self.idx_to_sigma(idx) - diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py deleted file mode 100644 index 15b03dfa355de7a3c4ae0f1211a31c7412f438a1..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/util.py +++ /dev/null @@ -1,344 +0,0 @@ -""" -adopted from -https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -and -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -and -https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py - -thanks! -""" - -import math - -import torch -import torch.nn as nn -from einops import repeat - - -def make_beta_schedule( - schedule, - n_timestep, - linear_start=1e-4, - linear_end=2e-2, -): - if schedule == "linear": - betas = ( - torch.linspace( - linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 - ) - ** 2 - ) - return betas.numpy() - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def mixed_checkpoint(func, inputs: dict, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function - borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that - it also works with non-tensor inputs - :param func: the function to evaluate. - :param inputs: the argument dictionary to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] - tensor_inputs = [ - inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) - ] - non_tensor_keys = [ - key for key in inputs if not isinstance(inputs[key], torch.Tensor) - ] - non_tensor_inputs = [ - inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) - ] - args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) - return MixedCheckpointFunction.apply( - func, - len(tensor_inputs), - len(non_tensor_inputs), - tensor_keys, - non_tensor_keys, - *args, - ) - else: - return func(**inputs) - - -class MixedCheckpointFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - run_function, - length_tensors, - length_non_tensors, - tensor_keys, - non_tensor_keys, - *args, - ): - ctx.end_tensors = length_tensors - ctx.end_non_tensors = length_tensors + length_non_tensors - ctx.gpu_autocast_kwargs = { - "enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled(), - } - assert ( - len(tensor_keys) == length_tensors - and len(non_tensor_keys) == length_non_tensors - ) - - ctx.input_tensors = { - key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) - } - ctx.input_non_tensors = { - key: val - for (key, val) in zip( - non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) - ) - } - ctx.run_function = run_function - ctx.input_params = list(args[ctx.end_non_tensors :]) - - with torch.no_grad(): - output_tensors = ctx.run_function( - **ctx.input_tensors, **ctx.input_non_tensors - ) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} - ctx.input_tensors = { - key: ctx.input_tensors[key].detach().requires_grad_(True) - for key in ctx.input_tensors - } - - with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = { - key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) - for key in ctx.input_tensors - } - # shallow_copies.update(additional_args) - output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) - input_grads = torch.autograd.grad( - output_tensors, - list(ctx.input_tensors.values()) + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return ( - (None, None, None, None, None) - + input_grads[: ctx.end_tensors] - + (None,) * (ctx.end_non_tensors - ctx.end_tensors) - + input_grads[ctx.end_tensors :] - ) - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - ctx.gpu_autocast_kwargs = { - "enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled(), - } - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - - -def timestep_embedding_pose(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def ones_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().data.fill_(1.) - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py deleted file mode 100644 index 7cd3b52c648ba623fce59512e3366c0053fa3286..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/wrappers.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -import torch.nn as nn -from packaging import version - -OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" - - -class IdentityWrapper(nn.Module): - def __init__(self, diffusion_model, compile_model: bool = False): - super().__init__() - compile = ( - torch.compile - if (version.parse(torch.__version__) >= version.parse("2.0.0")) - and compile_model - else lambda x: x - ) - self.diffusion_model = compile(diffusion_model) - - def forward(self, *args, **kwargs): - return self.diffusion_model(*args, **kwargs) - - -class OpenAIWrapper(IdentityWrapper): - def forward( - self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs - ) -> torch.Tensor: - x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) - return self.diffusion_model( - x, - timesteps=t, - context=c.get("crossattn", None), - y=c.get("vector", None), - **kwargs - ) - diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py deleted file mode 100644 index 016be35523187ea366db9ade391fe8ee276db60b..0000000000000000000000000000000000000000 --- a/sgm/modules/distributions/distributions.py +++ /dev/null @@ -1,102 +0,0 @@ -import numpy as np -import torch - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( - device=self.parameters.device - ) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/sgm/modules/distributions/distributions1.py b/sgm/modules/distributions/distributions1.py deleted file mode 100644 index 0b61f03077358ce4737c85842d9871f70dabb656..0000000000000000000000000000000000000000 --- a/sgm/modules/distributions/distributions1.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import numpy as np - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( - device=self.parameters.device - ) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py deleted file mode 100644 index 97b5ae2b230f89b4dba57e44c4f851478ad86f68..0000000000000000000000000000000000000000 --- a/sgm/modules/ema.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.m_name2s_name = {} - self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) - self.register_buffer( - "num_updates", - torch.tensor(0, dtype=torch.int) - if use_num_upates - else torch.tensor(-1, dtype=torch.int), - ) - - for name, p in model.named_parameters(): - if p.requires_grad: - # remove as '.'-character is not allowed in buffers - s_name = name.replace(".", "") - self.m_name2s_name.update({name: s_name}) - self.register_buffer(s_name, p.clone().detach().data) - - self.collected_params = [] - - def reset_num_updates(self): - del self.num_updates - self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) - - def forward(self, model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_( - one_minus_decay * (shadow_params[sname] - m_param[key]) - ) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py deleted file mode 100644 index 9e7e13eaa418cbf062ecdcaa404a5e7c528deb3e..0000000000000000000000000000000000000000 --- a/sgm/modules/encoders/modules.py +++ /dev/null @@ -1,1154 +0,0 @@ -from contextlib import nullcontext -from functools import partial -from typing import Dict, List, Optional, Tuple, Union -from packaging import version - -import kornia -import numpy as np -import open_clip -from open_clip.tokenizer import SimpleTokenizer -import torch -import torch.nn as nn -from einops import rearrange, repeat -from omegaconf import ListConfig -from torch.utils.checkpoint import checkpoint -import transformers -from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer, - T5EncoderModel, T5Tokenizer) - -from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer -from ...modules.diffusionmodules.model import Encoder -from ...modules.diffusionmodules.openaimodel import Timestep -from ...modules.diffusionmodules.util import (extract_into_tensor, - make_beta_schedule) -from ...modules.distributions.distributions import DiagonalGaussianDistribution -from ...util import (append_dims, autocast, count_params, default, - disabled_train, expand_dims_like, instantiate_from_config) - - -class AbstractEmbModel(nn.Module): - def __init__(self): - super().__init__() - self._is_trainable = None - self._ucg_rate = None - self._input_key = None - - @property - def is_trainable(self) -> bool: - return self._is_trainable - - @property - def ucg_rate(self) -> Union[float, torch.Tensor]: - return self._ucg_rate - - @property - def input_key(self) -> str: - return self._input_key - - @is_trainable.setter - def is_trainable(self, value: bool): - self._is_trainable = value - - @ucg_rate.setter - def ucg_rate(self, value: Union[float, torch.Tensor]): - self._ucg_rate = value - - @input_key.setter - def input_key(self, value: str): - self._input_key = value - - @is_trainable.deleter - def is_trainable(self): - del self._is_trainable - - @ucg_rate.deleter - def ucg_rate(self): - del self._ucg_rate - - @input_key.deleter - def input_key(self): - del self._input_key - - -class GeneralConditioner(nn.Module): - OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} - KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} - - def __init__(self, emb_models: Union[List, ListConfig]): - super().__init__() - embedders = [] - for n, embconfig in enumerate(emb_models): - embedder = instantiate_from_config(embconfig) - assert isinstance( - embedder, AbstractEmbModel - ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" - embedder.is_trainable = embconfig.get("is_trainable", False) - embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) - if not embedder.is_trainable: - embedder.train = disabled_train - for param in embedder.parameters(): - param.requires_grad = False - embedder.eval() - print( - f"Initialized embedder #{n}: {embedder.__class__.__name__} " - f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" - ) - - if "input_key" in embconfig: - embedder.input_key = embconfig["input_key"] - elif "input_keys" in embconfig: - embedder.input_keys = embconfig["input_keys"].split(',') - else: - raise KeyError( - f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" - ) - - embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) - if embedder.legacy_ucg_val is not None: - embedder.ucg_prng = np.random.RandomState() - - embedders.append(embedder) - self.embedders = nn.ModuleList(embedders) - - def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: - assert embedder.legacy_ucg_val is not None - p = embedder.ucg_rate - val = embedder.legacy_ucg_val - for i in range(len(batch[embedder.input_key])): - if embedder.ucg_prng.choice(2, p=[1 - p, p]): - batch[embedder.input_key][i] = val - return batch - - def forward( - self, batch: Dict, force_zero_embeddings: Optional[List] = None, force_ref_zero_embeddings: bool = False - ) -> Dict: - output = dict() - if force_zero_embeddings is None: - force_zero_embeddings = [] - for embedder in self.embedders: - embedding_context = nullcontext if (embedder.is_trainable or embedder.modifier_token is not None) else torch.no_grad - with embedding_context(): - if hasattr(embedder, "input_key") and (embedder.input_key is not None): - if embedder.legacy_ucg_val is not None: - batch = self.possibly_get_ucg_val(embedder, batch) - emb_out = embedder(batch[embedder.input_key]) - elif hasattr(embedder, "input_keys"): - if force_ref_zero_embeddings: - emb_out = embedder(batch[embedder.input_keys[0]]) - else: - emb_out = [embedder(batch[k]) for k in embedder.input_keys] - if isinstance(emb_out[0], tuple): - emb_out = [torch.cat([x[0] for x in emb_out]), torch.cat([x[1] for x in emb_out])] - else: - emb_out = torch.cat(emb_out) - - assert isinstance( - emb_out, (torch.Tensor, list, tuple) - ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" - if not isinstance(emb_out, (list, tuple)): - emb_out = [emb_out] - for emb in emb_out: - out_key = self.OUTPUT_DIM2KEYS[emb.dim()] - if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: - emb = ( - expand_dims_like( - torch.bernoulli( - (1.0 - embedder.ucg_rate) - * torch.ones(emb.shape[0], device=emb.device) - ), - emb, - ) - * emb - ) - if ( - hasattr(embedder, "input_key") - and embedder.input_key in force_zero_embeddings - ): - emb = torch.zeros_like(emb) - if ( - hasattr(embedder, "input_keys") - and embedder.input_keys in force_zero_embeddings - ): - emb = torch.zeros_like(emb) - if out_key in output: - if hasattr(embedder, "input_keys"): - catdim = 1 if ('pose' in embedder.input_keys) else self.KEY2CATDIM[out_key] - if not force_ref_zero_embeddings: - c, c1 = emb.chunk(2) - output[out_key] = torch.cat( - (output[out_key], c), catdim - ) - output[out_key+'_ref'] = torch.cat( - (output[out_key+'_ref'], c1), catdim - ) - else: - # print(output[out_key].size(), emb.size(), "$") - output[out_key] = torch.cat( - (output[out_key], emb), catdim - ) - else: - catdim = 1 if ('pose' in embedder.input_key and emb.size(1) != 77) else self.KEY2CATDIM[out_key] - output[out_key] = torch.cat( - (output[out_key], emb), catdim - ) - else: - if hasattr(embedder, "input_keys"): - if not force_ref_zero_embeddings: - c, c1 = emb.chunk(2) - output[out_key] = c - output[out_key+'_ref'] = c1 - else: - output[out_key] = emb - else: - output[out_key] = emb - - for out_key in self.OUTPUT_DIM2KEYS.values(): - if out_key+'_ref' in output and not force_ref_zero_embeddings: - output[out_key] = torch.cat([output[out_key], output[out_key+'_ref']], 0) - del output[out_key+'_ref'] - - return output - - def get_unconditional_conditioning( - self, - batch_c: Dict, - batch_uc: Optional[Dict] = None, - force_uc_zero_embeddings: Optional[List[str]] = None, - force_ref_zero_embeddings: Optional[List[str]] = None, - ): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - ucg_rates = list() - for embedder in self.embedders: - ucg_rates.append(embedder.ucg_rate) - embedder.ucg_rate = 0.0 - c = self(batch_c, force_ref_zero_embeddings=force_ref_zero_embeddings) - uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings, force_ref_zero_embeddings) - - for embedder, rate in zip(self.embedders, ucg_rates): - embedder.ucg_rate = rate - return c, uc - - -class InceptionV3(nn.Module): - """Wrapper around the https://github.com/mseitzer/pytorch-fid inception - port with an additional squeeze at the end""" - - def __init__(self, normalize_input=False, **kwargs): - super().__init__() - from pytorch_fid import inception - - kwargs["resize_input"] = True - self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) - - def forward(self, inp): - outp = self.model(inp) - - if len(outp) == 1: - return outp[0].squeeze() - - return outp - - -class IdentityEncoder(AbstractEmbModel): - def encode(self, x): - return x - - def forward(self, x): - return x - - -class ClassEmbedder(AbstractEmbModel): - def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): - super().__init__() - self.embedding = nn.Embedding(n_classes, embed_dim) - self.n_classes = n_classes - self.add_sequence_dim = add_sequence_dim - - def forward(self, c): - c = self.embedding(c) - if self.add_sequence_dim: - c = c[:, None, :] - return c - - def get_unconditional_conditioning(self, bs, device="cuda"): - uc_class = ( - self.n_classes - 1 - ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) - uc = torch.ones((bs,), device=device) * uc_class - uc = {self.key: uc.long()} - return uc - - -class ClassEmbedderForMultiCond(ClassEmbedder): - def forward(self, batch, key=None, disable_dropout=False): - out = batch - key = default(key, self.key) - islist = isinstance(batch[key], list) - if islist: - batch[key] = batch[key][0] - c_out = super().forward(batch, key, disable_dropout) - out[key] = [c_out] if islist else c_out - return out - - -class FrozenT5Embedder(AbstractEmbModel): - """Uses the T5 transformer encoder for text""" - - def __init__( - self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True - ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl - super().__init__() - self.tokenizer = T5Tokenizer.from_pretrained(version) - self.transformer = T5EncoderModel.from_pretrained(version) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - - def freeze(self): - self.transformer = self.transformer.eval() - - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - with torch.autocast("cuda", enabled=False): - outputs = self.transformer(input_ids=tokens) - z = outputs.last_hidden_state - return z - - def encode(self, text): - return self(text) - - -class FrozenByT5Embedder(AbstractEmbModel): - """ - Uses the ByT5 transformer encoder for text. Is character-aware. - """ - - def __init__( - self, version="google/byt5-base", device="cuda", max_length=77, freeze=True - ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl - super().__init__() - self.tokenizer = ByT5Tokenizer.from_pretrained(version) - self.transformer = T5EncoderModel.from_pretrained(version) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - - def freeze(self): - self.transformer = self.transformer.eval() - - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - with torch.autocast("cuda", enabled=False): - outputs = self.transformer(input_ids=tokens) - z = outputs.last_hidden_state - return z - - def encode(self, text): - return self(text) - - -class FrozenCLIPEmbedder(AbstractEmbModel): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - modifier_token=None, - version="openai/clip-vit-large-patch14", - device="cuda", - max_length=77, - freeze=True, - layer="last", - layer_idx=None, - always_return_pooled=False, - ): # clip-vit-base-patch32 - super().__init__() - assert layer in self.LAYERS - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version) - self.device = device - self.max_length = max_length - self.modifier_token = modifier_token - if self.modifier_token is not None: - if '+' in self.modifier_token: - self.modifier_token = self.modifier_token.split('+') - else: - self.modifier_token = [self.modifier_token] - - self.add_token() - - if freeze: - self.freeze() - self.layer = layer - self.layer_idx = layer_idx - self.return_pooled = always_return_pooled - if layer == "hidden": - assert layer_idx is not None - assert 0 <= abs(layer_idx) <= 12 - - def add_token(self): - self.modifier_token_id = [] - for each_modifier_token in self.modifier_token: - print(each_modifier_token, "adding new token") - _ = self.tokenizer.add_tokens(each_modifier_token) - modifier_token_id = self.tokenizer.convert_tokens_to_ids(each_modifier_token) - self.modifier_token_id.append(modifier_token_id) - - self.transformer.resize_token_embeddings(len(self.tokenizer)) - token_embeds = self.transformer.get_input_embeddings().weight.data - token_embeds[self.modifier_token_id[-1]] = torch.nn.Parameter(token_embeds[42170], requires_grad=True) - if len(self.modifier_token) == 2: - token_embeds[self.modifier_token_id[-2]] = torch.nn.Parameter(token_embeds[47629], requires_grad=True) - if len(self.modifier_token) == 3: - token_embeds[self.modifier_token_id[-3]] = torch.nn.Parameter(token_embeds[43514], requires_grad=True) - - def freeze(self): - if self.modifier_token is not None: - self.transformer = self.transformer.eval() - for param in self.transformer.text_model.encoder.parameters(): - param.requires_grad = False - for param in self.transformer.text_model.final_layer_norm.parameters(): - param.requires_grad = False - for param in self.transformer.text_model.embeddings.parameters(): - param.requires_grad = False - for param in self.transformer.get_input_embeddings().parameters(): - param.requires_grad = True - print("making grad true") - else: - self.transformer = self.transformer.eval() - - for param in self.parameters(): - param.requires_grad = False - - def _build_causal_attention_mask(self, bsz, seq_len, dtype): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) - mask.fill_(torch.tensor(torch.finfo(dtype).min)) - mask.triu_(1) # zero out the lower diagonal - mask = mask.unsqueeze(1) # expand mask - return mask - - @autocast - def custom_forward(self, hidden_states, input_ids): - r""" - Returns: - """ - input_shape = hidden_states.size() - bsz, seq_len = input_shape[:2] - if version.parse(transformers.__version__) >= version.parse('4.21'): - causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( - hidden_states.device - ) - else: - causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len).to( - hidden_states.device - ) - - encoder_outputs = self.transformer.text_model.encoder( - inputs_embeds=hidden_states, - causal_attention_mask=causal_attention_mask, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.transformer.text_model.final_layer_norm(last_hidden_state) - - return last_hidden_state - - @autocast - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt" - ) - tokens = batch_encoding["input_ids"].to(self.device) - - if self.modifier_token is not None: - indices = tokens == self.modifier_token_id[-1] - for token_id in self.modifier_token_id: - indices |= tokens == token_id - - indices = (indices*1).unsqueeze(-1) - - input_shape = tokens.size() - tokens = tokens.view(-1, input_shape[-1]) - - hidden_states = self.transformer.text_model.embeddings(input_ids=tokens) - if self.modifier_token is not None: - hidden_states = (1-indices)*hidden_states.detach() + indices*hidden_states - z = self.custom_forward(hidden_states, tokens) - return z - - def encode(self, text): - return self(text) - - -class FrozenOpenCLIPEmbedder2(AbstractEmbModel): - """ - Uses the OpenCLIP transformer encoder for text - """ - - LAYERS = ["pooled", "last", "penultimate"] - - def __init__( - self, - arch="ViT-H-14", - version="laion2b_s32b_b79k", - device="cuda", - max_length=77, - freeze=True, - layer="last", - always_return_pooled=False, - legacy=True, - ): - super().__init__() - assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms( - arch, - device=torch.device("cpu"), - pretrained=version, - ) - del model.visual - self.model = model - self.modifier_token = None - - self.device = device - self.max_length = max_length - self.return_pooled = always_return_pooled - if freeze: - self.freeze() - self.layer = layer - if self.layer == "last": - self.layer_idx = 0 - elif self.layer == "penultimate": - self.layer_idx = 1 - else: - raise NotImplementedError() - self.legacy = legacy - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - @autocast - def forward(self, text): - tokens = open_clip.tokenize(text) - z = self.encode_with_transformer(tokens.to(self.device)) - if not self.return_pooled and self.legacy: - return z - if self.return_pooled: - assert not self.legacy - return z[self.layer], z["pooled"] - return z[self.layer] - - def encode_with_transformer(self, text): - x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] - x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - if self.legacy: - x = x[self.layer] - x = self.model.ln_final(x) - return x - else: - # x is a dict and will stay a dict - o = x["last"] - o = self.model.ln_final(o) - pooled = self.pool(o, text) - x["pooled"] = pooled - return x - - def pool(self, x, text): - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = ( - x[torch.arange(x.shape[0]), text.argmax(dim=-1)] - @ self.model.text_projection - ) - return x - - def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): - outputs = {} - for i, r in enumerate(self.model.transformer.resblocks): - if i == len(self.model.transformer.resblocks) - 1: - outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD - if ( - self.model.transformer.grad_checkpointing - and not torch.jit.is_scripting() - ): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - outputs["last"] = x.permute(1, 0, 2) # LND -> NLD - return outputs - - def encode(self, text): - return self(text) - - -class FrozenOpenCLIPEmbedder(AbstractEmbModel): - LAYERS = [ - # "pooled", - "last", - "penultimate", - ] - - def __init__( - self, - modifier_token=None, - arch="ViT-H-14", - version="laion2b_s32b_b79k", - device="cuda", - max_length=77, - freeze=True, - layer="last", - always_return_pooled=False, - legacy=True, - ): - super().__init__() - assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms( - arch, device=torch.device("cpu"), pretrained=version - ) - del model.visual - self.model = model - - self.device = device - self.max_length = max_length - self.modifier_token = modifier_token - self.return_pooled = always_return_pooled - if self.modifier_token is not None: - if '+' in self.modifier_token: - self.modifier_token = self.modifier_token.split('+') - else: - self.modifier_token = [self.modifier_token] - self.tokenizer = SimpleTokenizer(additional_special_tokens=self.modifier_token) - - self.add_token() - else: - self.tokenizer = SimpleTokenizer() - - if freeze: - self.freeze() - self.layer = layer - if self.layer == "last": - self.layer_idx = 0 - elif self.layer == "penultimate": - self.layer_idx = 1 - else: - raise NotImplementedError() - self.legacy = legacy - - def tokenize(self, texts, context_length=77): - return self.tokenizer(texts, context_length=context_length) - - def add_token(self): - self.modifier_token_id = [] - - token_embeds1 = self.model.token_embedding.weight.data - for each_modifier_token in self.modifier_token: - modifier_token_id = self.tokenizer.encoder[each_modifier_token] - self.modifier_token_id.append(modifier_token_id) - - self.model.token_embedding = nn.Embedding(token_embeds1.shape[0] + len(self.modifier_token), token_embeds1.shape[1]) - self.model.token_embedding.weight.data[:token_embeds1.shape[0]] = token_embeds1 - - self.model.token_embedding.weight.data[self.modifier_token_id[-1]] = token_embeds1[42170] - if len(self.modifier_token) == 2: - self.model.token_embedding.weight.data[self.modifier_token_id[-2]] = token_embeds1[47629] - - def freeze(self): - if self.modifier_token is not None: - self.model = self.model.eval() - for param in self.model.transformer.parameters(): - param.requires_grad = False - for param in self.model.ln_final.parameters(): - param.requires_grad = False - self.model.text_projection.requires_grad = False - self.model.positional_embedding.requires_grad = False - for param in self.model.token_embedding.parameters(): - param.requires_grad = True - print("making grad true") - else: - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - @autocast - def forward(self, text): - tokens = self.tokenize(text) - z = self.encode_with_transformer(tokens.to(self.device)) - if not self.return_pooled and self.legacy: - return z - if self.return_pooled: - assert not self.legacy - return z[self.layer], z["pooled"] - return z[self.layer] - - def encode_with_transformer(self, text): - x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] - - if self.modifier_token is not None: - indices = text == self.modifier_token_id[-1] - for token_id in self.modifier_token_id: - indices |= text == token_id - - indices = (indices*1).unsqueeze(-1) - x = ((1-indices)*x.detach() + indices*x) + self.model.positional_embedding - else: - x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - if self.legacy: - x = x[self.layer] - x = self.model.ln_final(x) - return x - else: - # x is a dict and will stay a dict - o = x["last"] - o = self.model.ln_final(o) - pooled = self.pool(o, text) - x["pooled"] = pooled - return x - - def pool(self, x, text): - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = ( - x[torch.arange(x.shape[0]), text.argmax(dim=-1)] - @ self.model.text_projection - ) - return x - - def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): - outputs = {} - for i, r in enumerate(self.model.transformer.resblocks): - if i == len(self.model.transformer.resblocks) - 1: - outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD - if ( - self.model.transformer.grad_checkpointing - and not torch.jit.is_scripting() - ): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - outputs["last"] = x.permute(1, 0, 2) # LND -> NLD - return outputs - - def encode(self, text): - return self(text) - - -class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): - """ - Uses the OpenCLIP vision transformer encoder for images - """ - - def __init__( - self, - arch="ViT-H-14", - version="laion2b_s32b_b79k", - device="cuda", - max_length=77, - freeze=True, - antialias=True, - ucg_rate=0.0, - unsqueeze_dim=False, - repeat_to_max_len=False, - num_image_crops=0, - output_tokens=False, - init_device=None, - ): - super().__init__() - model, _, _ = open_clip.create_model_and_transforms( - arch, - device=torch.device(default(init_device, "cpu")), - pretrained=version, - ) - del model.transformer - self.model = model - self.max_crops = num_image_crops - self.pad_to_max_len = self.max_crops > 0 - self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - - self.antialias = antialias - - self.register_buffer( - "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False - ) - self.register_buffer( - "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False - ) - self.ucg_rate = ucg_rate - self.unsqueeze_dim = unsqueeze_dim - self.stored_batch = None - self.model.visual.output_tokens = output_tokens - self.output_tokens = output_tokens - - def preprocess(self, x): - # normalize to [0,1] - x = kornia.geometry.resize( - x, - (224, 224), - interpolation="bicubic", - align_corners=True, - antialias=self.antialias, - ) - x = (x + 1.0) / 2.0 - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - @autocast - def forward(self, image, no_dropout=False): - z = self.encode_with_vision_transformer(image) - tokens = None - if self.output_tokens: - z, tokens = z[0], z[1] - z = z.to(image.dtype) - if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): - z = ( - torch.bernoulli( - (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) - )[:, None] - * z - ) - if tokens is not None: - tokens = ( - expand_dims_like( - torch.bernoulli( - (1.0 - self.ucg_rate) - * torch.ones(tokens.shape[0], device=tokens.device) - ), - tokens, - ) - * tokens - ) - if self.unsqueeze_dim: - z = z[:, None, :] - if self.output_tokens: - assert not self.repeat_to_max_len - assert not self.pad_to_max_len - return tokens, z - if self.repeat_to_max_len: - if z.dim() == 2: - z_ = z[:, None, :] - else: - z_ = z - return repeat(z_, "b 1 d -> b n d", n=self.max_length), z - elif self.pad_to_max_len: - assert z.dim() == 3 - z_pad = torch.cat( - ( - z, - torch.zeros( - z.shape[0], - self.max_length - z.shape[1], - z.shape[2], - device=z.device, - ), - ), - 1, - ) - return z_pad, z_pad[:, 0, ...] - return z - - def encode_with_vision_transformer(self, img): - # if self.max_crops > 0: - # img = self.preprocess_by_cropping(img) - if img.dim() == 5: - assert self.max_crops == img.shape[1] - img = rearrange(img, "b n c h w -> (b n) c h w") - img = self.preprocess(img) - if not self.output_tokens: - assert not self.model.visual.output_tokens - x = self.model.visual(img) - tokens = None - else: - assert self.model.visual.output_tokens - x, tokens = self.model.visual(img) - if self.max_crops > 0: - x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) - # drop out between 0 and all along the sequence axis - x = ( - torch.bernoulli( - (1.0 - self.ucg_rate) - * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) - ) - * x - ) - if tokens is not None: - tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) - print( - f"You are running very experimental token-concat in {self.__class__.__name__}. " - f"Check what you are doing, and then remove this message." - ) - if self.output_tokens: - return x, tokens - return x - - def encode(self, text): - return self(text) - - -class FrozenCLIPT5Encoder(AbstractEmbModel): - def __init__( - self, - clip_version="openai/clip-vit-large-patch14", - t5_version="google/t5-v1_1-xl", - device="cuda", - clip_max_length=77, - t5_max_length=77, - ): - super().__init__() - self.clip_encoder = FrozenCLIPEmbedder( - clip_version, device, max_length=clip_max_length - ) - self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print( - f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." - ) - - def encode(self, text): - return self(text) - - def forward(self, text): - clip_z = self.clip_encoder.encode(text) - t5_z = self.t5_encoder.encode(text) - return [clip_z, t5_z] - - -class SpatialRescaler(nn.Module): - def __init__( - self, - n_stages=1, - method="bilinear", - multiplier=0.5, - in_channels=3, - out_channels=None, - bias=False, - wrap_video=False, - kernel_size=1, - remap_output=False, - ): - super().__init__() - self.n_stages = n_stages - assert self.n_stages >= 0 - assert method in [ - "nearest", - "linear", - "bilinear", - "trilinear", - "bicubic", - "area", - ] - self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) - self.remap_output = out_channels is not None or remap_output - if self.remap_output: - print( - f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." - ) - self.channel_mapper = nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - bias=bias, - padding=kernel_size // 2, - ) - self.wrap_video = wrap_video - - def forward(self, x): - if self.wrap_video and x.ndim == 5: - B, C, T, H, W = x.shape - x = rearrange(x, "b c t h w -> b t c h w") - x = rearrange(x, "b t c h w -> (b t) c h w") - - for stage in range(self.n_stages): - x = self.interpolator(x, scale_factor=self.multiplier) - - if self.wrap_video: - x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) - x = rearrange(x, "b t c h w -> b c t h w") - if self.remap_output: - x = self.channel_mapper(x) - return x - - def encode(self, x): - return self(x) - - -class LowScaleEncoder(nn.Module): - def __init__( - self, - model_config, - linear_start, - linear_end, - timesteps=1000, - max_noise_level=250, - output_size=64, - scale_factor=1.0, - ): - super().__init__() - self.max_noise_level = max_noise_level - self.model = instantiate_from_config(model_config) - self.augmentation_schedule = self.register_schedule( - timesteps=timesteps, linear_start=linear_start, linear_end=linear_end - ) - self.out_size = output_size - self.scale_factor = scale_factor - - def register_schedule( - self, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - alphas = 1.0 - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert ( - alphas_cumprod.shape[0] == self.num_timesteps - ), "alphas have to be defined for each timestep" - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer("betas", to_torch(betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) - ) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise - ) - - def forward(self, x): - z = self.model.encode(x) - if isinstance(z, DiagonalGaussianDistribution): - z = z.sample() - z = z * self.scale_factor - noise_level = torch.randint( - 0, self.max_noise_level, (x.shape[0],), device=x.device - ).long() - z = self.q_sample(z, noise_level) - if self.out_size is not None: - z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") - return z, noise_level - - def decode(self, z): - z = z / self.scale_factor - return self.model.decode(z) - - -class ConcatTimestepEmbedderND(AbstractEmbModel): - """embeds each dimension independently and concatenates them""" - - def __init__(self, outdim): - super().__init__() - self.timestep = Timestep(outdim) - self.outdim = outdim - self.modifier_token = None - - def forward(self, x): - if x.ndim == 1: - x = x[:, None] - assert len(x.shape) == 2 - b, dims = x.shape[0], x.shape[1] - x = rearrange(x, "b d -> (b d)") - emb = self.timestep(x) - emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) - return emb - - -class GaussianEncoder(Encoder, AbstractEmbModel): - def __init__( - self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs - ): - super().__init__(*args, **kwargs) - self.posterior = DiagonalGaussianRegularizer() - self.weight = weight - self.flatten_output = flatten_output - - def forward(self, x) -> Tuple[Dict, torch.Tensor]: - z = super().forward(x) - z, log = self.posterior(z) - log["loss"] = log["kl_loss"] - log["weight"] = self.weight - if self.flatten_output: - z = rearrange(z, "b c h w -> b (h w ) c") - return log, z - diff --git a/sgm/modules/nerfsd_pytorch3d.py b/sgm/modules/nerfsd_pytorch3d.py deleted file mode 100644 index abe5c862771b4f7fa4207be90d813dcf3432c589..0000000000000000000000000000000000000000 --- a/sgm/modules/nerfsd_pytorch3d.py +++ /dev/null @@ -1,468 +0,0 @@ -import math -import sys -import itertools - -import numpy as np -import torch.nn as nn -import torch.nn.functional as F -import torch -from einops import rearrange -from ..modules.utils_cameraray import ( - get_patch_rays, - get_plucker_parameterization, - positional_encoding, - convert_to_view_space, - convert_to_view_space_points, - convert_to_target_space, -) - - -from pytorch3d.renderer import ray_bundle_to_ray_points -from pytorch3d.renderer.implicit.raysampling import RayBundle as RayBundle -from pytorch3d import _C - -from ..modules.diffusionmodules.util import zero_module - - -class FeatureNeRFEncoding(nn.Module): - def __init__( - self, - in_channels, - out_channels, - far_plane: float = 2.0, - rgb_predict=False, - average=False, - num_freqs=16, - ) -> None: - super().__init__() - - self.far_plane = far_plane - self.rgb_predict = rgb_predict - self.average = average - self.num_freqs = num_freqs - dim = 3 - self.plane_coefs = nn.Sequential( - nn.Linear(in_channels + self.num_freqs * dim * 4 + 2 * dim, out_channels), - nn.SiLU(), - nn.Linear(out_channels, out_channels), - ) - if not self.average: - self.nviews = nn.Linear( - in_channels + self.num_freqs * dim * 4 + 2 * dim, 1 - ) - self.decoder = zero_module( - nn.Linear(out_channels, 1 + (3 if rgb_predict else 0), bias=False) - ) - - def forward(self, pose, xref, ray_points, rays, mask_ref): - # xref : [b, n, hw, c] - # ray_points: [b, n+1, hw, d, 3] - # rays: [b, n+1, hw, 6] - - b, n, hw, c = xref.shape - d = ray_points.shape[3] - res = int(math.sqrt(hw)) - if mask_ref is not None: - mask_ref = torch.nn.functional.interpolate( - rearrange( - mask_ref, - "b n ... -> (b n) ...", - ), - size=[res, res], - mode="nearest", - ).reshape(b, n, -1, 1) - xref = xref * mask_ref - - volume = [] - for i, cam in enumerate(pose): - volume.append( - cam.transform_points_ndc(ray_points[i, 0].reshape(-1, 3)).reshape(n + 1, hw, d, 3)[..., :2] - ) - volume = torch.stack(volume) - - plane_features = F.grid_sample( - rearrange( - xref, - "b n (h w) c -> (b n) c h w", - b=b, - h=int(math.sqrt(hw)), - w=int(math.sqrt(hw)), - c=c, - n=n, - ), - torch.clip( - torch.nan_to_num( - rearrange(-1 * volume[:, 1:].detach(), "b n ... -> (b n) ...") - ), - -1.2, - 1.2, - ), - align_corners=True, - padding_mode="zeros", - ) # [bn, c, hw, d] - - plane_features = rearrange(plane_features, "(b n) ... -> b n ...", b=b, n=n) - - xyz_grid_features_inviewframe = convert_to_view_space_points(pose, ray_points[:, 0]) - xyz_grid_features_inviewframe_encoding = positional_encoding(xyz_grid_features_inviewframe, self.num_freqs) - camera_features_inviewframe = ( - convert_to_view_space(pose, rays[:, 0])[:, 1:] - .reshape(b, n, hw, 1, -1) - .expand(-1, -1, -1, d, -1) - ) - camera_features_inviewframe_encoding = positional_encoding( - get_plucker_parameterization(camera_features_inviewframe), - self.num_freqs // 2, - ) - xyz_grid_features = xyz_grid_features_inviewframe_encoding[:, :1].expand( - -1, n, -1, -1, -1 - ) - camera_features = ( - (convert_to_target_space(pose, rays[:, 1:])[..., :3]) - .reshape(b, n, hw, 1, -1) - .expand(-1, -1, -1, d, -1) - ) - camera_features_encoding = positional_encoding( - camera_features, self.num_freqs - ) - plane_features_final = self.plane_coefs( - torch.cat( - [ - plane_features.permute(0, 1, 3, 4, 2), - xyz_grid_features_inviewframe_encoding[:, 1:], - xyz_grid_features_inviewframe[:, 1:], - camera_features_inviewframe_encoding, - camera_features_inviewframe[..., 3:], - ], - dim=-1, - ) - ) # b, n, hw, d, c - - # plane_features = torch.cat([plane_features, xyz_grid_features, camera_features], dim=1) - if not self.average: - plane_features_attn = nn.functional.softmax( - self.nviews( - torch.cat( - [ - plane_features.permute(0, 1, 3, 4, 2), - xyz_grid_features, - xyz_grid_features_inviewframe[:, :1].expand(-1, n, -1, -1, -1), - camera_features, - camera_features_encoding, - ], - dim=-1, - ) - ), - dim=1, - ) # b, n, hw, d, 1 - - plane_features_final = (plane_features_final * plane_features_attn).sum(1) - else: - plane_features_final = plane_features_final.mean(1) - plane_features_attn = None - - out = self.decoder(plane_features_final) - return torch.cat([plane_features_final, out], dim=-1), plane_features_attn - - -class VolRender(nn.Module): - def __init__( - self, - ): - super().__init__() - - def get_weights(self, densities, deltas): - """Return weights based on predicted densities - - Args: - densities: Predicted densities for samples along ray - - Returns: - Weights for each sample - """ - delta_density = deltas * densities # [b, hw, "num_samples", 1] - alphas = 1 - torch.exp(-delta_density) - transmittance = torch.cumsum(delta_density[..., :-1, :], dim=-2) - transmittance = torch.cat( - [ - torch.zeros((*transmittance.shape[:2], 1, 1), device=densities.device), - transmittance, - ], - dim=-2, - ) - transmittance = torch.exp(-transmittance) # [b, hw, "num_samples", 1] - - weights = alphas * transmittance # [b, hw, "num_samples", 1] - weights = torch.nan_to_num(weights) - # opacities = 1.0 - torch.prod(1.0 - alphas, dim=-2, keepdim=True) - return weights, alphas, transmittance - - def forward( - self, - features, - densities, - dists=None, - return_weight=False, - densities_uniform=None, - dists_uniform=None, - return_weights_uniform=False, - rgb=None - ): - alphas = None - fg_mask = None - if dists is not None: - weights, alphas, transmittance = self.get_weights(densities, dists) - fg_mask = torch.sum(weights, -2) - else: - weights = densities # used when we have a pretraind nerf with direct weights as output - - rendered_feats = torch.sum(weights * features, dim=-2) + torch.sum( - (1 - weights) * torch.zeros_like(features), dim=-2 - ) - if rgb is not None: - rgb = torch.sum(weights * rgb, dim=-2) + torch.sum( - (1 - weights) * torch.zeros_like(rgb), dim=-2 - ) - # print("RENDER", fg_mask.shape, weights.shape) - weights_uniform = None - if return_weight: - return rendered_feats, fg_mask, alphas, weights, rgb - elif return_weights_uniform: - if densities_uniform is not None: - weights_uniform, _, transmittance = self.get_weights(densities_uniform, dists_uniform) - return rendered_feats, fg_mask, alphas, weights_uniform, rgb - else: - return rendered_feats, fg_mask, alphas, None, rgb - - -class Raymarcher(nn.Module): - def __init__( - self, - num_samples=32, - far_plane=2.0, - stratified=False, - training=True, - imp_sampling_percent=0.9, - near_plane=0., - ): - super().__init__() - self.num_samples = num_samples - self.far_plane = far_plane - self.near_plane = near_plane - u_max = 1. / (self.num_samples) - u = torch.linspace(0, 1 - u_max, self.num_samples, device="cuda") - self.register_buffer("u", u) - lengths = torch.linspace(self.near_plane, self.near_plane+self.far_plane, self.num_samples+1, device="cuda") - # u = (u[..., :-1] + u[..., 1:]) / 2.0 - lengths_center = (lengths[..., 1:] + lengths[..., :-1]) / 2.0 - lengths_upper = torch.cat([lengths_center, lengths[..., -1:]], -1) - lengths_lower = torch.cat([lengths[..., :1], lengths_center], -1) - self.register_buffer("lengths", lengths) - self.register_buffer("lengths_center", lengths_center) - self.register_buffer("lengths_upper", lengths_upper) - self.register_buffer("lengths_lower", lengths_lower) - self.stratified = stratified - self.training = training - self.imp_sampling_percent = imp_sampling_percent - - @torch.no_grad() - def importance_sampling(self, cdf, num_rays, num_samples, device): - # sample target rays for each reference view - cdf = cdf[..., 0] + 0.01 - if cdf.shape[1] != num_rays: - size = int(math.sqrt(num_rays)) - size_ = int(math.sqrt(cdf.size(1))) - cdf = rearrange( - torch.nn.functional.interpolate( - rearrange( - cdf.permute(0, 2, 1), "... (h w) -> ... h w", h=size_, w=size_ - ), - size=[size, size], - antialias=True, - mode="bilinear", - ), - "... h w -> ... (h w)", - h=size, - w=size, - ).permute(0, 2, 1) - - lengths = self.lengths[None, None, :].expand(cdf.shape[0], num_rays, -1) - - cdf_sum = torch.sum(cdf, dim=-1, keepdim=True) - padding = torch.relu(1e-5 - cdf_sum) - cdf = cdf + padding / cdf.shape[-1] - cdf_sum += padding - - pdf = cdf / cdf_sum - - # sample_pdf function - u_max = 1. / (num_samples) - u = self.u[None, None, :].expand(cdf.shape[0], num_rays, -1) - if self.stratified and self.training: - u += torch.rand((cdf.shape[0], num_rays, num_samples), dtype=cdf.dtype, device=cdf.device,) * u_max - - _C.sample_pdf( - lengths.reshape(-1, num_samples + 1), - pdf.reshape(-1, num_samples), - u.reshape(-1, num_samples), - 1e-5, - ) - return u, torch.cat([u[..., 1:] - u[..., :-1], lengths[..., -1:] - u[..., -1:] ], -1) - - @torch.no_grad() - def stratified_sampling(self, num_rays, device, uniform=False): - lengths_uniform = self.lengths[None, None, :].expand(-1, num_rays, -1) - - if uniform: - return ( - (lengths_uniform[..., 1:] + lengths_uniform[..., :-1]) / 2.0, - lengths_uniform[..., 1:] - lengths_uniform[..., :-1], - ) - if self.stratified and self.training: - t_rand = torch.rand( - (num_rays, self.num_samples + 1), - dtype=lengths_uniform.dtype, - device=lengths_uniform.device, - ) - jittered = self.lengths_lower[None, None, :].expand(-1, num_rays, -1) + \ - (self.lengths_upper[None, None, :].expand(-1, num_rays, -1) - self.lengths_lower[None, None, :].expand(-1, num_rays, -1)) * t_rand - return ((jittered[..., :-1] + jittered[..., 1:])/2., jittered[..., 1:] - jittered[..., :-1]) - else: - return ( - (lengths_uniform[..., 1:] + lengths_uniform[..., :-1]) / 2.0, - lengths_uniform[..., 1:] - lengths_uniform[..., :-1], - ) - - @torch.no_grad() - def forward(self, pose, resolution, weights, imp_sample_next_step=False, device='cuda', pytorch3d=True): - input_patch_rays, xys = get_patch_rays( - pose, - num_patches_x=resolution, - num_patches_y=resolution, - device=device, - return_xys=True, - stratified=self.stratified and self.training, - ) # (b, n, h*w, 6) - - num_rays = resolution**2 - # sample target rays for each reference view - if weights is not None: - if self.imp_sampling_percent <= 0: - lengths, dists = self.stratified_sampling(num_rays, device) - elif (torch.rand(1) < (1.-self.imp_sampling_percent)) and self.training: - lengths, dists = self.stratified_sampling(num_rays, device) - else: - lengths, dists = self.importance_sampling( - weights, num_rays, self.num_samples, device=device - ) - else: - lengths, dists = self.stratified_sampling(num_rays, device) - - dists_uniform = None - ray_points_uniform = None - if imp_sample_next_step: - lengths_uniform, dists_uniform = self.stratified_sampling( - num_rays, device, uniform=True - ) - - target_patch_raybundle_uniform = RayBundle( - origins=input_patch_rays[:, :1, :, :3], - directions=input_patch_rays[:, :1, :, 3:], - lengths=lengths_uniform, - xys=xys.to(device), - ) - ray_points_uniform = ray_bundle_to_ray_points(target_patch_raybundle_uniform).detach() - dists_uniform = dists_uniform.detach() - - # print( - # "SAMPLING", - # lengths.shape, - # lengths_uniform.shape, - # dists.shape, - # dists_uniform.shape, - # input_patch_rays.shape, - # ) - target_patch_raybundle = RayBundle( - origins=input_patch_rays[:, :1, :, :3], - directions=input_patch_rays[:, :1, :, 3:], - lengths=lengths.to(device), - xys=xys.to(device), - ) - ray_points = ray_bundle_to_ray_points(target_patch_raybundle) - return ( - input_patch_rays.detach(), - ray_points.detach(), - dists.detach(), - ray_points_uniform, - dists_uniform, - ) - - -class NerfSDModule(nn.Module): - def __init__( - self, - mode="feature-nerf", - out_channels=None, - far_plane=2.0, - num_samples=32, - rgb_predict=False, - average=False, - num_freqs=16, - stratified=False, - imp_sampling_percent=0.9, - near_plane=0. - ): - MODES = { - "feature-nerf": FeatureNeRFEncoding, # ampere - } - super().__init__() - self.rgb_predict = rgb_predict - - self.raymarcher = Raymarcher( - num_samples=num_samples, - far_plane=near_plane + far_plane, - stratified=stratified, - imp_sampling_percent=imp_sampling_percent, - near_plane=near_plane, - ) - model_class = MODES[mode] - self.model = model_class( - out_channels, - out_channels, - far_plane=near_plane + far_plane, - rgb_predict=rgb_predict, - average=average, - num_freqs=num_freqs, - ) - - def forward(self, pose, xref=None, mask_ref=None, prev_weights=None, imp_sample_next_step=False,): - # xref: b n h w c or b n hw c - # pose: a list of pytorch3d cameras - # mask_ref: mask corresponding to black regions because of padding non square images. - rgb = None - dists_uniform = None - weights_uniform = None - resolution = (int(math.sqrt(xref.size(2))) if len(xref.shape) == 4 else xref.size(3)) - input_patch_rays, ray_points, dists, ray_points_uniform, dists_uniform = (self.raymarcher(pose, resolution, weights=prev_weights, device=xref.device)) - output, plane_features_attn = self.model(pose, xref, ray_points, input_patch_rays, mask_ref) - weights = output[..., -1:] - features = output[..., :-1] - if self.rgb_predict: - rgb = features[..., -3:] - features = features[..., :-3] - dists = dists.unsqueeze(-1) - with torch.no_grad(): - if ray_points_uniform is not None: - output_uniform, _ = self.model(pose, xref, ray_points_uniform, input_patch_rays, mask_ref) - weights_uniform = output_uniform[..., -1:] - dists_uniform = dists_uniform.unsqueeze(-1) - - return ( - features, - weights, - dists, - plane_features_attn, - rgb, - weights_uniform, - dists_uniform, - ) diff --git a/sgm/modules/utils_cameraray.py b/sgm/modules/utils_cameraray.py deleted file mode 100644 index 8814022357f4ad69c5984f0569a33b09e1275052..0000000000000000000000000000000000000000 --- a/sgm/modules/utils_cameraray.py +++ /dev/null @@ -1,391 +0,0 @@ -#### Code taken from: https://github.com/mayankgrwl97/gbt -"""Utils for ray manipulation""" - -import numpy as np -import torch -from pytorch3d.renderer.implicit.raysampling import RayBundle as RayBundle -from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import PerspectiveCameras - - -############################# RAY BUNDLE UTILITIES ############################# - -def is_scalar(x): - """Returns True if the provided variable is a scalar - - Args: - x: scalar or array-like (numpy array or torch tensor) - - Returns: - bool: True if x is of the type scalar, or array-like with 0 dimension. False, otherwise - - """ - if isinstance(x, float) or isinstance(x, int): - return True - - if isinstance(x, np.ndarray) and np.ndim(x) == 0: - return True - - if isinstance(x, torch.Tensor) and x.dim() == 0: - return True - - return False - - -def transform_rays(reference_R, reference_T, rays): - """ - PyTorch3D Convention is used: X_cam = X_world @ R + T - - Args: - reference_R: world2cam rotation matrix for reference camera (B, 3, 3) - reference_T: world2cam translation vector for reference camera (B, 3) - rays: (origin, direction) defined in world reference frame (B, V, N, 6) - Returns: - torch.Tensor: Transformed rays w.r.t. reference camera (B, V, N, 6) - """ - batch, num_views, num_rays, ray_dim = rays.shape - assert ( - ray_dim == 6 - ), "First 3 dimensions should be origin; Last 3 dimensions should be direction" - - rays = rays.reshape(batch, num_views * num_rays, ray_dim) - rays_out = rays.clone() - rays_out[..., :3] = torch.bmm(rays[..., :3], reference_R) + reference_T.unsqueeze( - -2 - ) - rays_out[..., 3:] = torch.bmm(rays[..., 3:], reference_R) - rays_out = rays_out.reshape(batch, num_views, num_rays, ray_dim) - return rays_out - - -def get_directional_raybundle(cameras, x_pos_ndc, y_pos_ndc, max_depth=1): - if is_scalar(x_pos_ndc): - x_pos_ndc = [x_pos_ndc] - if is_scalar(y_pos_ndc): - y_pos_ndc = [y_pos_ndc] - assert is_scalar(max_depth) - - if not isinstance(x_pos_ndc, torch.Tensor): - x_pos_ndc = torch.tensor(x_pos_ndc) # (N, ) - if not isinstance(y_pos_ndc, torch.Tensor): - y_pos_ndc = torch.tensor(y_pos_ndc) # (N, ) - - xy_depth = torch.stack( - (x_pos_ndc, y_pos_ndc, torch.ones_like(x_pos_ndc) * max_depth), dim=-1 - ) # (N, 3) - - num_points = xy_depth.shape[0] - - unprojected = cameras.unproject_points( - xy_depth.to(cameras.device), world_coordinates=True, from_ndc=True - ) # (N, 3) - unprojected = unprojected.unsqueeze(0).to("cpu") # (B, N, 3) - - origins = ( - cameras.get_camera_center()[:, None, :].expand(-1, num_points, -1).to("cpu") - ) # (B, N, 3) - directions = unprojected - origins # (B, N, 3) - directions = directions / directions.norm(dim=-1).unsqueeze(-1) # (B, N, 3) - lengths = ( - torch.tensor([[0, 3]]).unsqueeze(0).expand(-1, num_points, -1).to("cpu") - ) # (B, N, 2) - xys = xy_depth[:, :2].unsqueeze(0).to("cpu") # (B, N, 2) - - raybundle = RayBundle( - origins=origins.to("cpu"), - directions=directions.to("cpu"), - lengths=lengths.to("cpu"), - xys=xys.to("cpu"), - ) - return raybundle - - -def get_patch_raybundle( - cameras, num_patches_x, num_patches_y, max_depth=1, stratified=False -): - horizontal_patch_edges = torch.linspace(1, -1, num_patches_x + 1) - # horizontal_positions = horizontal_patch_edges[:-1] # (num_patches_x,): Top left corner of patch - - vertical_patch_edges = torch.linspace(1, -1, num_patches_y + 1) - # vertical_positions = vertical_patch_edges[:-1] # (num_patches_y,): Top left corner of patch - if stratified: - horizontal_patch_edges_center = ( - horizontal_patch_edges[..., 1:] + horizontal_patch_edges[..., :-1] - ) / 2.0 - horizontal_patch_edges_upper = torch.cat( - [horizontal_patch_edges_center, horizontal_patch_edges[..., -1:]], -1 - ) - horizontal_patch_edges_lower = torch.cat( - [horizontal_patch_edges[..., :1], horizontal_patch_edges_center], -1 - ) - horizontal_positions = ( - horizontal_patch_edges_lower - + (horizontal_patch_edges_upper - horizontal_patch_edges_lower) - * torch.rand_like(horizontal_patch_edges_lower) - )[..., :-1] - - vertical_patch_edges_center = ( - vertical_patch_edges[..., 1:] + vertical_patch_edges[..., :-1] - ) / 2.0 - vertical_patch_edges_upper = torch.cat( - [vertical_patch_edges_center, vertical_patch_edges[..., -1:]], -1 - ) - vertical_patch_edges_lower = torch.cat( - [vertical_patch_edges[..., :1], vertical_patch_edges_center], -1 - ) - vertical_positions = ( - vertical_patch_edges_lower - + (vertical_patch_edges_upper - vertical_patch_edges_lower) - * torch.rand_like(vertical_patch_edges_lower) - )[..., :-1] - else: - horizontal_positions = ( - horizontal_patch_edges[:-1] + horizontal_patch_edges[1:] - ) / 2 # (num_patches_x, ) # Center of patch - vertical_positions = ( - vertical_patch_edges[:-1] + vertical_patch_edges[1:] - ) / 2 # (num_patches_y, ) # Center of patch - - h_pos, v_pos = torch.meshgrid( - horizontal_positions, vertical_positions, indexing='xy' - ) # (num_patches_y, num_patches_x), (num_patches_y, num_patches_x) - h_pos = h_pos.reshape(-1) # (num_patches_y * num_patches_x) - v_pos = v_pos.reshape(-1) # (num_patches_y * num_patches_x) - - raybundle = get_directional_raybundle( - cameras=cameras, x_pos_ndc=h_pos, y_pos_ndc=v_pos, max_depth=max_depth - ) - return raybundle - - -def get_patch_rays( - cameras_list, - num_patches_x, - num_patches_y, - device, - return_xys=False, - stratified=False, -): - """Returns patch rays given the camera viewpoints - - Args: - cameras_list(list[pytorch3d.renderer.cameras.BaseCameras]): List of list of cameras (len (batch_size, num_input_views,)) - num_patches_x: Number of patches in the x-direction (horizontal) - num_patches_y: Number of patches in the y-direction (vertical) - - Returns: - torch.tensor: Patch rays of shape (batch_size, num_views, num_patches, 6) - """ - batch, numviews = len(cameras_list), len(cameras_list[0]) - cameras_list = join_cameras_as_batch([cam for cam_batch in cameras_list for cam in cam_batch]) - patch_rays = get_patch_raybundle( - cameras_list, - num_patches_y=num_patches_y, - num_patches_x=num_patches_x, - stratified=stratified, - ) - if return_xys: - xys = patch_rays.xys - - patch_rays = torch.cat((patch_rays.origins.unsqueeze(0), patch_rays.directions), dim=-1) - patch_rays = patch_rays.reshape( - batch, numviews, num_patches_x * num_patches_y, 6 - ).to(device) - if return_xys: - return patch_rays, xys - return patch_rays - -############################ RAY PARAMETERIZATION ############################## - - -def get_plucker_parameterization(ray): - """Returns the plucker representation of the rays given the (origin, direction) representation - - Args: - ray(torch.Tensor): Tensor of shape (..., 6) with the (origin, direction) representation - - Returns: - torch.Tensor: Tensor of shape (..., 6) with the plucker (D, OxD) representation - """ - ray = ray.clone() # Create a clone - ray_origins = ray[..., :3] - ray_directions = ray[..., 3:] - ray_directions = ray_directions / ray_directions.norm(dim=-1).unsqueeze( - -1 - ) # Normalize ray directions to unit vectors - plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) - plucker_parameterization = torch.cat([ray_directions, plucker_normal], dim=-1) - - return plucker_parameterization - - -def positional_encoding(ray, n_freqs=10, start_freq=0): - """ - Positional Embeddings. For more details see Section 5.1 of - NeRFs: https://arxiv.org/pdf/2003.08934.pdf - - Args: - ray: (B,P,d) - n_freqs: num of frequency bands - parameterize(str|None): Parameterization used for rays. Recommended: use 'plucker'. Default=None. - - Returns: - pos_embeddings: Mapping input ray from R to R^{2*n_freqs}. - """ - start_freq = -1 * (n_freqs / 2) - freq_bands = 2.0 ** torch.arange(start_freq, start_freq + n_freqs) * np.pi - sin_encodings = [torch.sin(ray * freq) for freq in freq_bands] - cos_encodings = [torch.cos(ray * freq) for freq in freq_bands] - pos_embeddings = torch.cat( - sin_encodings + cos_encodings, dim=-1 - ) # B, P, d * 2n_freqs - return pos_embeddings - - -def convert_to_target_space(input_cameras, input_rays): - input_rays_transformed = [] - # input_cameras: b, N - # input_rays: b, N, hw, 6 - # return: b, N, hw, 6 - for i in range(len(input_cameras[0])): - reference_cameras = [cameras[0] for cameras in input_cameras] - reference_R = [ - camera.R.to(input_rays.device) for camera in reference_cameras - ] # List (length=batch_size) of Rs(shape: 1, 3, 3) - reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) - reference_T = [ - camera.T.to(input_rays.device) for camera in reference_cameras - ] # List (length=batch_size) of Ts(shape: 1, 3) - reference_T = torch.cat(reference_T, dim=0) # (B, 3) - input_rays_transformed.append( - transform_rays( - reference_R=reference_R, - reference_T=reference_T, - rays=input_rays[:, i : i + 1], - ) - ) - return torch.cat(input_rays_transformed, 1) - - -def convert_to_view_space(input_cameras, input_rays): - input_rays_transformed = [] - # input_cameras: b, N - # input_rays: b, hw, 6 - # return: b, n, hw, 6 - for i in range(len(input_cameras[0])): - reference_cameras = [cameras[i] for cameras in input_cameras] - reference_R = [ - camera.R.to(input_rays.device) for camera in reference_cameras - ] # List (length=batch_size) of Rs(shape: 1, 3, 3) - reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) - reference_T = [ - camera.T.to(input_rays.device) for camera in reference_cameras - ] # List (length=batch_size) of Ts(shape: 1, 3) - reference_T = torch.cat(reference_T, dim=0) # (B, 3) - input_rays_transformed.append( - transform_rays( - reference_R=reference_R, - reference_T=reference_T, - rays=input_rays.unsqueeze(1), - ) - ) - return torch.cat(input_rays_transformed, 1) - - -def convert_to_view_space_points(input_cameras, input_points): - input_rays_transformed = [] - # input_cameras: b, N - # ipput_points: b, hw, d, 3 - # returns: b, N, hw, d, 3 [target points transformed in the reference view frame] - for i in range(len(input_cameras[0])): - reference_cameras = [cameras[i] for cameras in input_cameras] - reference_R = [ - camera.R.to(input_points.device) for camera in reference_cameras - ] # List (length=batch_size) of Rs(shape: 1, 3, 3) - reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) - reference_T = [ - camera.T.to(input_points.device) for camera in reference_cameras - ] # List (length=batch_size) of Ts(shape: 1, 3) - reference_T = torch.cat(reference_T, dim=0) # (B, 3) - input_points_clone = torch.einsum( - "bsdj,bjk->bsdk", input_points, reference_R - ) + reference_T.reshape(-1, 1, 1, 3) - input_rays_transformed.append(input_points_clone.unsqueeze(1)) - return torch.cat(input_rays_transformed, dim=1) - - -def interpolate_translate_interpolate_xaxis(cam1, interp_start, interp_end, interp_step): - cameras = [] - for i in np.arange(interp_start, interp_end, interp_step): - viewtoworld = cam1.get_world_to_view_transform().inverse() - - x_axis = torch.from_numpy(np.array([i, 0., 0.0])).reshape(1,3).float().to(cam1.device) - newc = viewtoworld.transform_points(x_axis) - rt = cam1.R[0] - # t = cam1.T - new_t = -rt.T@newc.T - - cameras.append(PerspectiveCameras(R=cam1.R, - T=new_t.T, - focal_length=cam1.focal_length, - principal_point=cam1.principal_point, - image_size=512, - ) - ) - return cameras - - -def interpolate_translate_interpolate_yaxis(cam1, interp_start, interp_end, interp_step): - cameras = [] - for i in np.arange(interp_start, interp_end, interp_step): - # i = np.clip(i, -0.2, 0.18) - viewtoworld = cam1.get_world_to_view_transform().inverse() - - x_axis = torch.from_numpy(np.array([0, i, 0.0])).reshape(1,3).float().to(cam1.device) - newc = viewtoworld.transform_points(x_axis) - rt = cam1.R[0] - # t = cam1.T - new_t = -rt.T@newc.T - - cameras.append(PerspectiveCameras(R=cam1.R, - T=new_t.T, - focal_length=cam1.focal_length, - principal_point=cam1.principal_point, - image_size=512, - ) - ) - return cameras - - -def interpolate_translate_interpolate_zaxis(cam1, interp_start, interp_end, interp_step): - cameras = [] - for i in np.arange(interp_start, interp_end, interp_step): - viewtoworld = cam1.get_world_to_view_transform().inverse() - - x_axis = torch.from_numpy(np.array([0, 0., i])).reshape(1,3).float().to(cam1.device) - newc = viewtoworld.transform_points(x_axis) - rt = cam1.R[0] - # t = cam1.T - new_t = -rt.T@newc.T - - cameras.append(PerspectiveCameras(R=cam1.R, - T=new_t.T, - focal_length=cam1.focal_length, - principal_point=cam1.principal_point, - image_size=512, - ) - ) - return cameras - - -def interpolatefocal(cam1, interp_start, interp_end, interp_step): - cameras = [] - for i in np.arange(interp_start, interp_end, interp_step): - cameras.append(PerspectiveCameras(R=cam1.R, - T=cam1.T, - focal_length=cam1.focal_length*i, - principal_point=cam1.principal_point, - image_size=512, - ) - ) - return cameras diff --git a/sgm/util.py b/sgm/util.py deleted file mode 100644 index c0fbfeb210580c526c17fcdb190b166bfc6378dc..0000000000000000000000000000000000000000 --- a/sgm/util.py +++ /dev/null @@ -1,297 +0,0 @@ -import functools -import importlib -import os -from functools import partial -from inspect import isfunction - -import fsspec -import numpy as np -import torch -from PIL import Image, ImageDraw, ImageFont -from safetensors.torch import load_file as load_safetensors - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def get_string_from_tuple(s): - try: - # Check if the string starts and ends with parentheses - if s[0] == "(" and s[-1] == ")": - # Convert the string to a tuple - t = eval(s) - # Check if the type of t is tuple - if type(t) == tuple: - return t[0] - else: - pass - except: - pass - return s - - -def is_power_of_two(n): - """ - chat.openai.com/chat - Return True if n is a power of 2, otherwise return False. - - The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. - The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. - If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. - Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. - - """ - if n <= 0: - return False - return (n & (n - 1)) == 0 - - -def autocast(f, enabled=True): - def do_autocast(*args, **kwargs): - with torch.cuda.amp.autocast( - enabled=enabled, - dtype=torch.get_autocast_gpu_dtype(), - cache_enabled=torch.is_autocast_cache_enabled(), - ): - return f(*args, **kwargs) - - return do_autocast - - -def load_partial_from_config(config): - return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) - - -def log_txt_as_img(wh, xc, size=10): - # wh a tuple of (width, height) - # xc a list of captions to plot - b = len(xc) - txts = list() - for bi in range(b): - txt = Image.new("RGB", wh, color="white") - draw = ImageDraw.Draw(txt) - font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) - nc = int(40 * (wh[0] / 256)) - if isinstance(xc[bi], list): - text_seq = xc[bi][0] - else: - text_seq = xc[bi] - lines = "\n".join( - text_seq[start : start + nc] for start in range(0, len(text_seq), nc) - ) - - try: - draw.text((0, 0), lines, fill="black", font=font) - except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") - - txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 - txts.append(txt) - txts = np.stack(txts) - txts = torch.tensor(txts) - return txts - - -def partialclass(cls, *args, **kwargs): - class NewCls(cls): - __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) - - return NewCls - - -def make_path_absolute(path): - fs, p = fsspec.core.url_to_fs(path) - if fs.protocol == "file": - return os.path.abspath(p) - return path - - -def ismap(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] > 3) - - -def isimage(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) - - -def isheatmap(x): - if not isinstance(x, torch.Tensor): - return False - - return x.ndim == 2 - - -def isneighbors(x): - if not isinstance(x, torch.Tensor): - return False - return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) - - -def exists(x): - return x is not None - - -def expand_dims_like(x, y): - while x.dim() != y.dim(): - x = x.unsqueeze(-1) - return x - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def mean_flat(tensor): - """ - https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if config == "__is_first_stage__": - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False, invalidate_cache=True): - module, cls = string.rsplit(".", 1) - if invalidate_cache: - importlib.invalidate_caches() - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) - return x[(...,) + (None,) * dims_to_append] - - -def load_model_from_config(config, ckpt, delta_ckpt=None, verbose=True, freeze=True): - config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors" - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - model = instantiate_from_config(config.model) - - if delta_ckpt is not None: - token_weights1 = sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] - token_weights2 = sd['conditioner.embedders.1.model.token_embedding.weight'] - del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] - del sd['conditioner.embedders.1.model.token_embedding.weight'] - - m, u = model.load_state_dict(sd, strict=False) - - ## Load delta ckpt - if delta_ckpt is not None: - pl_sd_delta = torch.load(delta_ckpt, map_location="cpu") - sd_delta = pl_sd_delta["delta_state_dict"] - model.conditioner.embedders[0].transformer.text_model.embeddings.token_embedding.weight.data = torch.cat([token_weights1, sd_delta['embed'][0]], 0).to(model.device) - model.conditioner.embedders[1].model.token_embedding.weight.data = torch.cat([token_weights2, sd_delta['embed'][1]], 0).to(model.device) - del sd_delta['embed'] - for name, module in model.model.diffusion_model.named_modules(): - if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': - if hasattr(module, 'pose_emb_layers'): - module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references']) - del sd_delta[f'model.diffusion_model.{name}.references'] - - m, u = model.load_state_dict(sd_delta, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - if freeze: - for param in model.parameters(): - param.requires_grad = False - - model.eval() - return model - - -def get_configs_path() -> str: - """ - Get the `configs` directory. - For a working copy, this is the one in the root of the repository, - but for an installed copy, it's in the `sgm` package (see pyproject.toml). - """ - this_dir = os.path.dirname(__file__) - candidates = ( - os.path.join(this_dir, "configs"), - os.path.join(this_dir, "..", "configs"), - ) - for candidate in candidates: - candidate = os.path.abspath(candidate) - if os.path.isdir(candidate): - return candidate - raise FileNotFoundError(f"Could not find SGM configs in {candidates}") - - -def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): - """ - Will return the result of a recursive get attribute call. - E.g.: - a.b.c - = getattr(getattr(a, "b"), "c") - = get_nested_attribute(a, "b.c") - If any part of the attribute call is an integer x with current obj a, will - try to call a[x] instead of a.x first. - """ - attributes = attribute_path.split(".") - if depth is not None and depth > 0: - attributes = attributes[:depth] - assert len(attributes) > 0, "At least one attribute should be selected" - current_attribute = obj - current_key = None - for level, attribute in enumerate(attributes): - current_key = ".".join(attributes[: level + 1]) - try: - id_ = int(attribute) - current_attribute = current_attribute[id_] - except ValueError: - current_attribute = getattr(current_attribute, attribute) - - return (current_attribute, current_key) if return_key else current_attribute