diff --git a/README.md b/README.md index 8fb44f7c6c6045c9106e20accfa1d3e6df0dc4f0..4e741fb61c446b6e724d747412c633b173e1a1c6 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- title: LiDAR Diffusion -emoji: 📚 -colorFrom: blue -colorTo: green +emoji: 🚙🛞🚨 +colorFrom: green +colorTo: indigo sdk: gradio sdk_version: 4.26.0 app_file: app.py diff --git a/app.py b/app.py index d860a45d46dcefd21abcdb62e4f2b5237a938485..3eb1fe7039b19818959a83ab682e32b92e0d730c 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,81 @@ import gradio as gr +import spaces +import tempfile +import os +import torch +import numpy as np +from matplotlib.colors import LinearSegmentedColormap +from app_config import CSS, TITLE, DESCRIPTION, DEVICE +import sample_cond -def greet(name): - return "Hello " + name + "!!" +model = sample_cond.load_model() -iface = gr.Interface(fn=greet, inputs="text", outputs="text") -iface.launch() +def create_custom_colormap(): + colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)] + positions = [0, 0.38, 0.6, 0.7, 1] + + custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256) + return custom_cmap + + +def colorize_depth(depth, log_scale): + if log_scale: + depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8) + mask = depth == 0 + colormap = create_custom_colormap() + rgb = colormap(depth)[:, :, :3] + rgb[mask] = 0. + return rgb + + +@spaces.GPU +@torch.no_grad() +def generate_lidar(model, cond): + img, pcd = sample_cond.sample(model, cond) + return img, pcd + + +def load_camera(image): + split_per_view = 4 + camera = np.array(image).astype(np.float32) / 255. + camera = camera.transpose(2, 0, 1) + camera_list = np.split(camera, split_per_view, axis=2) # split into n chunks as different views + camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE) + return camera_cond + + +with gr.Blocks(css=CSS) as demo: + gr.Markdown(TITLE) + gr.Markdown(DESCRIPTION) + gr.Markdown("### Camera-to-LiDAR Demo") + # gr.Markdown("You can slide the output to compare the depth prediction with input image") + + with gr.Row(): + input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') + output_image = gr.Image(label="Range Map", elem_id='img-display-output') + raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab") + submit = gr.Button("Submit") + + def on_submit(image): + cond = load_camera(image) + img, pcd = generate_lidar(model, cond) + + tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False) + pcd.save(tmp.name) + + rgb_img = colorize_depth(img, log_scale=True) + + return [rgb_img, tmp.name] + + submit.click(on_submit, inputs=[input_image], outputs=[output_image, raw_file]) + + example_files = sorted(os.listdir('cam_examples')) + example_files = [os.path.join('cam_examples', filename) for filename in example_files] + examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, raw_file], + fn=on_submit, cache_examples=True) + + +if __name__ == '__main__': + demo.queue().launch() diff --git a/app_config.py b/app_config.py new file mode 100644 index 0000000000000000000000000000000000000000..505160872e64d4fbdab7e564d8388a0afe647eed --- /dev/null +++ b/app_config.py @@ -0,0 +1,17 @@ +import torch + +CSS = """ +#img-display-container { + max-height: 100vh; + } +#img-display-input { + max-height: 80vh; + } +#img-display-output { + max-height: 80vh; + } +""" +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +TITLE = "# LiDAR Diffusion" +DESCRIPTION = """Official demo for **LiDAR Diffusion: Towards Realistic Scene Generation with LiDAR Diffusion Models**. +Please refer to our [paper](https://arxiv.org/abs/2404.00815), [project page](https://lidar-diffusion.github.io/), or [github](https://github.com/hancyran/LiDAR-Diffusion) for more details.""" diff --git a/cam_examples/conditioning_000011.png b/cam_examples/conditioning_000011.png new file mode 100644 index 0000000000000000000000000000000000000000..f8e13c1e65b515ee85ababb900fcb5316ac37086 Binary files /dev/null and b/cam_examples/conditioning_000011.png differ diff --git a/cam_examples/conditioning_000153.png b/cam_examples/conditioning_000153.png new file mode 100644 index 0000000000000000000000000000000000000000..81f0c60659b3f31b2c01cfc750780c72cd344f9f Binary files /dev/null and b/cam_examples/conditioning_000153.png differ diff --git a/cam_examples/conditioning_000354.png b/cam_examples/conditioning_000354.png new file mode 100644 index 0000000000000000000000000000000000000000..02b534675afac8bb83cf944c9098884fb6a74298 Binary files /dev/null and b/cam_examples/conditioning_000354.png differ diff --git a/cam_examples/conditioning_000555.png b/cam_examples/conditioning_000555.png new file mode 100644 index 0000000000000000000000000000000000000000..0d01f5c9be3b6118fad8fba41e4ffd11837d5676 Binary files /dev/null and b/cam_examples/conditioning_000555.png differ diff --git a/cam_examples/conditioning_001026.png b/cam_examples/conditioning_001026.png new file mode 100644 index 0000000000000000000000000000000000000000..266b1de4f9cd659d15ac6d176b3411753cd42159 Binary files /dev/null and b/cam_examples/conditioning_001026.png differ diff --git a/data/config/semantic-kitti.yaml b/data/config/semantic-kitti.yaml new file mode 100644 index 0000000000000000000000000000000000000000..628106553b4b964920c825af42a53b4e39b73cfd --- /dev/null +++ b/data/config/semantic-kitti.yaml @@ -0,0 +1,211 @@ +# This file is covered by the LICENSE file in the root of this project. +labels: + 0 : "unlabeled" + 1 : "outlier" + 10: "car" + 11: "bicycle" + 13: "bus" + 15: "motorcycle" + 16: "on-rails" + 18: "truck" + 20: "other-vehicle" + 30: "person" + 31: "bicyclist" + 32: "motorcyclist" + 40: "road" + 44: "parking" + 48: "sidewalk" + 49: "other-ground" + 50: "building" + 51: "fence" + 52: "other-structure" + 60: "lane-marking" + 70: "vegetation" + 71: "trunk" + 72: "terrain" + 80: "pole" + 81: "traffic-sign" + 99: "other-object" + 252: "moving-car" + 253: "moving-bicyclist" + 254: "moving-person" + 255: "moving-motorcyclist" + 256: "moving-on-rails" + 257: "moving-bus" + 258: "moving-truck" + 259: "moving-other-vehicle" +color_map: # bgr + 0 : [0, 0, 0] + 1 : [0, 0, 255] + 10: [245, 150, 100] + 11: [245, 230, 100] + 13: [250, 80, 100] + 15: [150, 60, 30] + 16: [255, 0, 0] + 18: [180, 30, 80] + 20: [255, 0, 0] + 30: [30, 30, 255] + 31: [200, 40, 255] + 32: [90, 30, 150] + 40: [255, 0, 255] + 44: [255, 150, 255] + 48: [75, 0, 75] + 49: [75, 0, 175] + 50: [0, 200, 255] + 51: [50, 120, 255] + 52: [0, 150, 255] + 60: [170, 255, 150] + 70: [0, 175, 0] + 71: [0, 60, 135] + 72: [80, 240, 150] + 80: [150, 240, 255] + 81: [0, 0, 255] + 99: [255, 255, 50] + 252: [245, 150, 100] + 256: [255, 0, 0] + 253: [200, 40, 255] + 254: [30, 30, 255] + 255: [90, 30, 150] + 257: [250, 80, 100] + 258: [180, 30, 80] + 259: [255, 0, 0] +content: # as a ratio with the total number of points + 0: 0.018889854628292943 + 1: 0.0002937197336781505 + 10: 0.040818519255974316 + 11: 0.00016609538710764618 + 13: 2.7879693665067774e-05 + 15: 0.00039838616015114444 + 16: 0.0 + 18: 0.0020633612104619787 + 20: 0.0016218197275284021 + 30: 0.00017698551338515307 + 31: 1.1065903904919655e-08 + 32: 5.532951952459828e-09 + 40: 0.1987493871255525 + 44: 0.014717169549888214 + 48: 0.14392298360372 + 49: 0.0039048553037472045 + 50: 0.1326861944777486 + 51: 0.0723592229456223 + 52: 0.002395131480328884 + 60: 4.7084144280367186e-05 + 70: 0.26681502148037506 + 71: 0.006035012012626033 + 72: 0.07814222006271769 + 80: 0.002855498193863172 + 81: 0.0006155958086189918 + 99: 0.009923127583046915 + 252: 0.001789309418528068 + 253: 0.00012709999297008662 + 254: 0.00016059776092534436 + 255: 3.745553104802113e-05 + 256: 0.0 + 257: 0.00011351574470342043 + 258: 0.00010157861367183268 + 259: 4.3840131989471124e-05 +# classes that are indistinguishable from single scan or inconsistent in +# ground truth are mapped to their closest equivalent +learning_map: + 0 : 0 # "unlabeled" + 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped + 10: 1 # "car" + 11: 2 # "bicycle" + 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped + 15: 3 # "motorcycle" + 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped + 18: 4 # "truck" + 20: 5 # "other-vehicle" + 30: 6 # "person" + 31: 7 # "bicyclist" + 32: 8 # "motorcyclist" + 40: 9 # "road" + 44: 10 # "parking" + 48: 11 # "sidewalk" + 49: 12 # "other-ground" + 50: 13 # "building" + 51: 14 # "fence" + 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped + 60: 9 # "lane-marking" to "road" ---------------------------------mapped + 70: 15 # "vegetation" + 71: 16 # "trunk" + 72: 17 # "terrain" + 80: 18 # "pole" + 81: 19 # "traffic-sign" + 99: 0 # "other-object" to "unlabeled" ----------------------------mapped + 252: 1 # "moving-car" to "car" ------------------------------------mapped + 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped + 254: 6 # "moving-person" to "person" ------------------------------mapped + 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped + 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped + 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped + 258: 4 # "moving-truck" to "truck" --------------------------------mapped + 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped +learning_map_inv: # inverse of previous map + 0: 0 # "unlabeled", and others ignored + 1: 10 # "car" + 2: 11 # "bicycle" + 3: 15 # "motorcycle" + 4: 18 # "truck" + 5: 20 # "other-vehicle" + 6: 30 # "person" + 7: 31 # "bicyclist" + 8: 32 # "motorcyclist" + 9: 40 # "road" + 10: 44 # "parking" + 11: 48 # "sidewalk" + 12: 49 # "other-ground" + 13: 50 # "building" + 14: 51 # "fence" + 15: 70 # "vegetation" + 16: 71 # "trunk" + 17: 72 # "terrain" + 18: 80 # "pole" + 19: 81 # "traffic-sign" +learning_ignore: # Ignore classes + 0: True # "unlabeled", and others ignored + 1: False # "car" + 2: False # "bicycle" + 3: False # "motorcycle" + 4: False # "truck" + 5: False # "other-vehicle" + 6: False # "person" + 7: False # "bicyclist" + 8: False # "motorcyclist" + 9: False # "road" + 10: False # "parking" + 11: False # "sidewalk" + 12: False # "other-ground" + 13: False # "building" + 14: False # "fence" + 15: False # "vegetation" + 16: False # "trunk" + 17: False # "terrain" + 18: False # "pole" + 19: False # "traffic-sign" +split: # sequence numbers + train: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 9 + - 10 + valid: + - 8 + test: + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 + - 21 diff --git a/lidm/data/__init__.py b/lidm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/data/annotated_dataset.py b/lidm/data/annotated_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a81bf6fb186a145b1a654e23291329683d37d918 --- /dev/null +++ b/lidm/data/annotated_dataset.py @@ -0,0 +1,48 @@ +from pathlib import Path +from typing import Optional, List, Dict, Union, Any +import warnings + +from torch.utils.data import Dataset + +from .conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder +from .conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder + + +class Annotated3DObjectsDataset(Dataset): + def __init__(self, min_objects_per_image: int, + max_objects_per_image: int, no_tokens: int, num_beams: int, cats: List[str], + cat_blacklist: Optional[List[str]] = None, **kwargs): + self.min_objects_per_image = min_objects_per_image + self.max_objects_per_image = max_objects_per_image + self.no_tokens = no_tokens + self.num_beams = num_beams + + self.categories = [c for c in cats if c not in cat_blacklist] if cat_blacklist is not None else cats + self._conditional_builders = None + + @property + def no_classes(self) -> int: + return len(self.categories) + + @property + def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: + # cannot set this up in init because no_classes is only known after loading data in init of superclass + if self._conditional_builders is None: + self._conditional_builders = { + 'center': ObjectsCenterPointsConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.num_beams + ), + 'bbox': ObjectsBoundingBoxConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.num_beams + ) + } + return self._conditional_builders + + def get_textual_label_for_category_id(self, category_id: int) -> str: + return self.categories[category_id] diff --git a/lidm/data/base.py b/lidm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ea9986df81f98aec686be929c3b620f660d02a92 --- /dev/null +++ b/lidm/data/base.py @@ -0,0 +1,121 @@ +import pdb +from abc import abstractmethod +from functools import partial + +import PIL +import numpy as np +from PIL import Image + +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset, IterableDataset + +from ..utils.aug_utils import get_lidar_transform, get_camera_transform, get_anno_transform + + +class DatasetBase(Dataset): + def __init__(self, data_root, split, dataset_config, aug_config, return_pcd=False, condition_key=None, + scale_factors=None, degradation=None, **kwargs): + self.data_root = data_root + self.split = split + self.data = [] + self.aug_config = aug_config + + self.img_size = dataset_config.size + self.fov = dataset_config.fov + self.depth_range = dataset_config.depth_range + self.filtered_map_cats = dataset_config.filtered_map_cats + self.depth_scale = dataset_config.depth_scale + self.log_scale = dataset_config.log_scale + + if self.log_scale: + self.depth_thresh = (np.log2(1./255. + 1) / self.depth_scale) * 2. - 1 + 1e-6 + else: + self.depth_thresh = (1./255. / self.depth_scale) * 2. - 1 + 1e-6 + self.return_pcd = return_pcd + + if degradation is not None and scale_factors is not None: + scaled_img_size = (int(self.img_size[0] / scale_factors[0]), int(self.img_size[1] / scale_factors[1])) + degradation_fn = { + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + self.degradation_transform = partial(TF.resize, size=scaled_img_size, interpolation=degradation_fn) + else: + self.degradation_transform = None + self.condition_key = condition_key + + self.lidar_transform = get_lidar_transform(aug_config, split) + self.anno_transform = get_anno_transform(aug_config, split) if condition_key in ['bbox', 'center'] else None + self.view_transform = get_camera_transform(aug_config, split) if condition_key in ['camera'] else None + + self.prepare_data() + + def prepare_data(self): + raise NotImplementedError + + def process_scan(self, range_img): + range_img = np.where(range_img < 0, 0, range_img) + + if self.log_scale: + # log scale + range_img = np.log2(range_img + 0.0001 + 1) + + range_img = range_img / self.depth_scale + range_img = range_img * 2. - 1. + + range_img = np.clip(range_img, -1, 1) + range_img = np.expand_dims(range_img, axis=0) + + # mask + range_mask = np.ones_like(range_img) + range_mask[range_img < self.depth_thresh] = -1 + + return range_img, range_mask + + @staticmethod + def load_lidar_sweep(*args, **kwargs): + raise NotImplementedError + + @staticmethod + def load_semantic_map(*args, **kwargs): + raise NotImplementedError + + @staticmethod + def load_camera(*args, **kwargs): + raise NotImplementedError + + @staticmethod + def load_annotation(*args, **kwargs): + raise NotImplementedError + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + example = dict() + return example + + +class Txt2ImgIterableBaseDataset(IterableDataset): + """ + Define an interface to make the IterableDatasets for text2img data chainable + """ + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass \ No newline at end of file diff --git a/lidm/data/conditional_builder/__init__.py b/lidm/data/conditional_builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/data/conditional_builder/objects_bbox.py b/lidm/data/conditional_builder/objects_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..982da90403dbcd7bcada38651cf13efdfb89bc3d --- /dev/null +++ b/lidm/data/conditional_builder/objects_bbox.py @@ -0,0 +1,53 @@ +from itertools import cycle +from typing import List, Tuple, Callable, Optional + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from torch import LongTensor, Tensor + +from ..helper_types import BoundingBox, Annotation +from .objects_center_points import ObjectsCenterPointsConditionalBuilder, convert_pil_to_tensor +from .utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ + pad_list, get_plot_font_size, absolute_bbox + + +class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): + @property + def object_descriptor_length(self) -> int: + return 3 # 3/5: object_representation (1) + corners (2/4) + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_tuples = [ + (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) + for ann in annotations + ] + object_tuples = pad_list(object_tuples, self.empty_tuple, self.no_max_objects) + return object_tuples + + def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + object_triples = grouper(conditional_list, 3) + assert conditional.shape[0] == self.embedding_dim + return [(object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) for object_triple in object_triples if object_triple[0] != self.none], None + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + # font = ImageFont.truetype( + # "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", + # size=get_plot_font_size(font_size, figure_size) + # ) + font = ImageFont.load_default() + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): + annotation = self.representation_to_annotation(representation) + # class_label = label_for_category_no(annotation.category_id) + ' ' + additional_parameters_string(annotation) + class_label = label_for_category_no(annotation.category_id) + bbox = absolute_bbox(bbox, width, height) + draw.rectangle(bbox, outline=color, width=line_width) + draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. diff --git a/lidm/data/conditional_builder/objects_center_points.py b/lidm/data/conditional_builder/objects_center_points.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b75b85d10ae01b24f7e959fadd9a9b4632023b --- /dev/null +++ b/lidm/data/conditional_builder/objects_center_points.py @@ -0,0 +1,150 @@ +import math +import random +import warnings +from itertools import cycle +from typing import List, Optional, Tuple, Callable + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from .utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, pad_list, get_circle_size, \ + get_plot_font_size, absolute_bbox +from ..helper_types import BoundingBox, Annotation, Image +from torch import LongTensor, Tensor +from torchvision.transforms import PILToTensor + + +pil_to_tensor = PILToTensor() + + +def convert_pil_to_tensor(image: Image) -> Tensor: + with warnings.catch_warnings(): + # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 + warnings.simplefilter("ignore") + return pil_to_tensor(image) + + +class ObjectsCenterPointsConditionalBuilder: + def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, num_beams: int): + self.no_object_classes = no_object_classes + self.no_max_objects = no_max_objects + self.no_tokens = no_tokens + # self.no_sections = int(math.sqrt(self.no_tokens)) + self.no_sections = (self.no_tokens // num_beams, num_beams) # (width, height) + + @property + def none(self) -> int: + return self.no_tokens - 1 + + @property + def object_descriptor_length(self) -> int: + return 2 + + @property + def empty_tuple(self) -> Tuple: + return (self.none,) * self.object_descriptor_length + + @property + def embedding_dim(self) -> int: + return self.no_max_objects * self.object_descriptor_length + + def tokenize_coordinates(self, x: float, y: float) -> int: + """ + Express 2d coordinates with one number. + Example: assume self.no_tokens = 16, then no_sections = 4: + 0 0 0 0 + 0 0 # 0 + 0 0 0 0 + 0 0 0 x + Then the # position corresponds to token 6, the x position to token 15. + @param x: float in [0, 1] + @param y: float in [0, 1] + @return: discrete tokenized coordinate + """ + x_discrete = int(round(x * (self.no_sections[0] - 1))) + y_discrete = int(round(y * (self.no_sections[1] - 1))) + return y_discrete * self.no_sections[0] + x_discrete + + def coordinates_from_token(self, token: int) -> (float, float): + x = token % self.no_sections[0] + y = token // self.no_sections[0] + return x / (self.no_sections[0] - 1), y / (self.no_sections[1] - 1) + + def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: + x0, y0 = self.coordinates_from_token(token1) + x1, y1 = self.coordinates_from_token(token2) + # x2, y2 = self.coordinates_from_token(token3) + # x3, y3 = self.coordinates_from_token(token4) + return x0, y0, x1, y1 + + def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple: + # return self.tokenize_coordinates(bbox[0], bbox[1]), self.tokenize_coordinates(bbox[2], bbox[3]), self.tokenize_coordinates(bbox[4], bbox[5]), self.tokenize_coordinates(bbox[6], bbox[7]) + return self.tokenize_coordinates(bbox[0], bbox[1]), self.tokenize_coordinates(bbox[4], bbox[5]) + + def inverse_build(self, conditional: LongTensor) \ + -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + table_of_content = grouper(conditional_list, self.object_descriptor_length) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_tuple[0], self.coordinates_from_token(object_tuple[1])) + for object_tuple in table_of_content if object_tuple[0] != self.none + ], None + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + circle_size = get_circle_size(figure_size) + # font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', + # size=get_plot_font_size(font_size, figure_size)) + font = ImageFont.load_default() + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): + x_abs, y_abs = x * width, y * height + ann = self.representation_to_annotation(representation) + label = label_for_category_no(ann.category_id) + ' ' + additional_parameters_string(ann) + ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] + draw.ellipse(ellipse_bbox, fill=color, width=0) + draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. + + def object_representation(self, annotation: Annotation) -> int: + return annotation.category_id + + def representation_to_annotation(self, representation: int) -> Annotation: + category_id = representation % self.no_object_classes + # noinspection PyTypeChecker + return Annotation( + bbox=None, + category_id=category_id, + ) + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_tuples = [ + (self.object_representation(a), + self.tokenize_coordinates(a.center[0], a.center[1])) + for a in annotations + ] + empty_tuple = (self.none, self.none) + object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) + return object_tuples + + def build(self, annotations: List[Annotation]) \ + -> LongTensor: + if len(annotations) == 0: + warnings.warn('Did not receive any annotations.') + + random.shuffle(annotations) + if len(annotations) > self.no_max_objects: + warnings.warn('Received more annotations than allowed.') + annotations = annotations[:self.no_max_objects] + + object_tuples = self._make_object_descriptors(annotations) + flattened = [token for tuple_ in object_tuples for token in tuple_] + assert len(flattened) == self.embedding_dim + assert all(0 <= value < self.no_tokens for value in flattened) + + return LongTensor(flattened) diff --git a/lidm/data/conditional_builder/utils.py b/lidm/data/conditional_builder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e162ccdde4e61b535c85f7a0ca58c9101d184d79 --- /dev/null +++ b/lidm/data/conditional_builder/utils.py @@ -0,0 +1,188 @@ +import importlib +from typing import List, Any, Tuple, Optional + +import numpy as np +from ..helper_types import BoundingBox, Annotation + +# source: seaborn, color palette tab10 +COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), + (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] +BLACK = (0, 0, 0) +GRAY_75 = (63, 63, 63) +GRAY_50 = (127, 127, 127) +GRAY_25 = (191, 191, 191) +WHITE = (255, 255, 255) +FULL_CROP = (0., 0., 1., 1.) + + +def corners_3d_to_2d(corners3d): + """ + Args: + corners3d: (N, 8, 2) + Returns: + corners2d: (N, 4, 2) + """ + # select pairs to reorganize + mask_0_3 = corners3d[:, 0:4, 0].argmax(1) // 2 != 0 + mask_4_7 = corners3d[:, 4:8, 0].argmin(1) // 2 != 0 + + # reorganize corners in the order of (bottom-right, bottom-left) + corners3d[mask_0_3, 0:4] = corners3d[mask_0_3][:, [2, 3, 0, 1]] + # reorganize corners in the order of (top-left, top-right) + corners3d[mask_4_7, 4:8] = corners3d[mask_4_7][:, [2, 3, 0, 1]] + + # calculate corners in order + bot_r = np.stack([corners3d[:, 0:2, 0].max(1), corners3d[:, 0:2, 1].min(1)], axis=-1) + bot_l = np.stack([corners3d[:, 2:4, 0].min(1), corners3d[:, 2:4, 1].min(1)], axis=-1) + top_l = np.stack([corners3d[:, 4:6, 0].min(1), corners3d[:, 4:6, 1].max(1)], axis=-1) + top_r = np.stack([corners3d[:, 6:8, 0].max(1), corners3d[:, 6:8, 1].max(1)], axis=-1) + + return np.stack([bot_r, bot_l, top_l, top_r], axis=1) + + +def rotate_points_along_z(points, angle): + """ + Args: + points: (N, 3 + C) + angle: angle along z-axis, angle increases x ==> y + Returns: + + """ + cosa = np.cos(angle) + sina = np.sin(angle) + zeros = np.zeros(points.shape[0]) + ones = np.ones(points.shape[0]) + rot_matrix = np.stack(( + cosa, sina, zeros, + -sina, cosa, zeros, + zeros, zeros, ones)).reshape((-1, 3, 3)) + points_rot = np.matmul(points[:, :, 0:3], rot_matrix) + points_rot = np.concatenate((points_rot, points[:, :, 3:]), axis=-1) + return points_rot + + +def boxes_to_corners_3d(boxes3d): + """ + 7 -------- 4 + /| /| + 6 -------- 5 . + | | | | + . 3 -------- 0 + |/ |/ + 2 -------- 1 + Args: + boxes3d: (N, 7) [x, y, z, dx, dy, dz, heading], (x, y, z) is the box center + + Returns: + corners3d: (N, 8, 3) + """ + template = np.array( + [[1, 1, -1], [1, -1, -1], [-1, -1, -1], [-1, 1, -1], + [1, 1, 1], [1, -1, 1], [-1, -1, 1], [-1, 1, 1]], + ) / 2 + + # corners3d = boxes3d[:, None, 3:6].repeat(1, 8, 1) * template[None, :, :] + corners3d = np.tile(boxes3d[:, None, 3:6], (1, 8, 1)) * template[None, :, :] + corners3d = rotate_points_along_z(corners3d.reshape((-1, 8, 3)), boxes3d[:, 6]).reshape((-1, 8, 3)) + corners3d += boxes3d[:, None, 0:3] + + return corners3d + + +def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: + """ + Give intersection area of two rectangles. + @param rectangle1: (x0, y0, w, h) of first rectangle + @param rectangle2: (x0, y0, w, h) of second rectangle + """ + rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] + rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] + x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) + y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) + return x_overlap * y_overlap + + +def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: + return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] + + +def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: + bbox = relative_bbox + # bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height + bbox = bbox[0] * width, bbox[1] * height, bbox[2] * width, bbox[3] * height + # return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + x1, x2 = min(int(bbox[2]), int(bbox[0])), max(int(bbox[2]), int(bbox[0])) + y1, y2 = min(int(bbox[3]), int(bbox[1])), max(int(bbox[3]), int(bbox[1])) + if x1 == x2: + x2 += 1 + if y1 == y2: + y2 += 1 + return x1, y1, x2, y2 + + +def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: + return list_ + [pad_element for _ in range(pad_to_length - len(list_))] + + +def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ + List[Annotation]: + def clamp(x: float): + return max(min(x, 1.), 0.) + + def rescale_bbox(bbox: BoundingBox) -> BoundingBox: + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + if flip: + x0 = 1 - (x0 + w) + return x0, y0, w, h + + return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] + + +def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: + return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] + + +def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: + sl = slice(1) if short else slice(None) + string = '' + if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): + return string + if annotation.is_group_of: + string += 'group'[sl] + ',' + if annotation.is_occluded: + string += 'occluded'[sl] + ',' + if annotation.is_depiction: + string += 'depiction'[sl] + ',' + if annotation.is_inside: + string += 'inside'[sl] + return '(' + string.strip(",") + ')' + + +def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: + if font_size is None: + font_size = 10 + if max(figure_size) >= 256: + font_size = 12 + if max(figure_size) >= 512: + font_size = 15 + return font_size + + +def get_circle_size(figure_size: Tuple[int, int]) -> int: + circle_size = 2 + if max(figure_size) >= 256: + circle_size = 3 + if max(figure_size) >= 512: + circle_size = 4 + return circle_size + + +def load_object_from_string(object_string: str) -> Any: + """ + Source: https://stackoverflow.com/a/10773699 + """ + module_name, class_name = object_string.rsplit(".", 1) + return getattr(importlib.import_module(module_name), class_name) diff --git a/lidm/data/helper_types.py b/lidm/data/helper_types.py new file mode 100644 index 0000000000000000000000000000000000000000..ae525c7775ad55090511ef021db5867ce030e421 --- /dev/null +++ b/lidm/data/helper_types.py @@ -0,0 +1,20 @@ +from typing import Tuple, Optional, NamedTuple, Union, List +from PIL.Image import Image as pil_image +from torch import Tensor + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +Image = Union[Tensor, pil_image] +# BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h | x0, y0, x1, y1 +# BoundingBox3D = Tuple[float, float, float, float, float, float] # x0, y0, z0, l, w, h +BoundingBox = Tuple[float, float, float, float] # corner coordinates (x,y) in the order of bottom-right -> bottom-left -> top-left -> top-right +Center = Tuple[float, float] + + +class Annotation(NamedTuple): + category_id: int + bbox: Optional[BoundingBox] = None + center: Optional[Center] = None diff --git a/lidm/data/kitti.py b/lidm/data/kitti.py new file mode 100644 index 0000000000000000000000000000000000000000..103426449bbda265d817ea2641dd48c3194d2c4e --- /dev/null +++ b/lidm/data/kitti.py @@ -0,0 +1,345 @@ +import glob +import os +import pickle +import numpy as np +import yaml +from PIL import Image +import xml.etree.ElementTree as ET + +from lidm.data.base import DatasetBase +from .annotated_dataset import Annotated3DObjectsDataset +from .conditional_builder.utils import corners_3d_to_2d +from .helper_types import Annotation +from ..utils.lidar_utils import pcd2range, pcd2coord2d, range2pcd + +# TODO add annotation categories and semantic categories +CATEGORIES = ['ignore', 'car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', 'motorcyclist', + 'road', 'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', 'trunk', 'terrain', + 'pole', 'traffic-sign'] +CATE2LABEL = {k: v for v, k in enumerate(CATEGORIES)} # 0: invalid, 1~10: categories +LABEL2RGB = np.array([(0, 0, 0), (0, 0, 142), (119, 11, 32), (0, 0, 230), (0, 0, 70), (0, 0, 90), (220, 20, 60), + (255, 0, 0), (0, 0, 110), (128, 64, 128), (250, 170, 160), (244, 35, 232), (230, 150, 140), + (70, 70, 70), (190, 153, 153), (107, 142, 35), (0, 80, 100), (230, 150, 140), (153, 153, 153), + (220, 220, 0)]) +CAMERAS = ['CAM_FRONT'] +BBOX_CATS = ['car', 'people', 'cycle'] +BBOX_CAT2LABEL = {'car': 0, 'truck': 0, 'bus': 0, 'caravan': 0, 'person': 1, 'rider': 2, 'motorcycle': 2, 'bicycle': 2} + +# train + test +SEM_KITTI_TRAIN_SET = ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'] +KITTI_TRAIN_SET = SEM_KITTI_TRAIN_SET + ['11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'] +KITTI360_TRAIN_SET = ['00', '02', '04', '05', '06', '07', '09', '10'] + ['08'] # partial test data at '02' sequence +CAM_KITTI360_TRAIN_SET = ['00', '04', '05', '06', '07', '08', '09', '10'] # cam mismatch lidar in '02' + +# validation +SEM_KITTI_VAL_SET = KITTI_VAL_SET = ['08'] +CAM_KITTI360_VAL_SET = KITTI360_VAL_SET = ['03'] + + +class KITTIBase(DatasetBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dataset_name = 'kitti' + self.num_sem_cats = kwargs['dataset_config'].num_sem_cats + 1 + + @staticmethod + def load_lidar_sweep(path): + scan = np.fromfile(path, dtype=np.float32) + scan = scan.reshape((-1, 4)) + points = scan[:, 0:3] # get xyz + return points + + def load_semantic_map(self, path, pcd): + raise NotImplementedError + + def load_camera(self, path): + raise NotImplementedError + + def __getitem__(self, idx): + example = dict() + data_path = self.data[idx] + # lidar point cloud + sweep = self.load_lidar_sweep(data_path) + + if self.lidar_transform: + sweep, _ = self.lidar_transform(sweep, None) + + if self.condition_key == 'segmentation': + # semantic maps + proj_range, sem_map = self.load_semantic_map(data_path, sweep) + example[self.condition_key] = sem_map + else: + proj_range, _ = pcd2range(sweep, self.img_size, self.fov, self.depth_range) + proj_range, proj_mask = self.process_scan(proj_range) + example['image'], example['mask'] = proj_range, proj_mask + if self.return_pcd: + reproj_sweep, _, _ = range2pcd(proj_range[0] * .5 + .5, self.fov, self.depth_range, self.depth_scale, self.log_scale) + example['raw'] = sweep + example['reproj'] = reproj_sweep.astype(np.float32) + + # image degradation + if self.degradation_transform: + degraded_proj_range = self.degradation_transform(proj_range) + example['degraded_image'] = degraded_proj_range + + # cameras + if self.condition_key == 'camera': + cameras = self.load_camera(data_path) + example[self.condition_key] = cameras + + return example + + +class SemanticKITTIBase(KITTIBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + assert self.condition_key in ['segmentation'] # for segmentation input only + self.label2rgb = LABEL2RGB + + def prepare_data(self): + # read data paths from KITTI + for seq_id in eval('SEM_KITTI_%s_SET' % self.split.upper()): + self.data.extend(glob.glob(os.path.join( + self.data_root, f'dataset/sequences/{seq_id}/velodyne/*.bin'))) + # read label mapping + data_config = yaml.safe_load(open('./data/config/semantic-kitti.yaml', 'r')) + remap_dict = data_config["learning_map"] + max_key = max(remap_dict.keys()) + self.learning_map = np.zeros((max_key + 100), dtype=np.int32) + self.learning_map[list(remap_dict.keys())] = list(remap_dict.values()) + + def load_semantic_map(self, path, pcd): + label_path = path.replace('velodyne', 'labels').replace('.bin', '.label') + labels = np.fromfile(label_path, dtype=np.uint32) + labels = labels.reshape((-1)) + labels = labels & 0xFFFF # semantic label in lower half + labels = self.learning_map[labels] + + proj_range, sem_map = pcd2range(pcd, self.img_size, self.fov, self.depth_range, labels=labels) + # sem_map = np.expand_dims(sem_map, axis=0).astype(np.int64) + sem_map = sem_map.astype(np.int64) + if self.filtered_map_cats is not None: + sem_map[np.isin(sem_map, self.filtered_map_cats)] = 0 # set filtered category as noise + onehot = np.eye(self.num_sem_cats, dtype=np.float32)[sem_map].transpose(2, 0, 1) + return proj_range, onehot + + +class SemanticKITTITrain(SemanticKITTIBase): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset/SemanticKITTI', split='train', **kwargs) + + +class SemanticKITTIValidation(SemanticKITTIBase): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset/SemanticKITTI', split='val', **kwargs) + + +class KITTI360Base(KITTIBase): + def __init__(self, split_per_view=None, **kwargs): + super().__init__(**kwargs) + self.split_per_view = split_per_view + if self.condition_key == 'camera': + assert self.split_per_view is not None, 'For camera-to-lidar, need to specify split_per_view' + + def prepare_data(self): + # read data paths + self.data = [] + if self.condition_key == 'camera': + seq_list = eval('CAM_KITTI360_%s_SET' % self.split.upper()) + else: + seq_list = eval('KITTI360_%s_SET' % self.split.upper()) + for seq_id in seq_list: + self.data.extend(glob.glob(os.path.join( + self.data_root, f'data_3d_raw/2013_05_28_drive_00{seq_id}_sync/velodyne_points/data/*.bin'))) + + def random_drop_camera(self, camera_list): + if np.random.rand() < self.aug_config['camera_drop'] and self.split == 'train': + camera_list = [np.zeros_like(c) if i != len(camera_list) // 2 else c for i, c in enumerate(camera_list)] # keep the middle view only + return camera_list + + def load_camera(self, path): + camera_path = path.replace('data_3d_raw', 'data_2d_camera').replace('velodyne_points/data', 'image_00/data_rect').replace('.bin', '.png') + camera = np.array(Image.open(camera_path)).astype(np.float32) / 255. + camera = camera.transpose(2, 0, 1) + if self.view_transform: + camera = self.view_transform(camera) + camera_list = np.split(camera, self.split_per_view, axis=2) # split into n chunks as different views + camera_list = self.random_drop_camera(camera_list) + return camera_list + + +class KITTI360Train(KITTI360Base): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset/KITTI-360', split='train', **kwargs) + + +class KITTI360Validation(KITTI360Base): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset/KITTI-360', split='val', **kwargs) + + +class AnnotatedKITTI360Base(Annotated3DObjectsDataset, KITTI360Base): + def __init__(self, **kwargs): + self.id_bbox_dict = dict() + self.id_label_dict = dict() + + Annotated3DObjectsDataset.__init__(self, **kwargs) + KITTI360Base.__init__(self, **kwargs) + assert self.condition_key in ['center', 'bbox'] # for annotated images only + + @staticmethod + def parseOpencvMatrix(node): + rows = int(node.find('rows').text) + cols = int(node.find('cols').text) + data = node.find('data').text.split(' ') + + mat = [] + for d in data: + d = d.replace('\n', '') + if len(d) < 1: + continue + mat.append(float(d)) + mat = np.reshape(mat, [rows, cols]) + return mat + + def parseVertices(self, child): + transform = self.parseOpencvMatrix(child.find('transform')) + R = transform[:3, :3] + T = transform[:3, 3] + vertices = self.parseOpencvMatrix(child.find('vertices')) + vertices = np.matmul(R, vertices.transpose()).transpose() + T + return vertices + + def parse_bbox_xml(self, path): + tree = ET.parse(path) + root = tree.getroot() + + bbox_dict = dict() + label_dict = dict() + for child in root: + if child.find('transform') is None: + continue + + label_name = child.find('label').text + if label_name not in BBOX_CAT2LABEL: + continue + + label = BBOX_CAT2LABEL[label_name] + timestamp = int(child.find('timestamp').text) + # verts = self.parseVertices(child) + verts = self.parseOpencvMatrix(child.find('vertices'))[:8] + if timestamp in bbox_dict: + bbox_dict[timestamp].append(verts) + label_dict[timestamp].append(label) + else: + bbox_dict[timestamp] = [verts] + label_dict[timestamp] = [label] + return bbox_dict, label_dict + + def prepare_data(self): + KITTI360Base.prepare_data(self) + + self.data = [p for p in self.data if '2013_05_28_drive_0008_sync' not in p] # remove unlabeled sequence 08 + seq_list = eval('KITTI360_%s_SET' % self.split.upper()) + for seq_id in seq_list: + if seq_id != '08': + xml_path = os.path.join(self.data_root, f'data_3d_bboxes/train/2013_05_28_drive_00{seq_id}_sync.xml') + bbox_dict, label_dict = self.parse_bbox_xml(xml_path) + self.id_bbox_dict[seq_id] = bbox_dict + self.id_label_dict[seq_id] = label_dict + + def load_annotation(self, path): + seq_id = path.split('/')[-4].split('_')[-2][-2:] + timestamp = int(path.split('/')[-1].replace('.bin', '')) + verts_list = self.id_bbox_dict[seq_id][timestamp] + label_list = self.id_label_dict[seq_id][timestamp] + + if self.condition_key == 'bbox': + points = np.stack(verts_list) + elif self.condition_key == 'center': + points = (verts_list[0] + verts_list[6]) / 2. + else: + raise NotImplementedError + labels = np.array([label_list]) + if self.anno_transform: + points, labels = self.anno_transform(points, labels) + return points, labels + + def __getitem__(self, idx): + example = dict() + data_path = self.data[idx] + + # lidar point cloud + sweep = self.load_lidar_sweep(data_path) + + # annotations + bbox_points, bbox_labels = self.load_annotation(data_path) + + if self.lidar_transform: + sweep, bbox_points = self.lidar_transform(sweep, bbox_points) + + # point cloud -> range + proj_range, _ = pcd2range(sweep, self.img_size, self.fov, self.depth_range) + proj_range, proj_mask = self.process_scan(proj_range) + example['image'], example['mask'] = proj_range, proj_mask + if self.return_pcd: + example['reproj'] = sweep + + # annotation -> range + # NOTE: do not need to transform bbox points along with lidar, since their coordinates are based on range-image space instead of 3D space + proj_bbox_points, proj_bbox_labels = pcd2coord2d(bbox_points, self.fov, self.depth_range, labels=bbox_labels) + builder = self.conditional_builders[self.condition_key] + if self.condition_key == 'bbox': + proj_bbox_points = corners_3d_to_2d(proj_bbox_points) + annotations = [Annotation(bbox=bbox.flatten(), category_id=label) for bbox, label in + zip(proj_bbox_points, proj_bbox_labels)] + else: + annotations = [Annotation(center=center, category_id=label) for center, label in + zip(proj_bbox_points, proj_bbox_labels)] + example[self.condition_key] = builder.build(annotations) + + return example + + +class AnnotatedKITTI360Train(AnnotatedKITTI360Base): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset/KITTI-360', split='train', cats=BBOX_CATS, **kwargs) + + +class AnnotatedKITTI360Validation(AnnotatedKITTI360Base): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset/KITTI-360', split='train', cats=BBOX_CATS, **kwargs) + + +class KITTIImageBase(KITTIBase): + """ + Range ImageSet only combining KITTI-360 and SemanticKITTI + + #Samples (Training): 98014, #Samples (Val): 3511 + + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + assert self.condition_key in [None, 'image'] # for image input only + + def prepare_data(self): + # read data paths from KITTI-360 + self.data = [] + for seq_id in eval('KITTI360_%s_SET' % self.split.upper()): + self.data.extend(glob.glob(os.path.join( + self.data_root, f'KITTI-360/data_3d_raw/2013_05_28_drive_00{seq_id}_sync/velodyne_points/data/*.bin'))) + + # read data paths from KITTI + for seq_id in eval('KITTI_%s_SET' % self.split.upper()): + self.data.extend(glob.glob(os.path.join( + self.data_root, f'SemanticKITTI/dataset/sequences/{seq_id}/velodyne/*.bin'))) + + +class KITTIImageTrain(KITTIImageBase): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset', split='train', **kwargs) + + +class KITTIImageValidation(KITTIImageBase): + def __init__(self, **kwargs): + super().__init__(data_root='./dataset', split='val', **kwargs) diff --git a/lidm/eval/README.md b/lidm/eval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..33e8f8998cf905ac22e4b64a1b4ba87153ca880e --- /dev/null +++ b/lidm/eval/README.md @@ -0,0 +1,95 @@ +# Evaluation Toolbox for LiDAR Generation + +This directory is a **self-contained**, **memory-friendly** and mostly **CUDA-accelerated** toolbox of multiple evaluation metrics for LiDAR generative models, including: +* Perceptual metrics (our proposed): + * Fréchet Range Image Distance (**FRID**) + * Fréchet Sparse Volume Distance (**FSVD**) + * Fréchet Point-based Volume Distance (**FPVD**) +* Statistical metrics (proposed in [Learning Representations and Generative Models for 3D Point Clouds](https://arxiv.org/abs/1707.02392)): + * Minimum Matching Distance (**MMD**) + * Jensen-Shannon Divergence (**JSD**) +* Statistical pairwise metrics (for reconstruction only): + * Chamfer Distance (**CD**) + * Earth Mover's Distance (**EMD**) + +## Citation + +If you find this project useful in your research, please consider citing: +``` +@article{ran2024towards, + title={Towards Realistic Scene Generation with LiDAR Diffusion Models}, + author={Ran, Haoxi and Guizilini, Vitor and Wang, Yue}, + journal={arXiv preprint arXiv:2404.00815}, + year={2024} +} +``` + + +## Dependencies + +### Basic (install through **pip**): +* scipy +* numpy +* torch +* pyyaml + +### Required by FSVD and FPVD: +* [Torchsparse v1.4.0](https://github.com/mit-han-lab/torchsparse/tree/v1.4.0) (pip install git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0) +* [Google Sparse Hash library](https://github.com/sparsehash/sparsehash) (apt-get install libsparsehash-dev **or** compile locally and update variable CPLUS_INCLUDE_PATH with directory path) + + +## Model Zoo + +To evaluate with perceptual metrics on different types of LiDAR data, you can download all models through: +* this [google drive link](https://drive.google.com/file/d/1Ml4p4_nMlwLkSp7JB528GJv2_HxO8v1i/view?usp=drive_link) in the .zip file + +or +* the **full directory** of one specific model: + +### 64-beam LiDAR (trained on [SemanticKITTI](http://semantic-kitti.org/dataset.html)): + +| Metric | Model | Arch | Link | Code | Comments | +|:------:|:-------------------------------------------------------------------------------------------:|:-----------------------:|:-------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------|---------------------------------------------------------------------------| +| FRID | [RangeNet++](https://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/milioto2019iros.pdf) | DarkNet21-based UNet | [Google Drive](https://drive.google.com/drive/folders/1ZS8KOoxB9hjB6kwKbH5Zfc8O5qJlKsbl?usp=drive_link) | [./models/rangenet/model.py](./models/rangenet/model.py) | range image input (our trained model without the need of remission input) | +| FSVD | [MinkowskiNet](https://arxiv.org/abs/1904.08755) | Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1zN12ZEvjIvo4PCjAsncgC22yvtRrCCMe?usp=drive_link) | [./models/minkowskinet/model.py](./models/minkowskinet/model.py) | point cloud input | +| FPVD | [SPVCNN](https://arxiv.org/abs/2007.16100) | Point-Voxel Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1oEm3qpxfGetiVAfXIvecawEiFqW79M6B?usp=drive_link) | [./models/spvcnn/model.py](./models/spvcnn/model.py) | point cloud input | + + +### 32-beam LiDAR (trained on [nuScenes](https://www.nuscenes.org/nuscenes)): + +| Metric | Model | Arch | Link | Code | Comments | +|:------:|:------------------------------------------------:|:-----------------------:|:-------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------|-------------------| +| FSVD | [MinkowskiNet](https://arxiv.org/abs/1904.08755) | Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1oZIS9FlklCQ6dlh3TZ8Junir7QwgT-Me?usp=drive_link) | [./models/minkowskinet/model.py](./models/minkowskinet/model.py) | point cloud input | +| FPVD | [SPVCNN](https://arxiv.org/abs/2007.16100) | Point-Voxel Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1F69RbprAoT6MOJ7iI0KHjxuq-tbeqGiR?usp=drive_link) | [./models/spvcnn/model.py](./models/spvcnn/model.py) | point cloud input | + + +## Usage + +1. Place the unzipped `pretrained_weights` folder under the root python directory **or** modify the `DEFAULT_ROOT` variable in the `__init__.py`. +2. Prepare input data, including the synthesized samples and the reference dataset. **Note**: The reference data should be the **point clouds projected back from range images** instead of raw point clouds. +3. Specify the data type (`32` or `64`) and the metrics to evaluate. Options: `mmd`, `jsd`, `frid`, `fsvd`, `fpvd`, `cd`, `emd`. +4. (Optional) If you want to compute `frid`, `fsvd` or `fpvd` metric, adjust the corresponding batch size through the `MODAL2BATCHSIZE` in file `__init__.py` according to your max GPU memory (default: ~24GB). +5. Start evaluation and all results will print out! + +### Example: + +``` +from .eval_utils import evaluate + +data = '64' # specify data type to evaluate +metrics = ['mmd', 'jsd', 'frid', 'fsvd', 'fpvd'] # specify metrics to evaluate + +# list of np.float32 array +# shape of each array: (#points, #dim=3), #dim: xyz coordinate (NOTE: no need to input remission) +reference = ... +samples = ... + +evaluate(reference, samples, metrics, data) +``` + + +## Acknowledgement + +- The implementation of MinkowskiNet and SPVCNN is borrowed from [2DPASS](https://github.com/yanx27/2DPASS). +- The implementation of RangeNet++ is borrowed from [the official RangeNet++ codebase](https://github.com/PRBonn/lidar-bonnetal). +- The implementation of Chamfer Distance is adapted from [CD Pytorch Implementation](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch) and Earth Mover's Distance from [MSN official repo](https://github.com/Colin97/MSN-Point-Cloud-Completion). diff --git a/lidm/eval/__init__.py b/lidm/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b66df061c883b417576c2c1f6ae26f991c6c64c4 --- /dev/null +++ b/lidm/eval/__init__.py @@ -0,0 +1,62 @@ +""" +@Author: Haoxi Ran +@Date: 01/03/2024 +@Citation: Towards Realistic Scene Generation with LiDAR Diffusion Models + +""" + +import os + +import torch +import yaml + +from lidm.utils.misc_utils import dict2namespace +from ..modules.rangenet.model import Model as rangenet + +try: + from ..modules.spvcnn.model import Model as spvcnn + from ..modules.minkowskinet.model import Model as minkowskinet +except: + print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24') + +# user settings +DEFAULT_ROOT = './pretrained_weights' +MODAL2BATCHSIZE = {'range': 100, 'voxel': 50, 'point_voxel': 25} +OUTPUT_TEMPLATE = 50 * '-' + '\n|' + 16 * ' ' + '{}:{:.4E}' + 17 * ' ' + '|\n' + 50 * '-' + +# eval settings (do not modify) +VOXEL_SIZE = 0.05 +NUM_SECTORS = 16 +AGG_TYPE = 'depth' +TYPE2DATASET = {'32': 'nuscenes', '64': 'kitti'} +DATA_CONFIG = {'64': {'x': [-50, 50], 'y': [-50, 50], 'z': [-3, 1]}, + '32': {'x': [-30, 30], 'y': [-30, 30], 'z': [-3, 6]}} +MODALITY2MODEL = {'range': 'rangenet', 'voxel': 'minkowskinet', 'point_voxel': 'spvcnn'} +DATASET_CONFIG = {'kitti': {'size': [64, 1024], 'fov': [3, -25], 'depth_range': [1.0, 56.0], 'depth_scale': 6}, + 'nuscenes': {'size': [32, 1024], 'fov': [10, -30], 'depth_range': [1.0, 45.0]}} + + +def build_model(dataset_name, model_name, device='cpu'): + # config + model_folder = os.path.join(DEFAULT_ROOT, dataset_name, model_name) + + if not os.path.isdir(model_folder): + raise Exception('Not Available Pretrained Weights!') + + config = yaml.safe_load(open(os.path.join(model_folder, 'config.yaml'), 'r')) + if model_name != 'rangenet': + config = dict2namespace(config) + + # build model + model = eval(model_name)(config) + + # load checkpoint + if model_name == 'rangenet': + model.load_pretrained_weights(model_folder) + else: + ckpt = torch.load(os.path.join(model_folder, 'model.ckpt'), map_location="cpu") + model.load_state_dict(ckpt['state_dict'], strict=False) + model.to(device) + model.eval() + + return model diff --git a/lidm/eval/compile.sh b/lidm/eval/compile.sh new file mode 100644 index 0000000000000000000000000000000000000000..4d805f490ef9044309635b31b663c45bceaa4233 --- /dev/null +++ b/lidm/eval/compile.sh @@ -0,0 +1,9 @@ +#!/bin/sh + +cd modules/chamfer +python setup.py build_ext --inplace + +cd ../emd +python setup.py build_ext --inplace + +cd .. diff --git a/lidm/eval/eval_utils.py b/lidm/eval/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bdfef3d58d5462d135d5f4518848530d826a2cad --- /dev/null +++ b/lidm/eval/eval_utils.py @@ -0,0 +1,138 @@ +""" +@Author: Haoxi Ran +@Date: 01/03/2024 +@Citation: Towards Realistic Scene Generation with LiDAR Diffusion Models + +""" +import multiprocessing +from functools import partial + +import numpy as np +from scipy.spatial.distance import jensenshannon +from tqdm import tqdm + +from . import OUTPUT_TEMPLATE +from .metric_utils import compute_logits, compute_pairwise_cd, \ + compute_pairwise_emd, pcd2bev_sum, compute_pairwise_cd_batch, pcd2bev_bin +from .fid_score import calculate_frechet_distance + + +def evaluate(reference, samples, metrics, data): + # perceptual + if 'frid' in metrics: + compute_frid(reference, samples, data) + if 'fsvd' in metrics: + compute_fsvd(reference, samples, data) + if 'fpvd' in metrics: + compute_fpvd(reference, samples, data) + + # reconstruction + if 'cd' in metrics: + compute_cd(reference, samples) + if 'emd' in metrics: + compute_emd(reference, samples) + + # statistical + if 'jsd' in metrics: + compute_jsd(reference, samples, data) + if 'mmd' in metrics: + compute_mmd(reference, samples, data) + + +def compute_cd(reference, samples): + """ + Calculate score of Chamfer Distance (CD) + + """ + print('Evaluating (CD) ...') + results = [] + for x, y in zip(reference, samples): + d = compute_pairwise_cd(x, y) + results.append(d) + score = sum(results) / len(results) + print(OUTPUT_TEMPLATE.format('CD ', score)) + + +def compute_emd(reference, samples): + """ + Calculate score of Earth Mover's Distance (EMD) + + """ + print('Evaluating (EMD) ...') + results = [] + for x, y in zip(reference, samples): + d = compute_pairwise_emd(x, y) + results.append(d) + score = sum(results) / len(results) + print(OUTPUT_TEMPLATE.format('EMD ', score)) + + +def compute_mmd(reference, samples, data, dist='cd', verbose=True): + """ + Calculate the score of Minimum Matching Distance (MMD) + + """ + print('Evaluating (MMD) ...') + assert dist in ['cd', 'emd'] + reference, samples = pcd2bev_bin(data, reference, samples) + compute_dist_func = compute_pairwise_cd_batch if dist == 'cd' else compute_pairwise_emd + results = [] + for r in tqdm(reference, disable=not verbose): + dists = compute_dist_func(r, samples) + results.append(min(dists)) + score = sum(results) / len(results) + print(OUTPUT_TEMPLATE.format('MMD ', score)) + + +def compute_jsd(reference, samples, data): + """ + Calculate the score of Jensen-Shannon Divergence (JSD) + + """ + print('Evaluating (JSD) ...') + reference, samples = pcd2bev_sum(data, reference, samples) + reference = (reference / np.sum(reference)).flatten() + samples = (samples / np.sum(samples)).flatten() + score = jensenshannon(reference, samples) + print(OUTPUT_TEMPLATE.format('JSD ', score)) + + +def compute_fd(reference, samples): + mu1, mu2 = np.mean(reference, axis=0), np.mean(samples, axis=0) + sigma1, sigma2 = np.cov(reference, rowvar=False), np.cov(samples, rowvar=False) + distance = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + return distance + + +def compute_frid(reference, samples, data): + """ + Calculate the score of Fréchet Range Image Distance (FRID) + + """ + print('Evaluating (FRID) ...') + gt_logits, samples_logits = compute_logits(data, 'range', reference, samples) + score = compute_fd(gt_logits, samples_logits) + print(OUTPUT_TEMPLATE.format('FRID', score)) + + +def compute_fsvd(reference, samples, data): + """ + Calculate the score of Fréchet Sparse Volume Distance (FSVD) + + """ + print('Evaluating (FSVD) ...') + gt_logits, samples_logits = compute_logits(data, 'voxel', reference, samples) + score = compute_fd(gt_logits, samples_logits) + print(OUTPUT_TEMPLATE.format('FSVD', score)) + + +def compute_fpvd(reference, samples, data): + """ + Calculate the score of Fréchet Point-based Volume Distance (FPVD) + + """ + print('Evaluating (FPVD) ...') + gt_logits, samples_logits = compute_logits(data, 'point_voxel', reference, samples) + score = compute_fd(gt_logits, samples_logits) + print(OUTPUT_TEMPLATE.format('FPVD', score)) + diff --git a/lidm/eval/fid_score.py b/lidm/eval/fid_score.py new file mode 100644 index 0000000000000000000000000000000000000000..56fbb41329017636216cb6458b9cc82440152986 --- /dev/null +++ b/lidm/eval/fid_score.py @@ -0,0 +1,191 @@ +"""Calculates the Frechet Inception Distance (FID) to evalulate GANs +The FID metric calculates the distance between two distributions of images. +Typically, we have summary statistics (mean & covariance matrix) of one +of these distributions, while the 2nd distribution is given by a GAN. +When run as a stand-alone program, it compares the distribution of +images that are stored as PNG/JPEG at a specified location with a +distribution given by summary statistics (in pickle format). +The FID is calculated by assuming that X_1 and X_2 are the activations of +the pool_3 layer of the inception net for generated samples and real world +samples respectively. +See --help to see further details. +Code adapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead +of Tensorflow +Copyright 2018 Institute of Bioinformatics, JKU Linz +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import pathlib +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import numpy as np +import torch +import torchvision.transforms as TF +from PIL import Image +from scipy import linalg +from torch.nn.functional import adaptive_avg_pool2d + +try: + from tqdm import tqdm +except ImportError: + # If tqdm is not available, provide a mock version of it + def tqdm(x): + return x + +class ImagePathDataset(torch.utils.data.Dataset): + def __init__(self, files, transforms=None): + self.files = files + self.transforms = transforms + + def __len__(self): + return len(self.files) + + def __getitem__(self, i): + path = self.files[i] + img = Image.open(path).convert('RGB') + if self.transforms is not None: + img = self.transforms(img) + return img + + +def get_activations(files, model, batch_size=50, dims=2048, device='cpu', + num_workers=1): + """Calculates the activations of the pool_3 layer for all images. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : Batch size of images for the model to process at once. + Make sure that the number of samples is a multiple of + the batch size, otherwise some samples are ignored. This + behavior is retained to match the original FID score + implementation. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + Returns: + -- A numpy array of dimension (num images, dims) that contains the + activations of the given tensor when feeding inception with the + query tensor. + """ + model.eval() + + if batch_size > len(files): + print(('Warning: batch size is bigger than the data size. ' + 'Setting batch size to data size')) + batch_size = len(files) + + dataset = ImagePathDataset(files, transforms=TF.ToTensor()) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers) + + pred_arr = np.empty((len(files), dims)) + + start_idx = 0 + + for batch in tqdm(dataloader): + batch = batch.to(device) + + with torch.no_grad(): + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + + start_idx = start_idx + pred.shape[0] + + return pred_arr + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def calculate_activation_statistics(files, model, batch_size=50, dims=2048, + device='cpu', num_workers=1): + """Calculation of the statistics used by the FID. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : The images numpy array is split into batches with + batch size batch_size. A reasonable batch size + depends on the hardware. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(files, model, batch_size, dims, device, num_workers) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma diff --git a/lidm/eval/metric_utils.py b/lidm/eval/metric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..30210fc653842b19956a038385a007dccf7993e0 --- /dev/null +++ b/lidm/eval/metric_utils.py @@ -0,0 +1,458 @@ +""" +@Author: Haoxi Ran +@Date: 01/03/2024 +@Citation: Towards Realistic Scene Generation with LiDAR Diffusion Models + +""" + +import math +from itertools import repeat +from typing import List, Tuple, Union +import numpy as np +import torch + +from . import build_model, VOXEL_SIZE, MODALITY2MODEL, MODAL2BATCHSIZE, DATASET_CONFIG, AGG_TYPE, NUM_SECTORS, \ + TYPE2DATASET, DATA_CONFIG + +try: + from torchsparse import SparseTensor, PointTensor + from torchsparse.utils.collate import sparse_collate_fn + from .modules.chamfer3D.dist_chamfer_3D import chamfer_3DDist + from .modules.chamfer2D.dist_chamfer_2D import chamfer_2DDist + from .modules.emd.emd_module import emdModule +except: + print( + 'To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24') + + +def ravel_hash(x: np.ndarray) -> np.ndarray: + assert x.ndim == 2, x.shape + + x = x - np.min(x, axis=0) + x = x.astype(np.uint64, copy=False) + xmax = np.max(x, axis=0).astype(np.uint64) + 1 + + h = np.zeros(x.shape[0], dtype=np.uint64) + for k in range(x.shape[1] - 1): + h += x[:, k] + h *= xmax[k + 1] + h += x[:, -1] + return h + + +def sparse_quantize(coords, voxel_size: Union[float, Tuple[float, ...]] = 1, *, return_index: bool = False, + return_inverse: bool = False) -> List[np.ndarray]: + """ + Modified based on https://github.com/mit-han-lab/torchsparse/blob/462dea4a701f87a7545afb3616bf2cf53dd404f3/torchsparse/utils/quantize.py + + """ + if isinstance(voxel_size, (float, int)): + voxel_size = tuple(repeat(voxel_size, coords.shape[1])) + assert isinstance(voxel_size, tuple) and len(voxel_size) in [2, 3] # support 2D and 3D coordinates only + + voxel_size = np.array(voxel_size) + coords = np.floor(coords / voxel_size).astype(np.int32) + + _, indices, inverse_indices = np.unique( + ravel_hash(coords), return_index=True, return_inverse=True + ) + coords = coords[indices] + + outputs = [coords] + if return_index: + outputs += [indices] + if return_inverse: + outputs += [inverse_indices] + return outputs[0] if len(outputs) == 1 else outputs + + +def pcd2range(pcd, size, fov, depth_range, remission=None, labels=None, **kwargs): + # laser parameters + fov_up = fov[0] / 180.0 * np.pi # field of view up in rad + fov_down = fov[1] / 180.0 * np.pi # field of view down in rad + fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad + + # get depth (distance) of all points + depth = np.linalg.norm(pcd, 2, axis=1) + + # mask points out of range + mask = np.logical_and(depth > depth_range[0], depth < depth_range[1]) + depth, pcd = depth[mask], pcd[mask] + + # get scan components + scan_x, scan_y, scan_z = pcd[:, 0], pcd[:, 1], pcd[:, 2] + + # get angles of all points + yaw = -np.arctan2(scan_y, scan_x) + pitch = np.arcsin(scan_z / depth) + + # get projections in image coords + proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0] + proj_y = 1.0 - (pitch + abs(fov_down)) / fov_range # in [0.0, 1.0] + + # scale to image size using angular resolution + proj_x *= size[1] # in [0.0, W] + proj_y *= size[0] # in [0.0, H] + + # round and clamp for use as index + proj_x = np.maximum(0, np.minimum(size[1] - 1, np.floor(proj_x))).astype(np.int32) # in [0,W-1] + proj_y = np.maximum(0, np.minimum(size[0] - 1, np.floor(proj_y))).astype(np.int32) # in [0,H-1] + + # order in decreasing depth + order = np.argsort(depth)[::-1] + proj_x, proj_y = proj_x[order], proj_y[order] + + # project depth + depth = depth[order] + proj_range = np.full(size, -1, dtype=np.float32) + proj_range[proj_y, proj_x] = depth + + # project point feature + if remission is not None: + remission = remission[mask][order] + proj_feature = np.full(size, -1, dtype=np.float32) + proj_feature[proj_y, proj_x] = remission + elif labels is not None: + labels = labels[mask][order] + proj_feature = np.full(size, 0, dtype=np.float32) + proj_feature[proj_y, proj_x] = labels + else: + proj_feature = None + + return proj_range, proj_feature + + +def range2xyz(range_img, fov, depth_range, depth_scale, log_scale=True, **kwargs): + # laser parameters + size = range_img.shape + fov_up = fov[0] / 180.0 * np.pi # field of view up in rad + fov_down = fov[1] / 180.0 * np.pi # field of view down in rad + fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad + + # inverse transform from depth + if log_scale: + depth = (np.exp2(range_img * depth_scale) - 1) + else: + depth = range_img + + scan_x, scan_y = np.meshgrid(np.arange(size[1]), np.arange(size[0])) + scan_x = scan_x.astype(np.float64) / size[1] + scan_y = scan_y.astype(np.float64) / size[0] + + yaw = np.pi * (scan_x * 2 - 1) + pitch = (1.0 - scan_y) * fov_range - abs(fov_down) + + xyz = -np.ones((3, *size)) + xyz[0] = np.cos(yaw) * np.cos(pitch) * depth + xyz[1] = -np.sin(yaw) * np.cos(pitch) * depth + xyz[2] = np.sin(pitch) * depth + + # mask out invalid points + mask = np.logical_and(depth > depth_range[0], depth < depth_range[1]) + xyz[:, ~mask] = -1 + + return xyz + + +def pcd2voxel(pcd): + pcd_voxel = np.round(pcd / VOXEL_SIZE) + pcd_voxel = pcd_voxel - pcd_voxel.min(0, keepdims=1) + feat = np.concatenate((pcd, -np.ones((pcd.shape[0], 1))), axis=1) # -1 for remission placeholder + _, inds, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True) + + feat = torch.FloatTensor(feat[inds]) + pcd_voxel = torch.LongTensor(pcd_voxel[inds]) + lidar = SparseTensor(feat, pcd_voxel) + output = {'lidar': lidar} + return output + + +def pcd2voxel_full(data_type, *args): + config = DATA_CONFIG[data_type] + x_range, y_range, z_range = config['x'], config['y'], config['z'] + vol_shape = (math.ceil((x_range[1] - x_range[0]) / VOXEL_SIZE), math.ceil((y_range[1] - y_range[0]) / VOXEL_SIZE), + math.ceil((z_range[1] - z_range[0]) / VOXEL_SIZE)) + min_bound = (math.ceil((x_range[0]) / VOXEL_SIZE), math.ceil((y_range[0]) / VOXEL_SIZE), + math.ceil((z_range[0]) / VOXEL_SIZE)) + + output = tuple() + for data in args: + volume_list = [] + for pcd in data: + # mask out invalid points + mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1]) + mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1]) + mask_z = np.logical_and(pcd[:, 2] > z_range[0], pcd[:, 2] < z_range[1]) + mask = mask_x & mask_y & mask_z + pcd = pcd[mask] + + # voxelize + pcd_voxel = np.floor(pcd / VOXEL_SIZE) + _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True) + pcd_voxel = pcd_voxel[indices] + pcd_voxel = (pcd_voxel - min_bound).astype(np.int32) + + # 2D bev grid + vol = np.zeros(vol_shape, dtype=np.float32) + vol[pcd_voxel[:, 0], pcd_voxel[:, 1], pcd_voxel[:, 2]] = 1 + volume_list.append(vol) + output += (volume_list,) + return output + + +# def pcd2bev_full(data_type, *args, voxel_size=VOXEL_SIZE): +# config = DATA_CONFIG[data_type] +# x_range, y_range = config['x'], config['y'] +# vol_shape = (math.ceil((x_range[1] - x_range[0]) / voxel_size), math.ceil((y_range[1] - y_range[0]) / voxel_size)) +# min_bound = (math.ceil((x_range[0]) / voxel_size), math.ceil((y_range[0]) / voxel_size)) +# +# output = tuple() +# for data in args: +# volume_list = [] +# for pcd in data: +# # mask out invalid points +# mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1]) +# mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1]) +# mask = mask_x & mask_y +# pcd = pcd[mask][:, :2] # keep x,y coord +# +# # voxelize +# pcd_voxel = np.floor(pcd / voxel_size) +# _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True) +# pcd_voxel = pcd_voxel[indices] +# pcd_voxel = (pcd_voxel - min_bound).astype(np.int32) +# +# # 2D bev grid +# vol = np.zeros(vol_shape, dtype=np.float32) +# vol[pcd_voxel[:, 0], pcd_voxel[:, 1]] = 1 +# volume_list.append(vol) +# output += (volume_list,) +# return output + + +def pcd2bev_sum(data_type, *args, voxel_size=VOXEL_SIZE): + config = DATA_CONFIG[data_type] + x_range, y_range = config['x'], config['y'] + vol_shape = (math.ceil((x_range[1] - x_range[0]) / voxel_size), math.ceil((y_range[1] - y_range[0]) / voxel_size)) + min_bound = (math.ceil((x_range[0]) / voxel_size), math.ceil((y_range[0]) / voxel_size)) + + output = tuple() + for data in args: + volume_sum = np.zeros(vol_shape, np.float32) + for pcd in data: + # mask out invalid points + mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1]) + mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1]) + mask = mask_x & mask_y + pcd = pcd[mask][:, :2] # keep x,y coord + + # voxelize + pcd_voxel = np.floor(pcd / voxel_size) + _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True) + pcd_voxel = pcd_voxel[indices] + pcd_voxel = (pcd_voxel - min_bound).astype(np.int32) + + # summation + volume_sum[pcd_voxel[:, 0], pcd_voxel[:, 1]] += 1. + output += (volume_sum,) + return output + + +def pcd2bev_bin(data_type, *args, voxel_size=0.5): + config = DATA_CONFIG[data_type] + x_range, y_range = config['x'], config['y'] + vol_shape = (math.ceil((x_range[1] - x_range[0]) / voxel_size), math.ceil((y_range[1] - y_range[0]) / voxel_size)) + min_bound = (math.ceil((x_range[0]) / voxel_size), math.ceil((y_range[0]) / voxel_size)) + + output = tuple() + for data in args: + pcd_list = [] + for pcd in data: + # mask out invalid points + mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1]) + mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1]) + mask = mask_x & mask_y + pcd = pcd[mask][:, :2] # keep x,y coord + + # voxelize + pcd_voxel = np.floor(pcd / voxel_size) + _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True) + pcd_voxel = pcd_voxel[indices] + pcd_voxel = ((pcd_voxel - min_bound) / vol_shape).astype(np.float32) + pcd_list.append(pcd_voxel) + output += (pcd_list,) + return output + + +def bev_sample(data_type, *args, voxel_size=0.5): + config = DATA_CONFIG[data_type] + x_range, y_range = config['x'], config['y'] + + output = tuple() + for data in args: + pcd_list = [] + for pcd in data: + # mask out invalid points + mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1]) + mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1]) + mask = mask_x & mask_y + pcd = pcd[mask][:, :2] # keep x,y coord + + # voxelize + pcd_voxel = np.floor(pcd / voxel_size) + _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True) + pcd = pcd[indices] + pcd_list.append(pcd) + output += (pcd_list,) + return output + + +def preprocess_pcd(pcd, **kwargs): + depth = np.linalg.norm(pcd, 2, axis=1) + mask = np.logical_and(depth > kwargs['depth_range'][0], depth < kwargs['depth_range'][1]) + pcd = pcd[mask] + return pcd + + +def preprocess_range(pcd, **kwargs): + depth_img = pcd2range(pcd, **kwargs)[0] + xyz_img = range2xyz(depth_img, log_scale=False, **kwargs) + depth_img = depth_img[None] + img = np.vstack([depth_img, xyz_img]) + return img + + +def batch2list(batch_dict, agg_type='depth', **kwargs): + """ + Aggregation Type: Default 'depth', ['all', 'sector', 'depth'] + """ + output_list = [] + batch_indices = batch_dict['batch_indices'] + for b_idx in range(batch_indices.max() + 1): + # avg all + if agg_type == 'all': + logits = batch_dict['logits'][batch_indices == b_idx].mean(0) + + # avg on sectors + elif agg_type == 'sector': + logits = batch_dict['logits'][batch_indices == b_idx] + coords = batch_dict['coords'][batch_indices == b_idx].float() + coords = coords - coords.mean(0) + angle = torch.atan2(coords[:, 1], coords[:, 0]) # [-pi, pi] + sector_range = torch.linspace(-np.pi - 1e-4, np.pi + 1e-4, NUM_SECTORS + 1) + logits_list = [] + for i in range(NUM_SECTORS): + sector_indices = torch.where((angle >= sector_range[i]) & (angle < sector_range[i + 1]))[0] + sector_logits = logits[sector_indices].mean(0) + sector_logits = torch.nan_to_num(sector_logits, 0.) + logits_list.append(sector_logits) + logits = torch.cat(logits_list) # dim: 768 + + # avg by depth + elif agg_type == 'depth': + logits = batch_dict['logits'][batch_indices == b_idx] + coords = batch_dict['coords'][batch_indices == b_idx].float() + coords = coords - coords.mean(0) + bev_depth = torch.norm(coords, dim=-1) * VOXEL_SIZE + sector_range = torch.linspace(kwargs['depth_range'][0] + 3, kwargs['depth_range'][1], NUM_SECTORS + 1) + sector_range[0] = 0. + logits_list = [] + for i in range(NUM_SECTORS): + sector_indices = torch.where((bev_depth >= sector_range[i]) & (bev_depth < sector_range[i + 1]))[0] + sector_logits = logits[sector_indices].mean(0) + sector_logits = torch.nan_to_num(sector_logits, 0.) + logits_list.append(sector_logits) + logits = torch.cat(logits_list) # dim: 768 + + else: + raise NotImplementedError + + output_list.append(logits.detach().cpu().numpy()) + return output_list + + +def compute_logits(data_type, modality, *args): + assert data_type in ['32', '64'] + assert modality in ['range', 'voxel', 'point_voxel'] + is_voxel = 'voxel' in modality + dataset_name = TYPE2DATASET[data_type] + dataset_config = DATASET_CONFIG[dataset_name] + bs = MODAL2BATCHSIZE[modality] + + model = build_model(dataset_name, MODALITY2MODEL[modality], device='cuda') + + output = tuple() + for data in args: + all_logits_list = [] + for i in range(math.ceil(len(data) / bs)): + batch = data[i * bs:(i + 1) * bs] + if is_voxel: + batch = [pcd2voxel(preprocess_pcd(pcd, **dataset_config)) for pcd in batch] + batch = sparse_collate_fn(batch) + batch = {k: v.cuda() if isinstance(v, (torch.Tensor, SparseTensor, PointTensor)) else v for k, v in + batch.items()} + with torch.no_grad(): + batch_out = model(batch, return_final_logits=True) + batch_out = batch2list(batch_out, AGG_TYPE, **dataset_config) + all_logits_list.extend(batch_out) + else: + batch = [preprocess_range(pcd, **dataset_config) for pcd in batch] + batch = torch.from_numpy(np.stack(batch)).float().cuda() + with torch.no_grad(): + batch_out = model(batch, return_final_logits=True, agg_type=AGG_TYPE) + all_logits_list.append(batch_out) + if is_voxel: + all_logits = np.stack(all_logits_list) + else: + all_logits = np.vstack(all_logits_list) + output += (all_logits,) + + del model, batch, batch_out + torch.cuda.empty_cache() + return output + + +def compute_pairwise_cd(x, y, module=None): + if module is None: + module = chamfer_3DDist() + if x.ndim == 2 and y.ndim == 2: + x, y = x[None], y[None] + x, y = torch.from_numpy(x).cuda(), torch.from_numpy(y).cuda() + dist1, dist2, _, _ = module(x, y) + dist = (dist1.mean() + dist2.mean()) / 2 + return dist.item() + + +def compute_pairwise_cd_batch(reference, samples): + ndim = reference.ndim + assert ndim in [2, 3] + module = chamfer_3DDist() if ndim == 3 else chamfer_2DDist() + len_r, len_s = reference.shape[0], [s.shape[0] for s in samples] + max_len = max([len_r] + len_s) + reference = torch.from_numpy( + np.vstack([reference, np.ones((max_len - reference.shape[0], ndim), dtype=np.float32) * 1e6])).cuda() + samples = [np.vstack([s, np.ones((max_len - s.shape[0], ndim), dtype=np.float32) * 1e6]) for s in samples] + samples = torch.from_numpy(np.stack(samples)).cuda() + reference = reference.expand_as(samples) + dist_r, dist_s, _, _ = module(reference, samples) + + results = [] + for i in range(samples.shape[0]): + dist1, dist2, len1, len2 = dist_r[i], dist_s[i], len_r, len_s[i] + dist = (dist1[:len1].mean() + dist2[:len2].mean()) / 2. + results.append(dist.item()) + return results + + +def compute_pairwise_emd(x, y, module=None): + if module is None: + module = emdModule() + n_points = min(x.shape[0], y.shape[0]) + n_points = n_points - n_points % 1024 + x, y = x[:n_points], y[:n_points] + if x.ndim == 2 and y.ndim == 2: + x, y = x[None], y[None] + x, y = torch.from_numpy(x).cuda(), torch.from_numpy(y).cuda() + dist, _ = module(x, y, 0.005, 50) + dist = torch.sqrt(dist).mean() + return dist.item() diff --git a/lidm/eval/models/__init__.py b/lidm/eval/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/models/minkowskinet/__init__.py b/lidm/eval/models/minkowskinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/models/minkowskinet/model.py b/lidm/eval/models/minkowskinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..daa36a5009e96f3331653341c6b185627d688a19 --- /dev/null +++ b/lidm/eval/models/minkowskinet/model.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn + +try: + import torchsparse + import torchsparse.nn as spnn + from ..ts import basic_blocks +except ImportError: + raise Exception('Required ts lib. Reference: https://github.com/mit-han-lab/torchsparse/tree/v1.4.0') + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + + cr = config.model_params.cr + cs = config.model_params.layer_num + cs = [int(cr * x) for x in cs] + + self.pres = self.vres = config.model_params.voxel_size + self.num_classes = config.model_params.num_class + + self.stem = nn.Sequential( + spnn.Conv3d(config.model_params.input_dims, cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True), + spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True)) + + self.stage1 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), + ) + + self.stage2 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), + ) + + self.stage3 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), + ) + + self.stage4 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), + ) + + self.up1 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1), + ) + ]) + + self.up2 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1), + ) + ]) + + self.up3 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1), + ) + ]) + + self.up4 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1), + ) + ]) + + self.classifier = nn.Sequential(nn.Linear(cs[8], self.num_classes)) + + self.weight_initialization() + self.dropout = nn.Dropout(0.3, True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, data_dict, return_logits=False, return_final_logits=False): + x = data_dict['lidar'] + x.C = x.C.int() + + x0 = self.stem(x) + x1 = self.stage1(x0) + x2 = self.stage2(x1) + x3 = self.stage3(x2) + x4 = self.stage4(x3) + + if return_logits: + output_dict = dict() + output_dict['logits'] = x4.F + output_dict['batch_indices'] = x4.C[:, -1] + return output_dict + + y1 = self.up1[0](x4) + y1 = torchsparse.cat([y1, x3]) + y1 = self.up1[1](y1) + + y2 = self.up2[0](y1) + y2 = torchsparse.cat([y2, x2]) + y2 = self.up2[1](y2) + + y3 = self.up3[0](y2) + y3 = torchsparse.cat([y3, x1]) + y3 = self.up3[1](y3) + + y4 = self.up4[0](y3) + y4 = torchsparse.cat([y4, x0]) + y4 = self.up4[1](y4) + if return_final_logits: + output_dict = dict() + output_dict['logits'] = y4.F + output_dict['coords'] = y4.C[:, :3] + output_dict['batch_indices'] = y4.C[:, -1] + return output_dict + + output = self.classifier(y4.F) + data_dict['output'] = output.F + + return data_dict diff --git a/lidm/eval/models/rangenet/__init__.py b/lidm/eval/models/rangenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/models/rangenet/model.py b/lidm/eval/models/rangenet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..752fae9effd476a6bff255e0063675d1cc2f72e2 --- /dev/null +++ b/lidm/eval/models/rangenet/model.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +# This file is covered by the LICENSE file in the root of this project. +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, bn_d=0.1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, + stride=1, padding=0, bias=False) + self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d) + self.relu1 = nn.LeakyReLU(0.1) + self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d) + self.relu2 = nn.LeakyReLU(0.1) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu2(out) + + out += residual + return out + + +# ****************************************************************************** + +# number of layers per model +model_blocks = { + 21: [1, 1, 2, 2, 1], + 53: [1, 2, 8, 8, 4], +} + + +class Backbone(nn.Module): + """ + Class for DarknetSeg. Subclasses PyTorch's own "nn" module + """ + + def __init__(self, params): + super(Backbone, self).__init__() + self.use_range = params["input_depth"]["range"] + self.use_xyz = params["input_depth"]["xyz"] + self.use_remission = params["input_depth"]["remission"] + self.drop_prob = params["dropout"] + self.bn_d = params["bn_d"] + self.OS = params["OS"] + self.layers = params["extra"]["layers"] + + # input depth calc + self.input_depth = 0 + self.input_idxs = [] + if self.use_range: + self.input_depth += 1 + self.input_idxs.append(0) + if self.use_xyz: + self.input_depth += 3 + self.input_idxs.extend([1, 2, 3]) + if self.use_remission: + self.input_depth += 1 + self.input_idxs.append(4) + + # stride play + self.strides = [2, 2, 2, 2, 2] + # check current stride + current_os = 1 + for s in self.strides: + current_os *= s + + # make the new stride + if self.OS > current_os: + print("Can't do OS, ", self.OS, + " because it is bigger than original ", current_os) + else: + # redo strides according to needed stride + for i, stride in enumerate(reversed(self.strides), 0): + if int(current_os) != self.OS: + if stride == 2: + current_os /= 2 + self.strides[-1 - i] = 1 + if int(current_os) == self.OS: + break + + # check that darknet exists + assert self.layers in model_blocks.keys() + + # generate layers depending on darknet type + self.blocks = model_blocks[self.layers] + + # input layer + self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d) + self.relu1 = nn.LeakyReLU(0.1) + + # encoder + self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0], + stride=self.strides[0], bn_d=self.bn_d) + self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1], + stride=self.strides[1], bn_d=self.bn_d) + self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2], + stride=self.strides[2], bn_d=self.bn_d) + self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3], + stride=self.strides[3], bn_d=self.bn_d) + self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4], + stride=self.strides[4], bn_d=self.bn_d) + + # for a bit of fun + self.dropout = nn.Dropout2d(self.drop_prob) + + # last channels + self.last_channels = 1024 + + # make layer useful function + def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1): + layers = [] + + # downsample + layers.append(("conv", nn.Conv2d(planes[0], planes[1], + kernel_size=3, + stride=[1, stride], dilation=1, + padding=1, bias=False))) + layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) + layers.append(("relu", nn.LeakyReLU(0.1))) + + # blocks + inplanes = planes[1] + for i in range(0, blocks): + layers.append(("residual_{}".format(i), + block(inplanes, planes, bn_d))) + + return nn.Sequential(OrderedDict(layers)) + + def run_layer(self, x, layer, skips, os): + y = layer(x) + if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]: + skips[os] = x.detach() + os *= 2 + x = y + return x, skips, os + + def forward(self, x, return_logits=False, return_list=None): + # filter input + x = x[:, self.input_idxs] + + # run cnn + # store for skip connections + skips = {} + out_dict = {} + os = 1 + + # first layer + x, skips, os = self.run_layer(x, self.conv1, skips, os) + x, skips, os = self.run_layer(x, self.bn1, skips, os) + x, skips, os = self.run_layer(x, self.relu1, skips, os) + if return_list and 'enc_0' in return_list: + out_dict['enc_0'] = x.detach().cpu() # 32, 64, 1024 + + # all encoder blocks with intermediate dropouts + x, skips, os = self.run_layer(x, self.enc1, skips, os) + if return_list and 'enc_1' in return_list: + out_dict['enc_1'] = x.detach().cpu() # 64, 64, 512 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc2, skips, os) + if return_list and 'enc_2' in return_list: + out_dict['enc_2'] = x.detach().cpu() # 128, 64, 256 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc3, skips, os) + if return_list and 'enc_3' in return_list: + out_dict['enc_3'] = x.detach().cpu() # 256, 64, 128 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc4, skips, os) + if return_list and 'enc_4' in return_list: + out_dict['enc_4'] = x.detach().cpu() # 512, 64, 64 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc5, skips, os) + if return_list and 'enc_5' in return_list: + out_dict['enc_5'] = x.detach().cpu() # 1024, 64, 32 + if return_logits: + return x + + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + if return_list is not None: + return x, skips, out_dict + return x, skips + + def get_last_depth(self): + return self.last_channels + + def get_input_depth(self): + return self.input_depth + + +class Decoder(nn.Module): + """ + Class for DarknetSeg. Subclasses PyTorch's own "nn" module + """ + + def __init__(self, params, OS=32, feature_depth=1024): + super(Decoder, self).__init__() + self.backbone_OS = OS + self.backbone_feature_depth = feature_depth + self.drop_prob = params["dropout"] + self.bn_d = params["bn_d"] + self.index = 0 + + # stride play + self.strides = [2, 2, 2, 2, 2] + # check current stride + current_os = 1 + for s in self.strides: + current_os *= s + # redo strides according to needed stride + for i, stride in enumerate(self.strides): + if int(current_os) != self.backbone_OS: + if stride == 2: + current_os /= 2 + self.strides[i] = 1 + if int(current_os) == self.backbone_OS: + break + + # decoder + self.dec5 = self._make_dec_layer(BasicBlock, + [self.backbone_feature_depth, 512], + bn_d=self.bn_d, + stride=self.strides[0]) + self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d, + stride=self.strides[1]) + self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d, + stride=self.strides[2]) + self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d, + stride=self.strides[3]) + self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d, + stride=self.strides[4]) + + # layer list to execute with skips + self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1] + + # for a bit of fun + self.dropout = nn.Dropout2d(self.drop_prob) + + # last channels + self.last_channels = 32 + + def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2): + layers = [] + + # downsample + if stride == 2: + layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1], + kernel_size=[1, 4], stride=[1, 2], + padding=[0, 1]))) + else: + layers.append(("conv", nn.Conv2d(planes[0], planes[1], + kernel_size=3, padding=1))) + layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) + layers.append(("relu", nn.LeakyReLU(0.1))) + + # blocks + layers.append(("residual", block(planes[1], planes, bn_d))) + + return nn.Sequential(OrderedDict(layers)) + + def run_layer(self, x, layer, skips, os): + feats = layer(x) # up + if feats.shape[-1] > x.shape[-1]: + os //= 2 # match skip + feats = feats + skips[os].detach() # add skip + x = feats + return x, skips, os + + def forward(self, x, skips, return_logits=False, return_list=None): + os = self.backbone_OS + out_dict = {} + + # run layers + x, skips, os = self.run_layer(x, self.dec5, skips, os) + if return_list and 'dec_4' in return_list: + out_dict['dec_4'] = x.detach().cpu() # 512, 64, 64 + x, skips, os = self.run_layer(x, self.dec4, skips, os) + if return_list and 'dec_3' in return_list: + out_dict['dec_3'] = x.detach().cpu() # 256, 64, 128 + x, skips, os = self.run_layer(x, self.dec3, skips, os) + if return_list and 'dec_2' in return_list: + out_dict['dec_2'] = x.detach().cpu() # 128, 64, 256 + x, skips, os = self.run_layer(x, self.dec2, skips, os) + if return_list and 'dec_1' in return_list: + out_dict['dec_1'] = x.detach().cpu() # 64, 64, 512 + x, skips, os = self.run_layer(x, self.dec1, skips, os) + if return_list and 'dec_0' in return_list: + out_dict['dec_0'] = x.detach().cpu() # 32, 64, 1024 + + logits = torch.clone(x).detach() + x = self.dropout(x) + + if return_logits: + return x, logits + if return_list is not None: + return out_dict + return x + + def get_last_depth(self): + return self.last_channels + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = Backbone(params=self.config["backbone"]) + self.decoder = Decoder(params=self.config["decoder"], OS=self.config["backbone"]["OS"], + feature_depth=self.backbone.get_last_depth()) + + def load_pretrained_weights(self, path): + w_dict = torch.load(path + "/backbone", + map_location=lambda storage, loc: storage) + self.backbone.load_state_dict(w_dict, strict=True) + w_dict = torch.load(path + "/segmentation_decoder", + map_location=lambda storage, loc: storage) + self.decoder.load_state_dict(w_dict, strict=True) + + def forward(self, x, return_logits=False, return_final_logits=False, return_list=None, agg_type='depth'): + if return_logits: + logits = self.backbone(x, return_logits) + logits = F.adaptive_avg_pool2d(logits, (1, 1)).squeeze() + logits = torch.clone(logits).detach().cpu().numpy() + return logits + elif return_list is not None: + x, skips, enc_dict = self.backbone(x, return_list=return_list) + dec_dict = self.decoder(x, skips, return_list=return_list) + out_dict = {**enc_dict, **dec_dict} + return out_dict + elif return_final_logits: + assert agg_type in ['all', 'sector', 'depth'] + y, skips = self.backbone(x) + y, logits = self.decoder(y, skips, True) + + B, C, H, W = logits.shape + N = 16 + + # avg all + if agg_type == 'all': + logits = logits.mean([2, 3]) + # avg in patch + elif agg_type == 'sector': + logits = logits.view(B, C, H, N, W // N).mean([2, 4]).reshape(B, -1) + # avg in row + elif agg_type == 'depth': + logits = logits.view(B, C, N, H // N, W).mean([3, 4]).reshape(B, -1) + + logits = torch.clone(logits).detach().cpu().numpy() + return logits + else: + y, skips = self.backbone(x) + y = self.decoder(y, skips, False) + return y diff --git a/lidm/eval/models/spvcnn/__init__.py b/lidm/eval/models/spvcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/models/spvcnn/model.py b/lidm/eval/models/spvcnn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7793f8e81cef35f331c0c2c70062023999e83a --- /dev/null +++ b/lidm/eval/models/spvcnn/model.py @@ -0,0 +1,179 @@ +import torch.nn as nn + +try: + import torchsparse + import torchsparse.nn as spnn + from torchsparse import PointTensor + from ..ts.utils import initial_voxelize, point_to_voxel, voxel_to_point + from ..ts import basic_blocks +except ImportError: + raise Exception('Required torchsparse lib. Reference: https://github.com/mit-han-lab/torchsparse/tree/v1.4.0') + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + cr = config.model_params.cr + cs = config.model_params.layer_num + cs = [int(cr * x) for x in cs] + + self.pres = self.vres = config.model_params.voxel_size + self.num_classes = config.model_params.num_class + + self.stem = nn.Sequential( + spnn.Conv3d(config.model_params.input_dims, cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True), + spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True)) + + self.stage1 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), + ) + + self.stage2 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), + ) + + self.stage3 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), + ) + + self.stage4 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), + ) + + self.up1 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1), + ) + ]) + + self.up2 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1), + ) + ]) + + self.up3 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1), + ) + ]) + + self.up4 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1), + ) + ]) + + self.classifier = nn.Sequential(nn.Linear(cs[8], self.num_classes)) + + self.point_transforms = nn.ModuleList([ + nn.Sequential( + nn.Linear(cs[0], cs[4]), + nn.BatchNorm1d(cs[4]), + nn.ReLU(True), + ), + nn.Sequential( + nn.Linear(cs[4], cs[6]), + nn.BatchNorm1d(cs[6]), + nn.ReLU(True), + ), + nn.Sequential( + nn.Linear(cs[6], cs[8]), + nn.BatchNorm1d(cs[8]), + nn.ReLU(True), + ) + ]) + + self.weight_initialization() + self.dropout = nn.Dropout(0.3, True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, data_dict, return_logits=False, return_final_logits=False): + x = data_dict['lidar'] + + # x: SparseTensor z: PointTensor + z = PointTensor(x.F, x.C.float()) + + x0 = initial_voxelize(z, self.pres, self.vres) + + x0 = self.stem(x0) + z0 = voxel_to_point(x0, z, nearest=False) + z0.F = z0.F + + x1 = point_to_voxel(x0, z0) + x1 = self.stage1(x1) + x2 = self.stage2(x1) + x3 = self.stage3(x2) + x4 = self.stage4(x3) + z1 = voxel_to_point(x4, z0) + z1.F = z1.F + self.point_transforms[0](z0.F) + + y1 = point_to_voxel(x4, z1) + + if return_logits: + output_dict = dict() + output_dict['logits'] = y1.F + output_dict['batch_indices'] = y1.C[:, -1] + return output_dict + + y1.F = self.dropout(y1.F) + y1 = self.up1[0](y1) + y1 = torchsparse.cat([y1, x3]) + y1 = self.up1[1](y1) + + y2 = self.up2[0](y1) + y2 = torchsparse.cat([y2, x2]) + y2 = self.up2[1](y2) + z2 = voxel_to_point(y2, z1) + z2.F = z2.F + self.point_transforms[1](z1.F) + + y3 = point_to_voxel(y2, z2) + y3.F = self.dropout(y3.F) + y3 = self.up3[0](y3) + y3 = torchsparse.cat([y3, x1]) + y3 = self.up3[1](y3) + + y4 = self.up4[0](y3) + y4 = torchsparse.cat([y4, x0]) + y4 = self.up4[1](y4) + z3 = voxel_to_point(y4, z2) + z3.F = z3.F + self.point_transforms[2](z2.F) + + if return_final_logits: + output_dict = dict() + output_dict['logits'] = z3.F + output_dict['coords'] = z3.C[:, :3] + output_dict['batch_indices'] = z3.C[:, -1].long() + return output_dict + + # output = self.classifier(z3.F) + data_dict['logits'] = z3.F + + return data_dict diff --git a/lidm/eval/models/ts/__init__.py b/lidm/eval/models/ts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/models/ts/basic_blocks.py b/lidm/eval/models/ts/basic_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..a18acc8eba8ad5f62ad6edb2cc02852a1536d0a3 --- /dev/null +++ b/lidm/eval/models/ts/basic_blocks.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# encoding: utf-8 +''' +@author: Xu Yan +@file: basic_blocks.py +@time: 2021/4/14 22:53 +''' +import torch.nn as nn + +try: + import torchsparse.nn as spnn +except: + print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24') + + +class BasicConvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d( + inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + out = self.net(x) + return out + + +class BasicDeconvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d( + inc, + outc, + kernel_size=ks, + stride=stride, + transposed=True), + spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + return self.net(x) + + +class ResidualBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d( + inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), spnn.BatchNorm(outc), + spnn.ReLU(True), + spnn.Conv3d( + outc, + outc, + kernel_size=ks, + dilation=dilation, + stride=1), + spnn.BatchNorm(outc)) + + self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ + nn.Sequential( + spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), + spnn.BatchNorm(outc) + ) + + self.ReLU = spnn.ReLU(True) + + def forward(self, x): + out = self.ReLU(self.net(x) + self.downsample(x)) + return out diff --git a/lidm/eval/models/ts/utils.py b/lidm/eval/models/ts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a01c0e54645ceb91fa76ed18bdae552f7fbf73 --- /dev/null +++ b/lidm/eval/models/ts/utils.py @@ -0,0 +1,90 @@ +import torch + +try: + import torchsparse.nn.functional as F + from torchsparse import PointTensor, SparseTensor + from torchsparse.nn.utils import get_kernel_offsets +except: + print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24') + +__all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point'] + + +# z: PointTensor +# return: SparseTensor +def initial_voxelize(z, init_res, after_res): + new_float_coord = torch.cat([(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) + + pc_hash = F.sphash(torch.floor(new_float_coord).int()) + sparse_hash = torch.unique(pc_hash) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), len(sparse_hash)) + + inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, counts) + inserted_coords = torch.round(inserted_coords).int() + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + + new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) + new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) + z.additional_features['idx_query'][1] = idx_query + z.additional_features['counts'][1] = counts + z.C = new_float_coord + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: SparseTensor +def point_to_voxel(x, z): + if z.additional_features is None or \ + z.additional_features.get('idx_query') is None or \ + z.additional_features['idx_query'].get(x.s) is None: + pc_hash = F.sphash( + torch.cat([torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], z.C[:, -1].int().view(-1, 1)], 1)) + sparse_hash = F.sphash(x.C) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), x.C.shape[0]) + z.additional_features['idx_query'][x.s] = idx_query + z.additional_features['counts'][x.s] = counts + else: + idx_query = z.additional_features['idx_query'][x.s] + counts = z.additional_features['counts'][x.s] + + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + new_tensor = SparseTensor(inserted_feat, x.C, x.s) + new_tensor.cmaps = x.cmaps + new_tensor.kmaps = x.kmaps + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: PointTensor +def voxel_to_point(x, z, nearest=False): + if z.idx_query is None or z.weights is None or z.idx_query.get(x.s) is None or z.weights.get(x.s) is None: + off = get_kernel_offsets(2, x.s, 1, device=z.F.device) + old_hash = F.sphash( + torch.cat([ + torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], + z.C[:, -1].int().view(-1, 1)], 1), off) + pc_hash = F.sphash(x.C.to(z.F.device)) + idx_query = F.sphashquery(old_hash, pc_hash) + weights = F.calc_ti_weights(z.C, idx_query, scale=x.s[0]).transpose(0, 1).contiguous() + idx_query = idx_query.transpose(0, 1).contiguous() + if nearest: + weights[:, 1:] = 0. + idx_query[:, 1:] = -1 + new_feat = F.spdevoxelize(x.F, idx_query, weights) + new_tensor = PointTensor(new_feat, z.C, idx_query=z.idx_query, weights=z.weights) + new_tensor.additional_features = z.additional_features + new_tensor.idx_query[x.s] = idx_query + new_tensor.weights[x.s] = weights + z.idx_query[x.s] = idx_query + z.weights[x.s] = weights + + else: + new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s)) + new_tensor = PointTensor(new_feat, z.C, idx_query=z.idx_query, weights=z.weights) + new_tensor.additional_features = z.additional_features + + return new_tensor \ No newline at end of file diff --git a/lidm/eval/modules/__init__.py b/lidm/eval/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/modules/chamfer2D/__init__.py b/lidm/eval/modules/chamfer2D/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/modules/chamfer2D/chamfer2D.cu b/lidm/eval/modules/chamfer2D/chamfer2D.cu new file mode 100755 index 0000000000000000000000000000000000000000..567dd1a0c041f0e11476e1e59bc65198ac227e04 --- /dev/null +++ b/lidm/eval/modules/chamfer2D/chamfer2D.cu @@ -0,0 +1,182 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*2]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} + diff --git a/lidm/eval/modules/chamfer2D/chamfer_cuda.cpp b/lidm/eval/modules/chamfer2D/chamfer_cuda.cpp new file mode 100755 index 0000000000000000000000000000000000000000..67574e21818ae9388f44f964d606b2f00355865e --- /dev/null +++ b/lidm/eval/modules/chamfer2D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/lidm/eval/modules/chamfer2D/dist_chamfer_2D.py b/lidm/eval/modules/chamfer2D/dist_chamfer_2D.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ecedbc731b23895ed2d6fe3fe48aaf632940da --- /dev/null +++ b/lidm/eval/modules/chamfer2D/dist_chamfer_2D.py @@ -0,0 +1,84 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os + +chamfer_found = importlib.find_loader("chamfer_2D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 2D") + cur_path = os.path.dirname(os.path.abspath(__file__)) + build_path = cur_path.replace('chamfer2D', 'tmp') + os.makedirs(build_path, exist_ok=True) + + from torch.utils.cpp_extension import load + + chamfer_2D = load(name="chamfer_2D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), + ], build_directory=build_path) + print("Loaded JIT 2D CUDA chamfer distance") + +else: + import chamfer_2D + + print("Loaded compiled 2D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_2DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, dim = xyz1.size() + assert dim == 2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + _, m, dim = xyz2.size() + assert dim == 2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + device = xyz1.device + + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_2D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_2DDist(nn.Module): + def __init__(self): + super(chamfer_2DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_2DFunction.apply(input1, input2) diff --git a/lidm/eval/modules/chamfer2D/setup.py b/lidm/eval/modules/chamfer2D/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..11d01237f57ad386dd88adda4cc869a53f94f4f2 --- /dev/null +++ b/lidm/eval/modules/chamfer2D/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_2D', + ext_modules=[ + CUDAExtension('chamfer_2D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/lidm/eval/modules/chamfer3D/__init__.py b/lidm/eval/modules/chamfer3D/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/modules/chamfer3D/chamfer3D.cu b/lidm/eval/modules/chamfer3D/chamfer3D.cu new file mode 100644 index 0000000000000000000000000000000000000000..d5b886dff11733be30519247d1fdb784818bff4a --- /dev/null +++ b/lidm/eval/modules/chamfer3D/chamfer3D.cu @@ -0,0 +1,196 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} + diff --git a/lidm/eval/modules/chamfer3D/chamfer_cuda.cpp b/lidm/eval/modules/chamfer3D/chamfer_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..67574e21818ae9388f44f964d606b2f00355865e --- /dev/null +++ b/lidm/eval/modules/chamfer3D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/lidm/eval/modules/chamfer3D/dist_chamfer_3D.py b/lidm/eval/modules/chamfer3D/dist_chamfer_3D.py new file mode 100644 index 0000000000000000000000000000000000000000..30063e50293445f0785f1cb7fd05719a1a29b172 --- /dev/null +++ b/lidm/eval/modules/chamfer3D/dist_chamfer_3D.py @@ -0,0 +1,76 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os + +chamfer_found = importlib.find_loader("chamfer_3D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 3D") + + from torch.utils.cpp_extension import load + + chamfer_3D = load(name="chamfer_3D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), + ]) + print("Loaded JIT 3D CUDA chamfer distance") + +else: + import chamfer_3D + print("Loaded compiled 3D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_3DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, _ = xyz1.size() + _, m, _ = xyz2.size() + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_3D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_3DDist(nn.Module): + def __init__(self): + super(chamfer_3DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_3DFunction.apply(input1, input2) diff --git a/lidm/eval/modules/chamfer3D/setup.py b/lidm/eval/modules/chamfer3D/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9a23aadadde026eb8c3db68a43d63086f6be856a --- /dev/null +++ b/lidm/eval/modules/chamfer3D/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_3D', + ext_modules=[ + CUDAExtension('chamfer_3D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/lidm/eval/modules/emd/__init__.py b/lidm/eval/modules/emd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/eval/modules/emd/emd.cpp b/lidm/eval/modules/emd/emd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5036016370aa5f148b17ec44ea7a30bc576c3a82 --- /dev/null +++ b/lidm/eval/modules/emd/emd.cpp @@ -0,0 +1,31 @@ +// EMD approximation module (based on auction algorithm) +// author: Minghua Liu +#include +#include + +int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, + at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, + at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters); + +int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx); + + + +int emd_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, + at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, + at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) { + return emd_cuda_forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters); +} + +int emd_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx) { + + return emd_cuda_backward(xyz1, xyz2, gradxyz, graddist, idx); +} + + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &emd_forward, "emd forward (CUDA)"); + m.def("backward", &emd_backward, "emd backward (CUDA)"); +} \ No newline at end of file diff --git a/lidm/eval/modules/emd/emd_cuda.cu b/lidm/eval/modules/emd/emd_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..08999bbac057bd09ed7668b71e2f691e4f06fd6b --- /dev/null +++ b/lidm/eval/modules/emd/emd_cuda.cu @@ -0,0 +1,316 @@ +// EMD approximation module (based on auction algorithm) +// author: Minghua Liu +#include +#include + +#include +#include +#include + +__device__ __forceinline__ float atomicMax(float *address, float val) +{ + int ret = __float_as_int(*address); + while(val > __int_as_float(ret)) + { + int old = ret; + if((ret = atomicCAS((int *)address, old, __float_as_int(val))) == old) + break; + } + return __int_as_float(ret); +} + + +__global__ void clear(int b, int * cnt_tmp, int * unass_cnt) { + for (int i = threadIdx.x; i < b; i += blockDim.x) { + cnt_tmp[i] = 0; + unass_cnt[i] = 0; + } +} + +__global__ void calc_unass_cnt(int b, int n, int * assignment, int * unass_cnt) { + // count the number of unassigned points in each batch + const int BLOCK_SIZE = 1024; + __shared__ int scan_array[BLOCK_SIZE]; + for (int i = blockIdx.x; i < b; i += gridDim.x) { + scan_array[threadIdx.x] = assignment[i * n + blockIdx.y * BLOCK_SIZE + threadIdx.x] == -1 ? 1 : 0; + __syncthreads(); + + int stride = 1; + while(stride <= BLOCK_SIZE / 2) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if(index < BLOCK_SIZE) + scan_array[index] += scan_array[index - stride]; + stride = stride * 2; + __syncthreads(); + } + __syncthreads(); + + if (threadIdx.x == BLOCK_SIZE - 1) { + atomicAdd(&unass_cnt[i], scan_array[threadIdx.x]); + } + __syncthreads(); + } +} + +__global__ void calc_unass_cnt_sum(int b, int * unass_cnt, int * unass_cnt_sum) { + // count the cumulative sum over over unass_cnt + const int BLOCK_SIZE = 512; // batch_size <= 512 + __shared__ int scan_array[BLOCK_SIZE]; + scan_array[threadIdx.x] = unass_cnt[threadIdx.x]; + __syncthreads(); + + int stride = 1; + while(stride <= BLOCK_SIZE / 2) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if(index < BLOCK_SIZE) + scan_array[index] += scan_array[index - stride]; + stride = stride * 2; + __syncthreads(); + } + __syncthreads(); + stride = BLOCK_SIZE / 4; + while(stride > 0) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if((index + stride) < BLOCK_SIZE) + scan_array[index + stride] += scan_array[index]; + stride = stride / 2; + __syncthreads(); + } + __syncthreads(); + + //printf("%d\n", unass_cnt_sum[b - 1]); + unass_cnt_sum[threadIdx.x] = scan_array[threadIdx.x]; +} + +__global__ void calc_unass_idx(int b, int n, int * assignment, int * unass_idx, int * unass_cnt, int * unass_cnt_sum, int * cnt_tmp) { + // list all the unassigned points + for (int i = blockIdx.x; i < b; i += gridDim.x) { + if (assignment[i * n + blockIdx.y * 1024 + threadIdx.x] == -1) { + int idx = atomicAdd(&cnt_tmp[i], 1); + unass_idx[unass_cnt_sum[i] - unass_cnt[i] + idx] = blockIdx.y * 1024 + threadIdx.x; + } + } +} + +__global__ void Bid(int b, int n, const float * xyz1, const float * xyz2, float eps, int * assignment, int * assignment_inv, float * price, + int * bid, float * bid_increments, float * max_increments, int * unass_cnt, int * unass_cnt_sum, int * unass_idx) { + const int batch = 2048, block_size = 1024, block_cnt = n / 1024; + __shared__ float xyz2_buf[batch * 3]; + __shared__ float price_buf[batch]; + __shared__ float best_buf[block_size]; + __shared__ float better_buf[block_size]; + __shared__ int best_i_buf[block_size]; + for (int i = blockIdx.x; i < b; i += gridDim.x) { + int _unass_cnt = unass_cnt[i]; + if (_unass_cnt == 0) + continue; + int _unass_cnt_sum = unass_cnt_sum[i]; + int unass_per_block = (_unass_cnt + block_cnt - 1) / block_cnt; + int thread_per_unass = block_size / unass_per_block; + int unass_this_block = max(min(_unass_cnt - (int) blockIdx.y * unass_per_block, unass_per_block), 0); + + float x1, y1, z1, best = -1e9, better = -1e9; + int best_i = -1, _unass_id = -1, thread_in_unass; + + if (threadIdx.x < thread_per_unass * unass_this_block) { + _unass_id = unass_per_block * blockIdx.y + threadIdx.x / thread_per_unass + _unass_cnt_sum - _unass_cnt; + _unass_id = unass_idx[_unass_id]; + thread_in_unass = threadIdx.x % thread_per_unass; + + x1 = xyz1[(i * n + _unass_id) * 3 + 0]; + y1 = xyz1[(i * n + _unass_id) * 3 + 1]; + z1 = xyz1[(i * n + _unass_id) * 3 + 2]; + } + + for (int k2 = 0; k2 < n; k2 += batch) { + int end_k = min(n, k2 + batch) - k2; + for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) { + xyz2_buf[j] = xyz2[(i * n + k2) * 3 + j]; + } + for (int j = threadIdx.x; j < end_k; j += blockDim.x) { + price_buf[j] = price[i * n + k2 + j]; + } + __syncthreads(); + + if (_unass_id != -1) { + int delta = (end_k + thread_per_unass - 1) / thread_per_unass; + int l = thread_in_unass * delta; + int r = min((thread_in_unass + 1) * delta, end_k); + for (int k = l; k < r; k++) + //if (!last || assignment_inv[i * n + k + k2] == -1) + { + float x2 = xyz2_buf[k * 3 + 0] - x1; + float y2 = xyz2_buf[k * 3 + 1] - y1; + float z2 = xyz2_buf[k * 3 + 2] - z1; + // the coordinates of points should be normalized to [0, 1] + float d = 3.0 - sqrtf(x2 * x2 + y2 * y2 + z2 * z2) - price_buf[k]; + if (d > best) { + better = best; + best = d; + best_i = k + k2; + } + else if (d > better) { + better = d; + } + } + } + __syncthreads(); + } + + best_buf[threadIdx.x] = best; + better_buf[threadIdx.x] = better; + best_i_buf[threadIdx.x] = best_i; + __syncthreads(); + + if (_unass_id != -1 && thread_in_unass == 0) { + for (int j = threadIdx.x + 1; j < threadIdx.x + thread_per_unass; j++) { + if (best_buf[j] > best) { + better = max(best, better_buf[j]); + best = best_buf[j]; + best_i = best_i_buf[j]; + } + else better = max(better, best_buf[j]); + } + bid[i * n + _unass_id] = best_i; + bid_increments[i * n + _unass_id] = best - better + eps; + atomicMax(&max_increments[i * n + best_i], best - better + eps); + } + } +} + +__global__ void GetMax(int b, int n, int * assignment, int * bid, float * bid_increments, float * max_increments, int * max_idx) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + int j = threadIdx.x + blockIdx.y * blockDim.x; + if (assignment[i * n + j] == -1) { + int bid_id = bid[i * n + j]; + float bid_inc = bid_increments[i * n + j]; + float max_inc = max_increments[i * n + bid_id]; + if (bid_inc - 1e-6 <= max_inc && max_inc <= bid_inc + 1e-6) + { + max_idx[i * n + bid_id] = j; + } + } + } +} + +__global__ void Assign(int b, int n, int * assignment, int * assignment_inv, float * price, int * bid, float * bid_increments, float * max_increments, int * max_idx, bool last) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + int j = threadIdx.x + blockIdx.y * blockDim.x; + if (assignment[i * n + j] == -1) { + int bid_id = bid[i * n + j]; + if (last || max_idx[i * n + bid_id] == j) + { + float bid_inc = bid_increments[i * n + j]; + int ass_inv = assignment_inv[i * n + bid_id]; + if (!last && ass_inv != -1) { + assignment[i * n + ass_inv] = -1; + } + assignment_inv[i * n + bid_id] = j; + assignment[i * n + j] = bid_id; + price[i * n + bid_id] += bid_inc; + max_increments[i * n + bid_id] = -1e9; + } + } + } +} + +__global__ void CalcDist(int b, int n, float * xyz1, float * xyz2, float * dist, int * assignment) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + int j = threadIdx.x + blockIdx.y * blockDim.x; + int k = assignment[i * n + j]; + float deltax = xyz1[(i * n + j) * 3 + 0] - xyz2[(i * n + k) * 3 + 0]; + float deltay = xyz1[(i * n + j) * 3 + 1] - xyz2[(i * n + k) * 3 + 1]; + float deltaz = xyz1[(i * n + j) * 3 + 2] - xyz2[(i * n + k) * 3 + 2]; + dist[i * n + j] = deltax * deltax + deltay * deltay + deltaz * deltaz; + } +} + +int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, + at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, + at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) { + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + if (n != m) { + printf("Input Error! The two point clouds should have the same size.\n"); + return -1; + } + + if (batch_size > 512) { + printf("Input Error! The batch size should be less than 512.\n"); + return -1; + } + + if (n % 1024 != 0) { + printf("Input Error! The size of the point clouds should be a multiple of 1024.\n"); + return -1; + } + + //cudaEvent_t start,stop; + //cudaEventCreate(&start); + //cudaEventCreate(&stop); + //cudaEventRecord(start); + //int iters = 50; + for (int i = 0; i < iters; i++) { + clear<<<1, batch_size>>>(batch_size, cnt_tmp.data(), unass_cnt.data()); + calc_unass_cnt<<>>(batch_size, n, assignment.data(), unass_cnt.data()); + calc_unass_cnt_sum<<<1, batch_size>>>(batch_size, unass_cnt.data(), unass_cnt_sum.data()); + calc_unass_idx<<>>(batch_size, n, assignment.data(), unass_idx.data(), unass_cnt.data(), + unass_cnt_sum.data(), cnt_tmp.data()); + Bid<<>>(batch_size, n, xyz1.data(), xyz2.data(), eps, assignment.data(), assignment_inv.data(), + price.data(), bid.data(), bid_increments.data(), max_increments.data(), + unass_cnt.data(), unass_cnt_sum.data(), unass_idx.data()); + GetMax<<>>(batch_size, n, assignment.data(), bid.data(), bid_increments.data(), max_increments.data(), max_idx.data()); + Assign<<>>(batch_size, n, assignment.data(), assignment_inv.data(), price.data(), bid.data(), + bid_increments.data(), max_increments.data(), max_idx.data(), i == iters - 1); + } + CalcDist<<>>(batch_size, n, xyz1.data(), xyz2.data(), dist.data(), assignment.data()); + //cudaEventRecord(stop); + //cudaEventSynchronize(stop); + //float elapsedTime; + //cudaEventElapsedTime(&elapsedTime,start,stop); + //printf("%lf\n", elapsedTime); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd Output: %s\n", cudaGetErrorString(err)); + return 0; + } + return 1; +} + +__global__ void NmDistanceGradKernel(int b, int n, const float * xyz1, const float * xyz2, const float * grad_dist, const int * idx, float * grad_xyz){ + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) { + float x1 = xyz1[(i * n + j) * 3 + 0]; + float y1 = xyz1[(i * n + j) * 3 + 1]; + float z1 = xyz1[(i * n + j) * 3 + 2]; + int j2 = idx[i * n + j]; + float x2 = xyz2[(i * n + j2) * 3 + 0]; + float y2 = xyz2[(i * n + j2) * 3 + 1]; + float z2 = xyz2[(i * n + j2) * 3 + 2]; + float g = grad_dist[i * n + j] * 2; + atomicAdd(&(grad_xyz[(i * n + j) * 3 + 0]), g * (x1 - x2)); + atomicAdd(&(grad_xyz[(i * n + j) * 3 + 1]), g * (y1 - y2)); + atomicAdd(&(grad_xyz[(i * n + j) * 3 + 2]), g * (z1 - z2)); + } + } +} + +int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx){ + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + NmDistanceGradKernel<<>>(batch_size, n, xyz1.data(), xyz2.data(), graddist.data(), idx.data(), gradxyz.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + return 0; + } + return 1; + +} diff --git a/lidm/eval/modules/emd/emd_module.py b/lidm/eval/modules/emd/emd_module.py new file mode 100644 index 0000000000000000000000000000000000000000..3065509eb77942610aab93a85c21f30bf6bf9bed --- /dev/null +++ b/lidm/eval/modules/emd/emd_module.py @@ -0,0 +1,112 @@ +# EMD approximation module (based on auction algorithm) +# memory complexity: O(n) +# time complexity: O(n^2 * iter) +# author: Minghua Liu + +# Input: +# xyz1, xyz2: [#batch, #points, 3] +# where xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud +# two point clouds should have same size and be normalized to [0, 1] +# #points should be a multiple of 1024 +# #batch should be no greater than 512 +# eps is a parameter which balances the error rate and the speed of convergence +# iters is the number of iteration +# we only calculate gradient for xyz1 + +# Output: +# dist: [#batch, #points], sqrt(dist) -> L2 distance +# assignment: [#batch, #points], index of the matched point in the ground truth point cloud +# the result is an approximation and the assignment is not guranteed to be a bijection +import importlib +import os +import time +import numpy as np +import torch +from torch import nn +from torch.autograd import Function + +emd_found = importlib.find_loader("emd") is not None +if not emd_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting EMD 3D") + + from torch.utils.cpp_extension import load + + emd = load(name="emd", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["emd.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["emd_cuda.cu"]), + ]) + print("Loaded JIT 3D CUDA emd") +else: + import emd + print("Loaded compiled 3D CUDA emd") + + +class emdFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2, eps, iters): + batchsize, n, _ = xyz1.size() + _, m, _ = xyz2.size() + + assert (n == m) + assert (xyz1.size()[0] == xyz2.size()[0]) + # assert(n % 1024 == 0) + assert (batchsize <= 512) + + xyz1 = xyz1.contiguous().float().cuda() + xyz2 = xyz2.contiguous().float().cuda() + dist = torch.zeros(batchsize, n, device='cuda').contiguous() + assignment = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() - 1 + assignment_inv = torch.zeros(batchsize, m, device='cuda', dtype=torch.int32).contiguous() - 1 + price = torch.zeros(batchsize, m, device='cuda').contiguous() + bid = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() + bid_increments = torch.zeros(batchsize, n, device='cuda').contiguous() + max_increments = torch.zeros(batchsize, m, device='cuda').contiguous() + unass_idx = torch.zeros(batchsize * n, device='cuda', dtype=torch.int32).contiguous() + max_idx = torch.zeros(batchsize * m, device='cuda', dtype=torch.int32).contiguous() + unass_cnt = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() + unass_cnt_sum = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() + cnt_tmp = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() + + emd.forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, + unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters) + + ctx.save_for_backward(xyz1, xyz2, assignment) + return dist, assignment + + @staticmethod + def backward(ctx, graddist, gradidx): + xyz1, xyz2, assignment = ctx.saved_tensors + graddist = graddist.contiguous() + + gradxyz1 = torch.zeros(xyz1.size(), device='cuda').contiguous() + gradxyz2 = torch.zeros(xyz2.size(), device='cuda').contiguous() + + emd.backward(xyz1, xyz2, gradxyz1, graddist, assignment) + return gradxyz1, gradxyz2, None, None + + +class emdModule(nn.Module): + def __init__(self): + super(emdModule, self).__init__() + + def forward(self, input1, input2, eps, iters): + return emdFunction.apply(input1, input2, eps, iters) + + +def test_emd(): + x1 = torch.rand(20, 8192, 3).cuda() + x2 = torch.rand(20, 8192, 3).cuda() + emd = emdModule() + start_time = time.perf_counter() + dis, assigment = emd(x1, x2, 0.05, 3000) + print("Input_size: ", x1.shape) + print("Runtime: %lfs" % (time.perf_counter() - start_time)) + print("EMD: %lf" % np.sqrt(dis.cpu()).mean()) + print("|set(assignment)|: %d" % assigment.unique().numel()) + assigment = assigment.cpu().numpy() + assigment = np.expand_dims(assigment, -1) + x2 = np.take_along_axis(x2, assigment, axis=1) + d = (x1 - x2) * (x1 - x2) + print("Verified EMD: %lf" % np.sqrt(d.cpu().sum(-1)).mean()) diff --git a/lidm/eval/modules/emd/setup.py b/lidm/eval/modules/emd/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8588de957bc285372be5a59a6e086af4954dc99b --- /dev/null +++ b/lidm/eval/modules/emd/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='emd', + ext_modules=[ + CUDAExtension('emd', [ + 'emd.cpp', + 'emd_cuda.cu', + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/lidm/models/__init__.py b/lidm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/models/autoencoder.py b/lidm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..089130695164fbb30460439c78efd96abcae33e5 --- /dev/null +++ b/lidm/models/autoencoder.py @@ -0,0 +1,465 @@ +import numpy as np +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ..modules.diffusion import model_lidm, model_ldm +from ..modules.distributions.distributions import DiagonalGaussianDistribution +from ..modules.ema import LitEma +from ..utils.misc_utils import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + n_embed, + embed_dim, + lossconfig=None, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False, + lib_name='ldm', + use_mask=False, + **kwargs + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.use_mask = use_mask + model_lib = eval(f'model_{lib_name}') + self.encoder = model_lib.Encoder(**ddconfig) + self.decoder = model_lib.Decoder(**ddconfig) + if lossconfig is not None: + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_, _, ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + # if len(x.shape) == 3: + # x = x[..., None] + + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size + 16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def get_mask(self, batch): + mask = batch['mask'] + # if len(mask.shape) == 3: + # mask = mask[..., None] + return mask + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + m = self.get_mask(batch) if self.use_mask else None + x_rec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencoder + aeloss, log_dict_ae = self.loss(qloss, x, x_rec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=None, masks=m) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, x_rec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + masks=m) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + if self.use_ema: + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + m = self.get_mask(batch) if self.use_mask else None + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=None, + masks=m + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=None, + masks=m + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor * self.learning_rate + # print("lr_d", lr_d) + # print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if self.use_mask: + mask = xrec[:, 1:2] < 0. + xrec = xrec[:, 0:1] + xrec[mask] = -1. + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + if self.use_mask: + mask = dec[:, 1:2] < 0. + dec = dec[:, 0:1] + dec[mask] = -1. + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + lib_name='ldm', + use_mask=False + ): + super().__init__() + self.image_key = image_key + self.use_mask = use_mask + model_lib = eval(f'model_{lib_name}') + self.encoder = model_lib.Encoder(**ddconfig) + self.decoder = model_lib.Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[:, None] + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/lidm/models/diffusion/__init__.py b/lidm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/models/diffusion/classifier.py b/lidm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..82d2d6d0ef132a83e25bc84e324252a9aa1d513b --- /dev/null +++ b/lidm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ...modules.diffusion.openaimodel import EncoderUNetModel, UNetModel +from ...utils.misc_utils import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/lidm/models/diffusion/ddim.py b/lidm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..8c49dd42cac10dd1743eab581557633124e95826 --- /dev/null +++ b/lidm/models/diffusion/ddim.py @@ -0,0 +1,204 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ...modules.basic import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ...utils.misc_utils import print_fn + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=False, + disable_tqdm=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print_fn(f'Data shape for DDIM sampling is {size}, eta {eta}', verbose) + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + verbose=verbose, disable_tqdm=disable_tqdm) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=False, disable_tqdm=True): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print_fn(f"Running DDIM Sampling with {total_steps} timesteps", verbose) + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, disable=disable_tqdm) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/lidm/models/diffusion/ddpm.py b/lidm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..706e1c14188e40c3edc11d123ea74768e237cf1e --- /dev/null +++ b/lidm/models/diffusion/ddpm.py @@ -0,0 +1,1455 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ...utils.misc_utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config, print_fn +from ...modules.ema import LitEma +from ...modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ...models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ...modules.basic import make_beta_schedule, extract_into_tensor, noise_like +from ...models.diffusion.ddim import DDIMSampler + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=[256, 256], + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + verbose=False + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.print_fn = partial(print_fn, verbose=verbose) + self.verbose = verbose + self.parameterization = parameterization + self.print_fn(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + self.print_fn(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + self.logvar = nn.Parameter(self.logvar, requires_grad=self.learn_logvar) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + self.print_fn(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + self.print_fn(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + self.print_fn("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + self.print_fn(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + self.print_fn(f"Missing Keys: {missing}") + if len(unexpected) > 0: + self.print_fn(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, *image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + # if len(x.shape) == 3: + # x = x[..., None] + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_mask=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + self.use_mask = use_mask + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + self.print_fn("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + self.print_fn(f"setting self.scale_factor to {self.scale_factor}") + self.print_fn("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + self.print_fn("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + self.print_fn(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + # ground truth + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + + # encoding + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'bbox', 'center', 'camera']: + xc = batch[cond_key] + elif cond_key in ['class_label']: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + # if bs is not None: + # xc = xc[:bs] + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, (dict, list)): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + self.print_fn("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + self.print_fn("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + self.print_fn("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + self.print_fn("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + self.print_fn("reducing kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + self.print_fn("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", 'bbox_img'] and self.model.conditioning_key: + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key in ['bbox', 'center']: + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** num_downs + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer.crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + self.print_fn(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + self.print_fn(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + self.print_fn(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + self.print_fn(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + self.print_fn(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + # simple loss + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + # vlb loss + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + + # total loss + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=False, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=False, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=False, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None, **kwargs): + if shape is None: + shape = (batch_size, self.channels, *self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, *self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=self.verbose, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=False, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, + plot_diffusion_rows=False, dset=None, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif self.cond_stage_key in ['camera']: + if isinstance(batch["camera"], list): + xc = torch.cat(batch["camera"], -1) + else: + xc = batch["camera"].permute(0, 2, 3, 1, 4) + xc = xc.reshape(*xc.shape[:3], -1) * 2. - 1. + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + if dset is None: + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key].data + label2rgb = torch.from_numpy(dset.label2rgb).to(self.device) / 127.5 - 1. + # log["original_conditioning"] = self.to_rgb(xc) + log["original_conditioning"] = label2rgb[xc.argmax(1)].permute(0, 3, 1, 2) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, shape=(self.channels, *self.image_size), batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + self.print_fn(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + self.print_fn('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + self.print_fn("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key in ['bbox', 'center'], 'Layout2ImgDiffusion only for cond_stage_key="bbox" or "center"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, dset=None, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + if dset is None: + dset = self.trainer.datamodule.datasets[key].data + mapper = dset.conditional_builders[self.cond_stage_key] + H, W = batch['image'].shape[2:] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label_for_category_id(catno) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (W, H)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/lidm/models/diffusion/plms.py b/lidm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..b68ab413ca361d8bc8ba815c12ad2dbb2144d728 --- /dev/null +++ b/lidm/models/diffusion/plms.py @@ -0,0 +1,236 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ...modules.diffusion.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/lidm/modules/__init__.py b/lidm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/attention.py b/lidm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..83708247da4a5d293aa6bbc83a15783eb32282df --- /dev/null +++ b/lidm/modules/attention.py @@ -0,0 +1,261 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from .basic import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/lidm/modules/basic.py b/lidm/modules/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..e0571d7a5f33ef7c26b219e20d8f31859936b5ed --- /dev/null +++ b/lidm/modules/basic.py @@ -0,0 +1,392 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat +from torch import Tensor + +from ..utils.misc_utils import instantiate_from_config, print_fn + + +class CircularPad(nn.Module): + def __init__(self, pad_size): + super(CircularPad, self).__init__() + h1, h2, v1, v2 = pad_size + self.h_pad, self.v_pad = (h1, h2, 0, 0), (0, 0, v1, v2) + + def forward(self, x): + if sum(self.h_pad) > 0: + x = nn.functional.pad(x, self.h_pad, mode="circular") # horizontal pad + if sum(self.v_pad) > 0: + x = nn.functional.pad(x, self.v_pad, mode="constant") # vertical pad + return x + + +class CircularConv2d(nn.Conv2d): + def __init__(self, *args, **kwargs): + if 'padding' in kwargs: + self.is_pad = True + if isinstance(kwargs['padding'], int): + h1 = h2 = v1 = v2 = kwargs['padding'] + elif isinstance(kwargs['padding'], tuple): + h1, h2, v1, v2 = kwargs['padding'] + else: + raise NotImplementedError + self.h_pad, self.v_pad = (h1, h2, 0, 0), (0, 0, v1, v2) + del kwargs['padding'] + else: + self.is_pad = False + + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + if self.is_pad: + if sum(self.h_pad) > 0: + x = nn.functional.pad(x, self.h_pad, mode="circular") # horizontal pad + if sum(self.v_pad) > 0: + x = nn.functional.pad(x, self.v_pad, mode="constant") # vertical pad + x = self._conv_forward(x, self.weight, self.bias) + return x + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=False): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + print_fn(f'Selected timesteps for ddim sampler: {steps_out}', False) + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=False): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + print_fn(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}', False) + print_fn(f'For the chosen value of eta, which is {eta}, this results in the following sigma_t schedule for ddim sampler {sigmas}', False) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, cconv=False, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + if cconv: + return CircularConv2d(*args, **kwargs) + else: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/lidm/modules/diffusion/__init__.py b/lidm/modules/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/diffusion/model_ldm.py b/lidm/modules/diffusion/model_ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..5374a1142212bf9292a0fc3fbbfe7ff772aa8d0e --- /dev/null +++ b/lidm/modules/diffusion/model_ldm.py @@ -0,0 +1,817 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ...utils.misc_utils import instantiate_from_config +from ...modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_levels, dropout=0.0, resamp_with_conv=True, in_channels, + use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if i_level in attn_levels: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if i_level in attn_levels: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_levels, dropout=0.0, resamp_with_conv=True, in_channels, + z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if i_level in attn_levels: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_levels, dropout=0.0, resamp_with_conv=True, in_channels, + z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if i_level in attn_levels: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, + ch_mult=(2, 2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + block_in = in_channels + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=( + int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, out_ch, num_res_blocks, + attn_levels, dropout=0.0, resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, + attn_levels=attn_levels, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, num_res_blocks, attn_levels, ch, ch_mult=(1, 2, 4, 8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_levels=attn_levels, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1. + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, z_channels=in_channels, num_res_blocks=2, + attn_levels=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult: list, in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3, + stride=1, padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, 'b c h w -> b (h w) c') + return z diff --git a/lidm/modules/diffusion/model_lidm.py b/lidm/modules/diffusion/model_lidm.py new file mode 100644 index 0000000000000000000000000000000000000000..2803291c509d6d7c8bc5d5fb519fd7c46fd14707 --- /dev/null +++ b/lidm/modules/diffusion/model_lidm.py @@ -0,0 +1,681 @@ +# pytorch_diffusion + derived encoder decoder +import math + +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ..basic import CircularConv2d +from ...utils.misc_utils import instantiate_from_config +from ...modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +UPSAMPLE_STRIDE2KERNEL_DICT = {(1, 2): (1, 5), (1, 4): (1, 7), (2, 1): (5, 1), (2, 2): (3, 3)} +UPSAMPLE_STRIDE2PAD_DICT = {(1, 2): (2, 2, 0, 0), (1, 4): (3, 3, 0, 0), (2, 1): (0, 0, 2, 2), (2, 2): (1, 1, 1, 1)} + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv, stride): + super().__init__() + self.with_conv = with_conv + self.stride = stride + if self.with_conv: + k, p = UPSAMPLE_STRIDE2KERNEL_DICT[stride], UPSAMPLE_STRIDE2PAD_DICT[stride] + self.conv = CircularConv2d(in_channels, in_channels, kernel_size=k, padding=p) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=self.stride, mode='bilinear', align_corners=True) + if self.with_conv: + x = self.conv(x) + return x + + +DOWNSAMPLE_STRIDE2KERNEL_DICT = {(1, 2): (3, 3), (1, 4): (3, 5), (2, 1): (3, 3), (2, 2): (3, 3)} +DOWNSAMPLE_STRIDE2PAD_DICT = {(1, 2): (0, 1, 1, 1), (1, 4): (1, 1, 1, 1), (2, 1): (1, 1, 1, 1), (2, 2): (0, 1, 0, 1)} + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv, stride): + super().__init__() + self.with_conv = with_conv + self.stride = stride + if self.with_conv: + k, p = DOWNSAMPLE_STRIDE2KERNEL_DICT[stride], DOWNSAMPLE_STRIDE2PAD_DICT[stride] + self.conv = CircularConv2d(in_channels, in_channels, kernel_size=k, stride=stride, padding=p) + + def forward(self, x): + if self.with_conv: + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=self.stride, stride=self.stride) # modified for lidar + return x + + +UNIFORM_KERNEL2PAD_DICT = {(3, 3): (1, 1, 1, 1), (1, 4): (1, 2, 0, 0)} + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, kernel_size=(3, 3), conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + pad = UNIFORM_KERNEL2PAD_DICT[kernel_size] + + self.norm1 = Normalize(in_channels) + self.conv1 = CircularConv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=pad) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CircularConv2d(out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=pad) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CircularConv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=pad) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult, strides, num_res_blocks, + attn_levels, dropout=0.0, resamp_with_conv=True, in_channels, z_channels, + double_z=True, use_linear_attn=False, attn_type="vanilla", use_mask=False, + **ignore_kwargs): + super().__init__() + if use_mask: + assert out_ch == in_channels + 1, 'Set "out_ch = out_ch + 1" for mask prediction.' + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + + # downsampling + self.conv_in = CircularConv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if i_level in attn_levels: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + stride = tuple(strides[i_level]) + down.downsample = Downsample(block_in, resamp_with_conv, stride) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = CircularConv2d(block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult, strides, num_res_blocks, attn_levels, + dropout=0.0, resamp_with_conv=True, in_channels, z_channels, give_pre_end=False, + tanh_out=False, use_linear_attn=False, attn_type="vanilla", use_mask=False, + **ignorekwargs): + super().__init__() + stride2kernel = {(2, 2): (3, 3), (1, 2): (1, 4)} + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + + # z to block_in + self.conv_in = CircularConv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + stride = tuple(strides[i_level - 1]) if i_level > 0 else None + kernel = stride2kernel[stride] if stride is not None else (1, 4) + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + kernel_size=kernel, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if i_level in attn_levels: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if stride is not None: + up.upsample = Upsample(block_in, resamp_with_conv, stride) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = CircularConv2d(block_in, + out_ch, + kernel_size=(1, 4), + stride=1, + padding=(1, 2, 0, 0)) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, + ch_mult=(2, 2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=( + int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, out_ch, num_res_blocks, + attn_levels, dropout=0.0, resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, attn_levels=attn_levels, dropout=dropout, + resamp_with_conv=resamp_with_conv, out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, num_res_blocks, attn_levels, ch, ch_mult=(1, 2, 4, 8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_levels=attn_levels, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1. + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, z_channels=in_channels, num_res_blocks=2, + attn_levels=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult: list, in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3, + stride=1, padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, 'b c h w -> b (h w) c') + return z diff --git a/lidm/modules/diffusion/openaimodel.py b/lidm/modules/diffusion/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..d21458a9cf84e6e0a797b408d14e3266eeadd704 --- /dev/null +++ b/lidm/modules/diffusion/openaimodel.py @@ -0,0 +1,971 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ..basic import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ...modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, cconv=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, cconv=cconv) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, cconv=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, cconv=cconv + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + cconv=False + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1, cconv=cconv), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims, cconv=cconv) + self.x_upd = Upsample(channels, False, dims, cconv=cconv) + elif down: + self.h_upd = Downsample(channels, False, dims, cconv=cconv) + self.x_upd = Downsample(channels, False, dims, cconv=cconv) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, cconv=cconv) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1, cconv=cconv + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, cconv=cconv) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), + True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + lib_name='ldm' + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.cconv = lib_name in ['lidm', 'lidm_v0'] + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1, cconv=self.cconv) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + cconv=self.cconv + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + cconv=self.cconv + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch, cconv=self.cconv + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + cconv=self.cconv + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + cconv=self.cconv + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + cconv=self.cconv + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + cconv=self.cconv + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, cconv=self.cconv) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, cconv=self.cconv)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1, use_cconv=self.use_cconv), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + lib_name='ldm', + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.cconv = lib_name == 'lidm' + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1, cconv=self.cconv) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/lidm/modules/distributions/__init__.py b/lidm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/distributions/distributions.py b/lidm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/lidm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/lidm/modules/ema.py b/lidm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6d1d3594bb6336ed896557cf33275ea94380d3 --- /dev/null +++ b/lidm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/lidm/modules/encoders/__init__.py b/lidm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/encoders/modules.py b/lidm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..fb37ae10aec342c6ff03fc1ecc979588dd92aa0a --- /dev/null +++ b/lidm/modules/encoders/modules.py @@ -0,0 +1,327 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +import kornia + +from ...modules.x_transformer import Encoder, TransformerWrapper + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast + # self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.tokenizer = BertTokenizerFast.from_pretrained('./models/bert') + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda", use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text) # .to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + strides=[], + method='bilinear', + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.strides = strides + assert method in ['nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'] + self.interpolator = partial(torch.nn.functional.interpolate, mode=method, align_corners=True) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + for h_s, w_s in self.strides: + x = self.interpolator(x, scale_factor=(1/h_s, 1/w_s)) + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.model.to(device) + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim == 2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipMultiTextEmbedder(FrozenCLIPTextEmbedder): + def __init__(self, num_views=1, apply_all=False, **kwargs): + super().__init__(**kwargs) + self.num_views = num_views + self.apply_all = apply_all + + def encode(self, text): + z = self(text) + if z.ndim == 2: + z = z[:, None, :] + + if not self.apply_all: + new_z = torch.zeros(*z.shape[:2], z.shape[2] * self.num_views, device=z.device) + new_z[:, :, self.num_views // 2 * z.shape[2]: (self.num_views // 2 + 1) * z.shape[2]] = z + else: + new_z = repeat(z, 'b 1 d -> b 1 (d m)', m=self.num_views) + + return new_z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device='cpu', jit=jit) + self.init() + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def init(self): + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + def preprocess(self, x): + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + # x = (x + 1.) / 2. + + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [0,1] + return self.model.encode_image(self.preprocess(x)) + + +class FrozenClipMultiImageEmbedder(FrozenClipImageEmbedder): + """ + Uses the CLIP image encoder with multi-image as input. + """ + + def __init__(self, num_views=1, split_per_view=1, img_dim=768, out_dim=512, key='camera', **kwargs): + super().__init__(**kwargs) + self.split_per_view = split_per_view + self.key = key + self.linear = nn.Linear(img_dim, out_dim) + self.view_embedding = nn.Parameter(img_dim ** -0.5 * torch.randn((1, num_views * split_per_view, img_dim))) + + def forward(self, x): + # x is assumed to be in range [0,1] + if isinstance(x, torch.Tensor) and x.ndim == 5: + x = x.permute(1, 0, 2, 3, 4) + elif isinstance(x, dict): + x = x[self.key] + elif isinstance(x, torch.Tensor) and x.ndim == 3: + x = self.linear(x) + return x + + with torch.no_grad(): + img_feats = [self.model.encode_image(self.preprocess(img))[:, None] for img in x] + x = torch.cat(img_feats, 1).float() + self.view_embedding + x = self.linear(x) + + return x + + +class FrozenClipImagePatchEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + img_dim=1024, + out_dim=512, + num_views=1, + split_per_view=1 + ): + super().__init__() + self.model, _ = clip.load(name=model, device='cpu', jit=jit) + self.init() + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.view_embedding = nn.Parameter(img_dim ** -0.5 * torch.randn((1, num_views * split_per_view, 1, img_dim))) + + self.linear = nn.Linear(img_dim, out_dim) + + def init(self): + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + def preprocess(self, x): + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + # x = (x + 1.) / 2. + + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def encode_image_patch(self, x): + visual_encoder = self.model.visual + x = x.type(self.model.dtype) + x = visual_encoder.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([visual_encoder.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + visual_encoder.positional_embedding.to(x.dtype) + x = visual_encoder.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = visual_encoder.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:, :] + + return x + + def forward(self, x): + # x is assumed to be in range [0,1] + img_feats = [self.encode_image_patch(self.preprocess(img))[:, None] for img in x] + x = torch.cat(img_feats, 1).float() + self.view_embedding + x = rearrange(x, 'b v n c -> b (v n) c') + x = self.linear(x) + return x diff --git a/lidm/modules/image_degradation/__init__.py b/lidm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0db8d3e862120b2c4587c6b904c0c1153ebcdf48 --- /dev/null +++ b/lidm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from .bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from .bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/lidm/modules/image_degradation/bsrgan.py b/lidm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..5b28874c1cc2aeb47f5f5d067397272401e59dd6 --- /dev/null +++ b/lidm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +from . import utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/lidm/modules/image_degradation/bsrgan_light.py b/lidm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..38f08a613ccbc6f18cd6c78f23919c222062a7cf --- /dev/null +++ b/lidm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +from . import utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/lidm/modules/image_degradation/utils/test.png b/lidm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/lidm/modules/image_degradation/utils/test.png differ diff --git a/lidm/modules/image_degradation/utils_image.py b/lidm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/lidm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/lidm/modules/losses/__init__.py b/lidm/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef46a16db54ad2fc925f47c54843a11315575f73 --- /dev/null +++ b/lidm/modules/losses/__init__.py @@ -0,0 +1,54 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + + +def l1(x, y): + return torch.abs(x - y) + + +def l2(x, y): + return torch.pow((x - y), 2) + + +def square_dist_loss(x, y): + return torch.sum((x - y) ** 2, dim=1, keepdim=True) + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) \ No newline at end of file diff --git a/lidm/modules/losses/contperceptual.py b/lidm/modules/losses/contperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..f9afbe01c2d1f6c30a10a810f92feadc082a7090 --- /dev/null +++ b/lidm/modules/losses/contperceptual.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn + +from . import weights_init, hinge_d_loss, vanilla_d_loss +from .discriminator import LiDARNLayerDiscriminator +from .lpips import LPIPS + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + p_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge", **kwargs): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_weight = p_weight + if p_weight > 0.: + self.perceptual_loss = LPIPS().eval() + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = LiDARNLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + disc_factor = 0. if global_step > self.discriminator_iter_start else self.disc_factor + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + loss = weighted_nll_loss + self.kl_weight * kl_loss + disc_factor * d_weight * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/lidm/modules/losses/discriminator.py b/lidm/modules/losses/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..64969ce102bc0088ccce2769bf43cf575c0a3968 --- /dev/null +++ b/lidm/modules/losses/discriminator.py @@ -0,0 +1,216 @@ +import functools +import torch.nn as nn + + +from ..basic import ActNorm, CircularConv2d + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=1, output_nc=1, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, output_nc, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class LiDARNLayerDiscriminator(nn.Module): + """Modified PatchGAN discriminator from Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=1, output_nc=1, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(LiDARNLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = (4, 4) + sequence = [CircularConv2d(input_nc, ndf, kernel_size=kw, stride=(1, 2), padding=(1, 2, 1, 2)), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + CircularConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=(1, 2), bias=use_bias, padding=(1, 2, 1, 2)), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + CircularConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, bias=use_bias, padding=(1, 2, 1, 2)), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + CircularConv2d(ndf * nf_mult, output_nc, kernel_size=kw, stride=1, padding=(1, 2, 1, 2))] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class LiDARNLayerDiscriminatorV2(nn.Module): + """Modified PatchGAN discriminator from Pix2Pix (larger receptive field) + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=1, output_nc=1, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(LiDARNLayerDiscriminatorV2, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = (4, 4) + sequence = [CircularConv2d(input_nc, ndf, kernel_size=kw, stride=(1, 2), padding=(1, 2, 1, 2)), nn.LeakyReLU(0.2, True), + CircularConv2d(ndf, ndf, kernel_size=kw, stride=(1, 2), padding=(1, 2, 1, 2)), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + CircularConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=(2, 2), bias=use_bias, padding=(1, 2, 1, 2)), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + CircularConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, bias=use_bias, padding=(1, 2, 1, 2)), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + CircularConv2d(ndf * nf_mult, output_nc, kernel_size=kw, stride=1, padding=(1, 2, 1, 2))] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class LiDARNLayerDiscriminatorV3(nn.Module): + """Modified PatchGAN discriminator from Pix2Pix (larger receptive field) + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=1, output_nc=1, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(LiDARNLayerDiscriminatorV3, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = (4, 4) + sequence = [CircularConv2d(input_nc, ndf, kernel_size=(1, 4), stride=(1, 1), padding=(1, 2, 1, 2)), nn.LeakyReLU(0.2, True), + CircularConv2d(ndf, ndf, kernel_size=kw, stride=(2, 2), padding=(1, 2, 1, 2)), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + CircularConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=(2, 2), bias=use_bias, padding=(1, 2, 1, 2)), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + CircularConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, bias=use_bias, padding=(1, 2, 1, 2)), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + CircularConv2d(ndf * nf_mult, output_nc, kernel_size=kw, stride=1, padding=(1, 2, 1, 2))] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + import pdb; pdb.set_trace() + return self.main(input) diff --git a/lidm/modules/losses/geometric.py b/lidm/modules/losses/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..62cdc1d71da440bed0a3833c723fad3bec2fdc3d --- /dev/null +++ b/lidm/modules/losses/geometric.py @@ -0,0 +1,78 @@ +from functools import partial + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + + +class GeoConverter(nn.Module): + def __init__(self, curve_length=4, bev_only=False, dataset_config=dict()): + super().__init__() + self.curve_length = curve_length + self.coord_dim = 3 if not bev_only else 2 + self.convert_fn = self.batch_range2bev if bev_only else self.batch_range2xyz + + fov = dataset_config.fov + self.fov_up = fov[0] / 180.0 * np.pi # field of view up in rad + self.fov_down = fov[1] / 180.0 * np.pi # field of view down in rad + self.fov_range = abs(self.fov_down) + abs(self.fov_up) # get field of view total in rad + self.depth_scale = dataset_config.depth_scale + self.depth_min, self.depth_max = dataset_config.depth_range + self.log_scale = dataset_config.log_scale + self.size = dataset_config['size'] + self.register_conversion() + + def register_conversion(self): + scan_x, scan_y = np.meshgrid(np.arange(self.size[1]), np.arange(self.size[0])) + scan_x = scan_x.astype(np.float64) / self.size[1] + scan_y = scan_y.astype(np.float64) / self.size[0] + + yaw = (np.pi * (scan_x * 2 - 1)) + pitch = ((1.0 - scan_y) * self.fov_range - abs(self.fov_down)) + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('cos_yaw', torch.cos(to_torch(yaw))) + self.register_buffer('sin_yaw', torch.sin(to_torch(yaw))) + self.register_buffer('cos_pitch', torch.cos(to_torch(pitch))) + self.register_buffer('sin_pitch', torch.sin(to_torch(pitch))) + + def batch_range2xyz(self, imgs): + batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale + if self.log_scale: + batch_depth = torch.exp2(batch_depth) - 1 + batch_depth = batch_depth.clamp(self.depth_min, self.depth_max) + + batch_x = self.cos_yaw * self.cos_pitch * batch_depth + batch_y = -self.sin_yaw * self.cos_pitch * batch_depth + batch_z = self.sin_pitch * batch_depth + batch_xyz = torch.cat([batch_x, batch_y, batch_z], dim=1) + + return batch_xyz + + def batch_range2bev(self, imgs): + batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale + if self.log_scale: + batch_depth = torch.exp2(batch_depth) - 1 + batch_depth = batch_depth.clamp(self.depth_min, self.depth_max) + + batch_x = self.cos_yaw * self.cos_pitch * batch_depth + batch_y = -self.sin_yaw * self.cos_pitch * batch_depth + batch_bev = torch.cat([batch_x, batch_y], dim=1) + + return batch_bev + + def curve_compress(self, batch_coord): + compressed_batch_coord = F.avg_pool2d(batch_coord, (1, self.curve_length)) + + return compressed_batch_coord + + def forward(self, input): + input = input / 2. + .5 # [-1, 1] -> [0, 1] + + input_coord = self.convert_fn(input) + if self.curve_length > 1: + input_coord = self.curve_compress(input_coord) + + return input_coord diff --git a/lidm/modules/losses/perceptual.py b/lidm/modules/losses/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..c1255a584194c6fdeae16b5e4df8ecf63e5224ed --- /dev/null +++ b/lidm/modules/losses/perceptual.py @@ -0,0 +1,123 @@ +import hashlib +import os + +import requests +import torch +import torch.nn as nn + +from tqdm import tqdm + +from . import l1, l2 +from ...utils.model_utils import build_model + +URL_MAP = { +} + +CKPT_MAP = { +} + +MD5_MAP = { +} + +PERCEPTUAL_TYPE = { + 'rangenet_full': [('enc_0', 32), ('enc_1', 64), ('enc_2', 128), ('enc_3', 256), ('enc_4', 512), ('enc_5', 1024), + ('dec_4', 512), ('dec_3', 256), ('dec_2', 128), ('dec_1', 64), ('dec_0', 32)], + 'rangenet_enc': [('enc_0', 32), ('enc_1', 64), ('enc_2', 128), ('enc_3', 256), ('enc_4', 512), ('enc_5', 1024)], + 'rangenet_dec': [('dec_4', 512), ('dec_3', 256), ('dec_2', 128), ('dec_1', 64), ('dec_0', 32)], + 'rangenet_final': [('dec_0', 32)] +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class PerceptualLoss(nn.Module): + def __init__(self, ptype, depth_scale, log_scale=True, use_dropout=True, lpips=False, p_loss='l1'): + super().__init__() + self.depth_scale = depth_scale + self.log_scale = log_scale + + if p_loss == "l1": + self.p_loss = l1 + else: + self.p_loss = l2 + + self.chns = PERCEPTUAL_TYPE[ptype] + self.return_list = [name for name, _ in self.chns] + self.loss_scale = [5.0, 3.39, 2.29, 1.61, 0.895] # predefined based on the loss of each stage after a few epochs (refer ) + self.net = build_model('kitti', 'rangenet') + self.lin_list = nn.ModuleList([NetLinLayer(ch, use_dropout=use_dropout) for _, ch in self.chns]) if lpips else None + for param in self.parameters(): + param.requires_grad = False + + @staticmethod + def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + @staticmethod + def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) + + def preprocess(self, *inputs): + assert len(inputs) == 2, 'input with both depth images and coord images' + depth_img, xyz_img = inputs + + # scale to standard rangenet input + depth_img = (depth_img * 0.5 + 0.5) * self.depth_scale + if self.log_scale: + depth_img = torch.exp2(depth_img) - 1 + + img = torch.cat([depth_img, xyz_img], 1) + return img + + def forward(self, target, input): + in0_input, in1_input = self.preprocess(*input), self.preprocess(*target) + outs0, outs1 = self.net(in0_input, return_list=self.return_list), self.net(in1_input, return_list=self.return_list) + + val_list = [] + for i, (name, _) in enumerate(self.chns): + feats0, feats1 = self.normalize_tensor(outs0[name].to(in0_input.device)), \ + self.normalize_tensor(outs1[name].to(in0_input.device)) + diffs = self.p_loss(feats0, feats1) + res = self.lin_list[i].model(diffs) if self.lin_list is not None else diffs.mean(1, keepdim=True) + res = self.spatial_average(res, keepdim=True) * self.loss_scale[i] + val_list.append(res) + val = sum(val_list) + return val diff --git a/lidm/modules/losses/vqperceptual.py b/lidm/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cd7220338d6beacdf5e70108c6279c1fe18b7c --- /dev/null +++ b/lidm/modules/losses/vqperceptual.py @@ -0,0 +1,176 @@ +import torch +from torch import nn + +from . import weights_init, l1, l2, hinge_d_loss, vanilla_d_loss, measure_perplexity, square_dist_loss +from .geometric import GeoConverter +from .discriminator import NLayerDiscriminator, LiDARNLayerDiscriminator, LiDARNLayerDiscriminatorV2 +from .perceptual import PerceptualLoss + +VERSION2DISC = {'v0': NLayerDiscriminator, 'v1': LiDARNLayerDiscriminator, 'v2': LiDARNLayerDiscriminatorV2} + + +class VQGeoLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_out_channels=1, disc_factor=1.0, disc_weight=1.0, + mask_factor=0.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, pixel_loss="l1", disc_version='v1', + geo_factor=1.0, curve_length=4, perceptual_factor=1.0, perceptual_type='rangenet_final', + dataset_config=dict()): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + self.mask_factor = mask_factor + self.geo_factor = geo_factor + + # scale of reconstruction loss + self.rec_scale = 1 + if mask_factor > 0: + self.rec_scale += 1. + if geo_factor > 0: + self.rec_scale += 1. + if perceptual_factor > 0: + self.rec_scale += 1. + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.perceptual_factor = perceptual_factor + if perceptual_factor > 0.: + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = PerceptualLoss(perceptual_type, dataset_config.depth_scale, + dataset_config.log_scale).eval() + + disc_cls = VERSION2DISC[disc_version] + self.discriminator = disc_cls(input_nc=disc_in_channels, + output_nc=disc_out_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQGeoLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + self.geometry_converter = GeoConverter(curve_length, False, dataset_config) # force converting xyz output + self.geo_loss = square_dist_loss + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None, masks=None): + input_coord = self.geometry_converter(inputs) + rec_coord = self.geometry_converter(reconstructions[:, 0:1].contiguous()) + + ############# Reconstruction ############# + # pixel reconstruction loss + if self.mask_factor > 0 and masks is not None: + pixel_rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions[:, 0:1].contiguous()) + mask_rec_loss = self.pixel_loss(masks.contiguous(), reconstructions[:, 1:2].contiguous()) * self.mask_factor + else: + pixel_rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + mask_rec_loss = torch.tensor(0.0) + + # geometry reconstruction loss (bev) + if self.geo_factor > 0: + geo_rec_loss = self.geo_loss(input_coord[:, :2], rec_coord[:, :2]) * self.geo_factor + else: + geo_rec_loss = torch.tensor(0.0) + + # perceptual loss + if self.perceptual_factor > 0: + perceptual_loss = self.perceptual_loss((inputs.contiguous(), input_coord), + (reconstructions[:, 0:1].contiguous(), rec_coord)) * self.perceptual_factor + else: + perceptual_loss = torch.tensor(0.0) + + # overall reconstruction loss + rec_loss = (pixel_rec_loss + mask_rec_loss + geo_rec_loss + perceptual_loss) / self.rec_scale + nll_loss = rec_loss + nll_loss = torch.mean(nll_loss) + + ############# GAN ############# + disc_factor = 0. if global_step > self.discriminator_iter_start else self.disc_factor + # update generator (input: img, mask, coord, [cond]) + if optimizer_idx == 0: + disc_recons = reconstructions.contiguous() + if self.geo_factor > 0: + disc_recons = torch.cat((disc_recons, rec_coord[:, :2].contiguous()), dim=1) + if cond is not None and self.disc_conditional: + disc_recons = torch.cat((disc_recons, cond), dim=1) + logits_fake = self.discriminator(disc_recons) + + # adversarial loss + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/pix_rec_loss".format(split): pixel_rec_loss.detach().mean(), + "{}/geo_rec_loss".format(split): geo_rec_loss.detach().mean(), + "{}/mask_rec_loss".format(split): mask_rec_loss.detach().mean(), + "{}/perceptual_loss".format(split): perceptual_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean()} + + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + # update discriminator (input: img, mask, coord, [cond]) + if optimizer_idx == 1: + disc_inputs, disc_recons = inputs.contiguous().detach(), reconstructions.contiguous().detach() + if self.mask_factor > 0 and masks is not None: + disc_inputs = torch.cat((disc_inputs, masks.contiguous().detach()), dim=1) + if self.geo_factor > 0: + disc_inputs = torch.cat((disc_inputs, input_coord[:, :2].contiguous()), dim=1) + disc_recons = torch.cat((disc_recons, rec_coord[:, :2].contiguous()), dim=1) + if cond is not None: + disc_inputs = torch.cat((disc_inputs, cond), dim=1) + disc_recons = torch.cat((disc_recons, cond), dim=1) + logits_real = self.discriminator(disc_inputs) + logits_fake = self.discriminator(disc_recons) + + # gan loss + d_loss = self.disc_loss(logits_real, logits_fake) * disc_factor + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean()} + + return d_loss, log diff --git a/lidm/modules/minkowskinet/__init__.py b/lidm/modules/minkowskinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/minkowskinet/model.py b/lidm/modules/minkowskinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..daa36a5009e96f3331653341c6b185627d688a19 --- /dev/null +++ b/lidm/modules/minkowskinet/model.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn + +try: + import torchsparse + import torchsparse.nn as spnn + from ..ts import basic_blocks +except ImportError: + raise Exception('Required ts lib. Reference: https://github.com/mit-han-lab/torchsparse/tree/v1.4.0') + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + + cr = config.model_params.cr + cs = config.model_params.layer_num + cs = [int(cr * x) for x in cs] + + self.pres = self.vres = config.model_params.voxel_size + self.num_classes = config.model_params.num_class + + self.stem = nn.Sequential( + spnn.Conv3d(config.model_params.input_dims, cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True), + spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True)) + + self.stage1 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), + ) + + self.stage2 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), + ) + + self.stage3 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), + ) + + self.stage4 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), + ) + + self.up1 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1), + ) + ]) + + self.up2 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1), + ) + ]) + + self.up3 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1), + ) + ]) + + self.up4 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1), + ) + ]) + + self.classifier = nn.Sequential(nn.Linear(cs[8], self.num_classes)) + + self.weight_initialization() + self.dropout = nn.Dropout(0.3, True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, data_dict, return_logits=False, return_final_logits=False): + x = data_dict['lidar'] + x.C = x.C.int() + + x0 = self.stem(x) + x1 = self.stage1(x0) + x2 = self.stage2(x1) + x3 = self.stage3(x2) + x4 = self.stage4(x3) + + if return_logits: + output_dict = dict() + output_dict['logits'] = x4.F + output_dict['batch_indices'] = x4.C[:, -1] + return output_dict + + y1 = self.up1[0](x4) + y1 = torchsparse.cat([y1, x3]) + y1 = self.up1[1](y1) + + y2 = self.up2[0](y1) + y2 = torchsparse.cat([y2, x2]) + y2 = self.up2[1](y2) + + y3 = self.up3[0](y2) + y3 = torchsparse.cat([y3, x1]) + y3 = self.up3[1](y3) + + y4 = self.up4[0](y3) + y4 = torchsparse.cat([y4, x0]) + y4 = self.up4[1](y4) + if return_final_logits: + output_dict = dict() + output_dict['logits'] = y4.F + output_dict['coords'] = y4.C[:, :3] + output_dict['batch_indices'] = y4.C[:, -1] + return output_dict + + output = self.classifier(y4.F) + data_dict['output'] = output.F + + return data_dict diff --git a/lidm/modules/rangenet/__init__.py b/lidm/modules/rangenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/rangenet/model.py b/lidm/modules/rangenet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..752fae9effd476a6bff255e0063675d1cc2f72e2 --- /dev/null +++ b/lidm/modules/rangenet/model.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +# This file is covered by the LICENSE file in the root of this project. +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, bn_d=0.1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, + stride=1, padding=0, bias=False) + self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d) + self.relu1 = nn.LeakyReLU(0.1) + self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d) + self.relu2 = nn.LeakyReLU(0.1) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu2(out) + + out += residual + return out + + +# ****************************************************************************** + +# number of layers per model +model_blocks = { + 21: [1, 1, 2, 2, 1], + 53: [1, 2, 8, 8, 4], +} + + +class Backbone(nn.Module): + """ + Class for DarknetSeg. Subclasses PyTorch's own "nn" module + """ + + def __init__(self, params): + super(Backbone, self).__init__() + self.use_range = params["input_depth"]["range"] + self.use_xyz = params["input_depth"]["xyz"] + self.use_remission = params["input_depth"]["remission"] + self.drop_prob = params["dropout"] + self.bn_d = params["bn_d"] + self.OS = params["OS"] + self.layers = params["extra"]["layers"] + + # input depth calc + self.input_depth = 0 + self.input_idxs = [] + if self.use_range: + self.input_depth += 1 + self.input_idxs.append(0) + if self.use_xyz: + self.input_depth += 3 + self.input_idxs.extend([1, 2, 3]) + if self.use_remission: + self.input_depth += 1 + self.input_idxs.append(4) + + # stride play + self.strides = [2, 2, 2, 2, 2] + # check current stride + current_os = 1 + for s in self.strides: + current_os *= s + + # make the new stride + if self.OS > current_os: + print("Can't do OS, ", self.OS, + " because it is bigger than original ", current_os) + else: + # redo strides according to needed stride + for i, stride in enumerate(reversed(self.strides), 0): + if int(current_os) != self.OS: + if stride == 2: + current_os /= 2 + self.strides[-1 - i] = 1 + if int(current_os) == self.OS: + break + + # check that darknet exists + assert self.layers in model_blocks.keys() + + # generate layers depending on darknet type + self.blocks = model_blocks[self.layers] + + # input layer + self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d) + self.relu1 = nn.LeakyReLU(0.1) + + # encoder + self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0], + stride=self.strides[0], bn_d=self.bn_d) + self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1], + stride=self.strides[1], bn_d=self.bn_d) + self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2], + stride=self.strides[2], bn_d=self.bn_d) + self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3], + stride=self.strides[3], bn_d=self.bn_d) + self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4], + stride=self.strides[4], bn_d=self.bn_d) + + # for a bit of fun + self.dropout = nn.Dropout2d(self.drop_prob) + + # last channels + self.last_channels = 1024 + + # make layer useful function + def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1): + layers = [] + + # downsample + layers.append(("conv", nn.Conv2d(planes[0], planes[1], + kernel_size=3, + stride=[1, stride], dilation=1, + padding=1, bias=False))) + layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) + layers.append(("relu", nn.LeakyReLU(0.1))) + + # blocks + inplanes = planes[1] + for i in range(0, blocks): + layers.append(("residual_{}".format(i), + block(inplanes, planes, bn_d))) + + return nn.Sequential(OrderedDict(layers)) + + def run_layer(self, x, layer, skips, os): + y = layer(x) + if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]: + skips[os] = x.detach() + os *= 2 + x = y + return x, skips, os + + def forward(self, x, return_logits=False, return_list=None): + # filter input + x = x[:, self.input_idxs] + + # run cnn + # store for skip connections + skips = {} + out_dict = {} + os = 1 + + # first layer + x, skips, os = self.run_layer(x, self.conv1, skips, os) + x, skips, os = self.run_layer(x, self.bn1, skips, os) + x, skips, os = self.run_layer(x, self.relu1, skips, os) + if return_list and 'enc_0' in return_list: + out_dict['enc_0'] = x.detach().cpu() # 32, 64, 1024 + + # all encoder blocks with intermediate dropouts + x, skips, os = self.run_layer(x, self.enc1, skips, os) + if return_list and 'enc_1' in return_list: + out_dict['enc_1'] = x.detach().cpu() # 64, 64, 512 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc2, skips, os) + if return_list and 'enc_2' in return_list: + out_dict['enc_2'] = x.detach().cpu() # 128, 64, 256 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc3, skips, os) + if return_list and 'enc_3' in return_list: + out_dict['enc_3'] = x.detach().cpu() # 256, 64, 128 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc4, skips, os) + if return_list and 'enc_4' in return_list: + out_dict['enc_4'] = x.detach().cpu() # 512, 64, 64 + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + x, skips, os = self.run_layer(x, self.enc5, skips, os) + if return_list and 'enc_5' in return_list: + out_dict['enc_5'] = x.detach().cpu() # 1024, 64, 32 + if return_logits: + return x + + x, skips, os = self.run_layer(x, self.dropout, skips, os) + + if return_list is not None: + return x, skips, out_dict + return x, skips + + def get_last_depth(self): + return self.last_channels + + def get_input_depth(self): + return self.input_depth + + +class Decoder(nn.Module): + """ + Class for DarknetSeg. Subclasses PyTorch's own "nn" module + """ + + def __init__(self, params, OS=32, feature_depth=1024): + super(Decoder, self).__init__() + self.backbone_OS = OS + self.backbone_feature_depth = feature_depth + self.drop_prob = params["dropout"] + self.bn_d = params["bn_d"] + self.index = 0 + + # stride play + self.strides = [2, 2, 2, 2, 2] + # check current stride + current_os = 1 + for s in self.strides: + current_os *= s + # redo strides according to needed stride + for i, stride in enumerate(self.strides): + if int(current_os) != self.backbone_OS: + if stride == 2: + current_os /= 2 + self.strides[i] = 1 + if int(current_os) == self.backbone_OS: + break + + # decoder + self.dec5 = self._make_dec_layer(BasicBlock, + [self.backbone_feature_depth, 512], + bn_d=self.bn_d, + stride=self.strides[0]) + self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d, + stride=self.strides[1]) + self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d, + stride=self.strides[2]) + self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d, + stride=self.strides[3]) + self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d, + stride=self.strides[4]) + + # layer list to execute with skips + self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1] + + # for a bit of fun + self.dropout = nn.Dropout2d(self.drop_prob) + + # last channels + self.last_channels = 32 + + def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2): + layers = [] + + # downsample + if stride == 2: + layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1], + kernel_size=[1, 4], stride=[1, 2], + padding=[0, 1]))) + else: + layers.append(("conv", nn.Conv2d(planes[0], planes[1], + kernel_size=3, padding=1))) + layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) + layers.append(("relu", nn.LeakyReLU(0.1))) + + # blocks + layers.append(("residual", block(planes[1], planes, bn_d))) + + return nn.Sequential(OrderedDict(layers)) + + def run_layer(self, x, layer, skips, os): + feats = layer(x) # up + if feats.shape[-1] > x.shape[-1]: + os //= 2 # match skip + feats = feats + skips[os].detach() # add skip + x = feats + return x, skips, os + + def forward(self, x, skips, return_logits=False, return_list=None): + os = self.backbone_OS + out_dict = {} + + # run layers + x, skips, os = self.run_layer(x, self.dec5, skips, os) + if return_list and 'dec_4' in return_list: + out_dict['dec_4'] = x.detach().cpu() # 512, 64, 64 + x, skips, os = self.run_layer(x, self.dec4, skips, os) + if return_list and 'dec_3' in return_list: + out_dict['dec_3'] = x.detach().cpu() # 256, 64, 128 + x, skips, os = self.run_layer(x, self.dec3, skips, os) + if return_list and 'dec_2' in return_list: + out_dict['dec_2'] = x.detach().cpu() # 128, 64, 256 + x, skips, os = self.run_layer(x, self.dec2, skips, os) + if return_list and 'dec_1' in return_list: + out_dict['dec_1'] = x.detach().cpu() # 64, 64, 512 + x, skips, os = self.run_layer(x, self.dec1, skips, os) + if return_list and 'dec_0' in return_list: + out_dict['dec_0'] = x.detach().cpu() # 32, 64, 1024 + + logits = torch.clone(x).detach() + x = self.dropout(x) + + if return_logits: + return x, logits + if return_list is not None: + return out_dict + return x + + def get_last_depth(self): + return self.last_channels + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = Backbone(params=self.config["backbone"]) + self.decoder = Decoder(params=self.config["decoder"], OS=self.config["backbone"]["OS"], + feature_depth=self.backbone.get_last_depth()) + + def load_pretrained_weights(self, path): + w_dict = torch.load(path + "/backbone", + map_location=lambda storage, loc: storage) + self.backbone.load_state_dict(w_dict, strict=True) + w_dict = torch.load(path + "/segmentation_decoder", + map_location=lambda storage, loc: storage) + self.decoder.load_state_dict(w_dict, strict=True) + + def forward(self, x, return_logits=False, return_final_logits=False, return_list=None, agg_type='depth'): + if return_logits: + logits = self.backbone(x, return_logits) + logits = F.adaptive_avg_pool2d(logits, (1, 1)).squeeze() + logits = torch.clone(logits).detach().cpu().numpy() + return logits + elif return_list is not None: + x, skips, enc_dict = self.backbone(x, return_list=return_list) + dec_dict = self.decoder(x, skips, return_list=return_list) + out_dict = {**enc_dict, **dec_dict} + return out_dict + elif return_final_logits: + assert agg_type in ['all', 'sector', 'depth'] + y, skips = self.backbone(x) + y, logits = self.decoder(y, skips, True) + + B, C, H, W = logits.shape + N = 16 + + # avg all + if agg_type == 'all': + logits = logits.mean([2, 3]) + # avg in patch + elif agg_type == 'sector': + logits = logits.view(B, C, H, N, W // N).mean([2, 4]).reshape(B, -1) + # avg in row + elif agg_type == 'depth': + logits = logits.view(B, C, N, H // N, W).mean([3, 4]).reshape(B, -1) + + logits = torch.clone(logits).detach().cpu().numpy() + return logits + else: + y, skips = self.backbone(x) + y = self.decoder(y, skips, False) + return y diff --git a/lidm/modules/spvcnn/__init__.py b/lidm/modules/spvcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/spvcnn/model.py b/lidm/modules/spvcnn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7793f8e81cef35f331c0c2c70062023999e83a --- /dev/null +++ b/lidm/modules/spvcnn/model.py @@ -0,0 +1,179 @@ +import torch.nn as nn + +try: + import torchsparse + import torchsparse.nn as spnn + from torchsparse import PointTensor + from ..ts.utils import initial_voxelize, point_to_voxel, voxel_to_point + from ..ts import basic_blocks +except ImportError: + raise Exception('Required torchsparse lib. Reference: https://github.com/mit-han-lab/torchsparse/tree/v1.4.0') + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + cr = config.model_params.cr + cs = config.model_params.layer_num + cs = [int(cr * x) for x in cs] + + self.pres = self.vres = config.model_params.voxel_size + self.num_classes = config.model_params.num_class + + self.stem = nn.Sequential( + spnn.Conv3d(config.model_params.input_dims, cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True), + spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True)) + + self.stage1 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), + ) + + self.stage2 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), + ) + + self.stage3 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), + ) + + self.stage4 = nn.Sequential( + basic_blocks.BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), + basic_blocks.ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), + basic_blocks.ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), + ) + + self.up1 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1), + ) + ]) + + self.up2 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1), + ) + ]) + + self.up3 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1), + ) + ]) + + self.up4 = nn.ModuleList([ + basic_blocks.BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2), + nn.Sequential( + basic_blocks.ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, + dilation=1), + basic_blocks.ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1), + ) + ]) + + self.classifier = nn.Sequential(nn.Linear(cs[8], self.num_classes)) + + self.point_transforms = nn.ModuleList([ + nn.Sequential( + nn.Linear(cs[0], cs[4]), + nn.BatchNorm1d(cs[4]), + nn.ReLU(True), + ), + nn.Sequential( + nn.Linear(cs[4], cs[6]), + nn.BatchNorm1d(cs[6]), + nn.ReLU(True), + ), + nn.Sequential( + nn.Linear(cs[6], cs[8]), + nn.BatchNorm1d(cs[8]), + nn.ReLU(True), + ) + ]) + + self.weight_initialization() + self.dropout = nn.Dropout(0.3, True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, data_dict, return_logits=False, return_final_logits=False): + x = data_dict['lidar'] + + # x: SparseTensor z: PointTensor + z = PointTensor(x.F, x.C.float()) + + x0 = initial_voxelize(z, self.pres, self.vres) + + x0 = self.stem(x0) + z0 = voxel_to_point(x0, z, nearest=False) + z0.F = z0.F + + x1 = point_to_voxel(x0, z0) + x1 = self.stage1(x1) + x2 = self.stage2(x1) + x3 = self.stage3(x2) + x4 = self.stage4(x3) + z1 = voxel_to_point(x4, z0) + z1.F = z1.F + self.point_transforms[0](z0.F) + + y1 = point_to_voxel(x4, z1) + + if return_logits: + output_dict = dict() + output_dict['logits'] = y1.F + output_dict['batch_indices'] = y1.C[:, -1] + return output_dict + + y1.F = self.dropout(y1.F) + y1 = self.up1[0](y1) + y1 = torchsparse.cat([y1, x3]) + y1 = self.up1[1](y1) + + y2 = self.up2[0](y1) + y2 = torchsparse.cat([y2, x2]) + y2 = self.up2[1](y2) + z2 = voxel_to_point(y2, z1) + z2.F = z2.F + self.point_transforms[1](z1.F) + + y3 = point_to_voxel(y2, z2) + y3.F = self.dropout(y3.F) + y3 = self.up3[0](y3) + y3 = torchsparse.cat([y3, x1]) + y3 = self.up3[1](y3) + + y4 = self.up4[0](y3) + y4 = torchsparse.cat([y4, x0]) + y4 = self.up4[1](y4) + z3 = voxel_to_point(y4, z2) + z3.F = z3.F + self.point_transforms[2](z2.F) + + if return_final_logits: + output_dict = dict() + output_dict['logits'] = z3.F + output_dict['coords'] = z3.C[:, :3] + output_dict['batch_indices'] = z3.C[:, -1].long() + return output_dict + + # output = self.classifier(z3.F) + data_dict['logits'] = z3.F + + return data_dict diff --git a/lidm/modules/ts/__init__.py b/lidm/modules/ts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/modules/ts/basic_blocks.py b/lidm/modules/ts/basic_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..15720f7ce0a39aa147367cb53b88e04386499a43 --- /dev/null +++ b/lidm/modules/ts/basic_blocks.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# encoding: utf-8 +''' +@author: Xu Yan +@file: basic_blocks.py +@time: 2021/4/14 22:53 +''' +import torch.nn as nn +import torchsparse.nn as spnn + + +class BasicConvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d( + inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + out = self.net(x) + return out + + +class BasicDeconvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d( + inc, + outc, + kernel_size=ks, + stride=stride, + transposed=True), + spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + return self.net(x) + + +class ResidualBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d( + inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), spnn.BatchNorm(outc), + spnn.ReLU(True), + spnn.Conv3d( + outc, + outc, + kernel_size=ks, + dilation=dilation, + stride=1), + spnn.BatchNorm(outc)) + + self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ + nn.Sequential( + spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), + spnn.BatchNorm(outc) + ) + + self.ReLU = spnn.ReLU(True) + + def forward(self, x): + out = self.ReLU(self.net(x) + self.downsample(x)) + return out diff --git a/lidm/modules/ts/utils.py b/lidm/modules/ts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77e5e055ce584c5b09bee5857166a950c597ee0d --- /dev/null +++ b/lidm/modules/ts/utils.py @@ -0,0 +1,86 @@ +import torch +import torchsparse.nn.functional as F +from torchsparse import PointTensor, SparseTensor +from torchsparse.nn.utils import get_kernel_offsets + +__all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point'] + + +# z: PointTensor +# return: SparseTensor +def initial_voxelize(z, init_res, after_res): + new_float_coord = torch.cat([(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) + + pc_hash = F.sphash(torch.floor(new_float_coord).int()) + sparse_hash = torch.unique(pc_hash) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), len(sparse_hash)) + + inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, counts) + inserted_coords = torch.round(inserted_coords).int() + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + + new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) + new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) + z.additional_features['idx_query'][1] = idx_query + z.additional_features['counts'][1] = counts + z.C = new_float_coord + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: SparseTensor +def point_to_voxel(x, z): + if z.additional_features is None or \ + z.additional_features.get('idx_query') is None or \ + z.additional_features['idx_query'].get(x.s) is None: + pc_hash = F.sphash( + torch.cat([torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], z.C[:, -1].int().view(-1, 1)], 1)) + sparse_hash = F.sphash(x.C) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), x.C.shape[0]) + z.additional_features['idx_query'][x.s] = idx_query + z.additional_features['counts'][x.s] = counts + else: + idx_query = z.additional_features['idx_query'][x.s] + counts = z.additional_features['counts'][x.s] + + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + new_tensor = SparseTensor(inserted_feat, x.C, x.s) + new_tensor.cmaps = x.cmaps + new_tensor.kmaps = x.kmaps + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: PointTensor +def voxel_to_point(x, z, nearest=False): + if z.idx_query is None or z.weights is None or z.idx_query.get(x.s) is None or z.weights.get(x.s) is None: + off = get_kernel_offsets(2, x.s, 1, device=z.F.device) + old_hash = F.sphash( + torch.cat([ + torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], + z.C[:, -1].int().view(-1, 1)], 1), off) + pc_hash = F.sphash(x.C.to(z.F.device)) + idx_query = F.sphashquery(old_hash, pc_hash) + weights = F.calc_ti_weights(z.C, idx_query, scale=x.s[0]).transpose(0, 1).contiguous() + idx_query = idx_query.transpose(0, 1).contiguous() + if nearest: + weights[:, 1:] = 0. + idx_query[:, 1:] = -1 + new_feat = F.spdevoxelize(x.F, idx_query, weights) + new_tensor = PointTensor(new_feat, z.C, idx_query=z.idx_query, weights=z.weights) + new_tensor.additional_features = z.additional_features + new_tensor.idx_query[x.s] = idx_query + new_tensor.weights[x.s] = weights + z.idx_query[x.s] = idx_query + z.weights[x.s] = weights + + else: + new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s)) + new_tensor = PointTensor(new_feat, z.C, idx_query=z.idx_query, weights=z.weights) + new_tensor.additional_features = z.additional_features + + return new_tensor \ No newline at end of file diff --git a/lidm/modules/x_transformer.py b/lidm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b969b4e904a4f8ecfcdb4b561ad62c24bff087f --- /dev/null +++ b/lidm/modules/x_transformer.py @@ -0,0 +1,642 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + + return inner + + +def not_equals(val): + def inner(x): + return x != val + + return inner + + +def equals(val): + def inner(x): + return x == val + + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out diff --git a/lidm/utils/__init__.py b/lidm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lidm/utils/aug_utils.py b/lidm/utils/aug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01c130f305afee85e02d2248122059e67ec158ad --- /dev/null +++ b/lidm/utils/aug_utils.py @@ -0,0 +1,107 @@ +import numpy as np + + +def get_lidar_transform(config, split): + transform_list = [] + if config['rotate']: + transform_list.append(RandomRotateAligned()) + if config['flip']: + transform_list.append(RandomFlip()) + return Compose(transform_list) if len(transform_list) > 0 and split == 'train' else None + + +def get_camera_transform(config, split): + # import open_clip + # transform = open_clip.image_transform((224, 224), split == 'train', resize_longest_max=True) + # TODO + transform = None + return transform + + +def get_anno_transform(config, split): + if config['keypoint_drop'] and split == 'train': + drop_range = config['keypoint_drop_range'] if 'keypoint_drop_range' in config else (5, 60) + transform = RandomKeypointDrop(drop_range) + else: + transform = None + return transform + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, pcd, pcd1=None): + for t in self.transforms: + pcd, pcd1 = t(pcd, pcd1) + return pcd, pcd1 + + +class RandomFlip(object): + def __init__(self, p=1.): + self.p = p + + def __call__(self, coord, coord1=None): + if np.random.rand() < self.p: + if np.random.rand() < 0.5: + coord[:, 0] = -coord[:, 0] + if coord1 is not None: + coord1[:, 0] = -coord1[:, 0] + if np.random.rand() < 0.5: + coord[:, 1] = -coord[:, 1] + if coord1 is not None: + coord1[:, 1] = -coord1[:, 1] + return coord, coord1 + + +class RandomRotateAligned(object): + def __init__(self, rot=np.pi / 4, p=1.): + self.rot = rot + self.p = p + + def __call__(self, coord, coord1=None): + if np.random.rand() < self.p: + angle_z = np.random.uniform(-self.rot, self.rot) + cos_z, sin_z = np.cos(angle_z), np.sin(angle_z) + R = np.array([[cos_z, -sin_z, 0], [sin_z, cos_z, 0], [0, 0, 1]]) + coord = np.dot(coord, R) + if coord1 is not None: + coord1 = np.dot(coord1, R) + return coord, coord1 + + +class RandomKeypointDrop(object): + def __init__(self, num_range=(5, 60), p=.5): + self.num_range = num_range + self.p = p + + def __call__(self, center, category=None): + if np.random.rand() < self.p: + num = len(center) + if num > self.num_range[0]: + num_kept = np.random.randint(self.num_range[0], min(self.num_range[1], num)) + idx_kept = np.random.choice(num, num_kept, replace=False) + center, category = center[idx_kept], category[idx_kept] + return center, category + + +# class ResizeMaxSize(object): +# def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): +# super().__init__() +# if not isinstance(max_size, int): +# raise TypeError(f"Size should be int. Got {type(max_size)}") +# self.max_size = max_size +# self.interpolation = interpolation +# self.fn = min if fn == 'min' else min +# self.fill = fill +# +# def forward(self, img): +# width, height = img.size +# scale = self.max_size / float(max(height, width)) +# if scale != 1.0: +# new_size = tuple(round(dim * scale) for dim in (height, width)) +# img = F.resize(img, new_size, self.interpolation) +# pad_h = self.max_size - new_size[0] +# pad_w = self.max_size - new_size[1] +# img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) +# return img diff --git a/lidm/utils/lidar_utils.py b/lidm/utils/lidar_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54845771e9fcff03d265a4426f6e214386aee2ba --- /dev/null +++ b/lidm/utils/lidar_utils.py @@ -0,0 +1,206 @@ +import math + +import numpy as np + + +def pcd2coord2d(pcd, fov, depth_range, labels=None): + # laser parameters + fov_up = fov[0] / 180.0 * np.pi # field of view up in rad + fov_down = fov[1] / 180.0 * np.pi # field of view down in rad + fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad + + # get depth (distance) of all points + depth = np.linalg.norm(pcd, 2, axis=-1) + + # mask points out of range + mask = np.logical_and(depth > depth_range[0], depth < depth_range[1]) + if pcd.ndim == 3: + mask = mask.all(axis=1) + depth, pcd = depth[mask], pcd[mask] + + # get scan components + scan_x, scan_y, scan_z = pcd[..., 0], pcd[..., 1], pcd[..., 2] + + # get angles of all points + yaw = -np.arctan2(scan_y, scan_x) + pitch = np.arcsin(scan_z / depth) + + # get projections in image coords + proj_x = np.clip(0.5 * (yaw / np.pi + 1.0), 0., 1.) # in [0.0, 1.0] + proj_y = np.clip(1.0 - (pitch + abs(fov_down)) / fov_range, 0., 1.) # in [0.0, 1.0] + proj_coord2d = np.stack([proj_x, proj_y], axis=-1) + + if labels is not None: + proj_labels = labels[mask] + else: + proj_labels = None + + return proj_coord2d, proj_labels + + +def pcd2range(pcd, size, fov, depth_range, remission=None, labels=None, **kwargs): + # laser parameters + fov_up = fov[0] / 180.0 * np.pi # field of view up in rad + fov_down = fov[1] / 180.0 * np.pi # field of view down in rad + fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad + + # get depth (distance) of all points + depth = np.linalg.norm(pcd, 2, axis=1) + + # mask points out of range + mask = np.logical_and(depth > depth_range[0], depth < depth_range[1]) + depth, pcd = depth[mask], pcd[mask] + + # get scan components + scan_x, scan_y, scan_z = pcd[:, 0], pcd[:, 1], pcd[:, 2] + + # get angles of all points + yaw = -np.arctan2(scan_y, scan_x) + pitch = np.arcsin(scan_z / depth) + + # get projections in image coords + proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0] + proj_y = 1.0 - (pitch + abs(fov_down)) / fov_range # in [0.0, 1.0] + + # scale to image size using angular resolution + proj_x *= size[1] # in [0.0, W] + proj_y *= size[0] # in [0.0, H] + + # round and clamp for use as index + proj_x = np.maximum(0, np.minimum(size[1] - 1, np.floor(proj_x))).astype(np.int32) # in [0,W-1] + proj_y = np.maximum(0, np.minimum(size[0] - 1, np.floor(proj_y))).astype(np.int32) # in [0,H-1] + + # order in decreasing depth + order = np.argsort(depth)[::-1] + proj_x, proj_y = proj_x[order], proj_y[order] + + # project depth + depth = depth[order] + proj_range = np.full(size, -1, dtype=np.float32) + proj_range[proj_y, proj_x] = depth + + # project point feature + if remission is not None: + remission = remission[mask][order] + proj_feature = np.full(size, -1, dtype=np.float32) + proj_feature[proj_y, proj_x] = remission + elif labels is not None: + labels = labels[mask][order] + proj_feature = np.full(size, 0, dtype=np.float32) + proj_feature[proj_y, proj_x] = labels + else: + proj_feature = None + + return proj_range, proj_feature + + +def range2pcd(range_img, fov, depth_range, depth_scale, log_scale=True, label=None, color=None, **kwargs): + # laser parameters + size = range_img.shape + fov_up = fov[0] / 180.0 * np.pi # field of view up in rad + fov_down = fov[1] / 180.0 * np.pi # field of view down in rad + fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad + + # inverse transform from depth + depth = (range_img * depth_scale).flatten() + if log_scale: + depth = np.exp2(depth) - 1 + + scan_x, scan_y = np.meshgrid(np.arange(size[1]), np.arange(size[0])) + scan_x = scan_x.astype(np.float64) / size[1] + scan_y = scan_y.astype(np.float64) / size[0] + + yaw = (np.pi * (scan_x * 2 - 1)).flatten() + pitch = ((1.0 - scan_y) * fov_range - abs(fov_down)).flatten() + + pcd = np.zeros((len(yaw), 3)) + pcd[:, 0] = np.cos(yaw) * np.cos(pitch) * depth + pcd[:, 1] = -np.sin(yaw) * np.cos(pitch) * depth + pcd[:, 2] = np.sin(pitch) * depth + + # mask out invalid points + mask = np.logical_and(depth > depth_range[0], depth < depth_range[1]) + pcd = pcd[mask, :] + + # label + if label is not None: + label = label.flatten()[mask] + + # default point color + if color is not None: + color = color.reshape(-1, 3)[mask, :] + else: + color = np.ones((pcd.shape[0], 3)) * [0.7, 0.7, 1] + + return pcd, color, label + + +def range2xyz(range_img, fov, depth_range, depth_scale, log_scale=True, **kwargs): + # laser parameters + size = range_img.shape + fov_up = fov[0] / 180.0 * np.pi # field of view up in rad + fov_down = fov[1] / 180.0 * np.pi # field of view down in rad + fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad + + # inverse transform from depth + if log_scale: + depth = (np.exp2(range_img * depth_scale) - 1) + else: + depth = range_img + + scan_x, scan_y = np.meshgrid(np.arange(size[1]), np.arange(size[0])) + scan_x = scan_x.astype(np.float64) / size[1] + scan_y = scan_y.astype(np.float64) / size[0] + + yaw = np.pi * (scan_x * 2 - 1) + pitch = (1.0 - scan_y) * fov_range - abs(fov_down) + + xyz = -np.ones((3, *size)) + xyz[0] = np.cos(yaw) * np.cos(pitch) * depth + xyz[1] = -np.sin(yaw) * np.cos(pitch) * depth + xyz[2] = np.sin(pitch) * depth + + # mask out invalid points + mask = np.logical_and(depth > depth_range[0], depth < depth_range[1]) + xyz[:, ~mask] = -1 + + return xyz + + +def pcd2bev(pcd, x_range, y_range, z_range, resolution, **kwargs): + # mask out invalid points + mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1]) + mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1]) + mask_z = np.logical_and(pcd[:, 2] > z_range[0], pcd[:, 2] < z_range[1]) + mask = mask_x & mask_y & mask_z + pcd = pcd[mask] + + # points to bev coords + bev_x = np.floor((pcd[:, 0] - x_range[0]) / resolution).astype(np.int32) + bev_y = np.floor((pcd[:, 1] - y_range[0]) / resolution).astype(np.int32) + + # 2D bev grid + bev_shape = (math.ceil((x_range[1] - x_range[0]) // resolution), math.ceil((y_range[1] - y_range[0]) // resolution)) + bev_grid = np.zeros(bev_shape, dtype=np.float64) + + # populate the BEV grid with bev coords + bev_grid[bev_x, bev_y] = 1 + + return bev_grid + + +if __name__ == '__main__': + # test = np.loadtxt('test_range.txt') + # pcd, _, _ = range2pcd(test, (32, 1024), (10, -30)) + # np.savetxt('test_pcd.txt', pcd, fmt='%.4f') + + # import matplotlib.pyplot as plt + # pcd = np.loadtxt('test_origin.txt') + # bev_grid = pcd2bev(pcd) + # plt.imshow(bev_grid[:, :, 0], cmap='gray') # Display the BEV for the first height level + # plt.savefig('test.png', dpi=300, bbox_inches='tight', pad_inches=0, transparent=True) + + from PIL import Image + img = Image.open('assets/kitti/range.png') + img.convert('L') + img = np.array(img) / 255. diff --git a/lidm/utils/lr_scheduler.py b/lidm/utils/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/lidm/utils/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/lidm/utils/misc_utils.py b/lidm/utils/misc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..339e5e27c7a8a5029a9f2dd0cf9f44006b2b756d --- /dev/null +++ b/lidm/utils/misc_utils.py @@ -0,0 +1,243 @@ +import argparse +import importlib +import random + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def set_seed(seed): + """ + Setting of Global Seed for Reproducibility (for inference only) + + refer to: https://pytorch.org/docs/stable/notes/randomness.html + + """ + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def print_fn(msg, verbose): + if verbose: + print(msg) + + +def dict2namespace(config): + namespace = argparse.Namespace() + for key, value in config.items(): + if isinstance(value, dict): + new_value = dict2namespace(value) + else: + new_value = value + setattr(namespace, key, new_value) + return namespace + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def isdepth(x): + if not isinstance(x, (torch.Tensor, np.ndarray)): + return False + return ((len(x.shape) == 4) and (x.shape[1] == 1)) or (len(x.shape) == 3) + + +def ismap(x): + if not isinstance(x, (torch.Tensor, np.ndarray)): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, (torch.Tensor, np.ndarray)): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/lidm/utils/model_utils.py b/lidm/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f040e025d44d6c310c874ed7bcef9e789ef73296 --- /dev/null +++ b/lidm/utils/model_utils.py @@ -0,0 +1,41 @@ +import os + +import torch +import yaml + +from lidm.utils.misc_utils import dict2namespace +from ..modules.rangenet.model import Model as rangenet + +try: + from ..modules.spvcnn.model import Model as spvcnn + from ..modules.minkowskinet.model import Model as minkowskinet +except: + print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24') + +DEFAULT_ROOT = './pretrained_weights' + + +def build_model(dataset_name, model_name, device='cpu'): + # config + model_folder = os.path.join(DEFAULT_ROOT, dataset_name, model_name) + + if not os.path.isdir(model_folder): + raise Exception('Not Available Pretrained Weights!') + + config = yaml.safe_load(open(os.path.join(model_folder, 'config.yaml'), 'r')) + if model_name != 'rangenet': + config = dict2namespace(config) + + # build model + model = eval(model_name)(config) + + # load checkpoint + if model_name == 'rangenet': + model.load_pretrained_weights(model_folder) + else: + ckpt = torch.load(os.path.join(model_folder, 'model.ckpt'), map_location="cpu") + model.load_state_dict(ckpt['state_dict'], strict=False) + model.to(device) + model.eval() + + return model diff --git a/models/first_stage_models/kitti/f_c2_p4/config.yaml b/models/first_stage_models/kitti/f_c2_p4/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fbedeed9f280022b163af049bb2d6f2de8da7e88 --- /dev/null +++ b/models/first_stage_models/kitti/f_c2_p4/config.yaml @@ -0,0 +1,57 @@ +model: + base_learning_rate: 4.5e-6 + target: lidm.models.autoencoder.VQModel + params: + monitor: val/rec_loss + embed_dim: 8 + n_embed: 16384 + lib_name: lidm + use_mask: False # False + ddconfig: + double_z: false + z_channels: 8 + in_channels: 1 + out_ch: 1 + ch: 64 + ch_mult: [1,2,2,4] # num_down = len(ch_mult)-1 + strides: [[1,2],[2,2],[2,2]] + num_res_blocks: 2 + attn_levels: [] + dropout: 0.0 + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 8 + wrap: true + dataset: + size: [64, 1024] + fov: [ 3,-25 ] + depth_range: [ 1.0,56.0 ] + depth_scale: 5.84 # np.log2(depth_max + 1) + log_scale: true + x_range: [ -50.0, 50.0 ] + y_range: [ -50.0, 50.0 ] + z_range: [ -3.0, 1.0 ] + resolution: 1 + num_channels: 1 + num_cats: 10 + num_views: 2 + num_sem_cats: 19 + filtered_map_cats: [ ] + aug: + flip: true + rotate: true + keypoint_drop: false + keypoint_drop_range: [ 5,20 ] + randaug: false + train: + target: lidm.data.kitti.KITTIImageTrain + params: + condition_key: image + validation: + target: lidm.data.kitti.KITTIImageValidation + params: + condition_key: image diff --git a/models/first_stage_models/kitti/f_c2_p4_wo_logscale/config.yaml b/models/first_stage_models/kitti/f_c2_p4_wo_logscale/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c4933ad3fe9c1ebba731074487983b1acfe2fb4 --- /dev/null +++ b/models/first_stage_models/kitti/f_c2_p4_wo_logscale/config.yaml @@ -0,0 +1,57 @@ +model: + base_learning_rate: 4.5e-6 + target: lidm.models.autoencoder.VQModel + params: + monitor: val/rec_loss + embed_dim: 8 + n_embed: 16384 + lib_name: lidm + use_mask: False # False + ddconfig: + double_z: false + z_channels: 8 + in_channels: 1 + out_ch: 1 + ch: 64 + ch_mult: [1,2,2,4] # num_down = len(ch_mult)-1 + strides: [[1,2],[2,2],[2,2]] + num_res_blocks: 2 + attn_levels: [] + dropout: 0.0 + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 8 + wrap: true + dataset: + size: [64, 1024] + fov: [ 3,-25 ] + depth_range: [ 1.0,56.0 ] + depth_scale: 56 # np.log2(depth_max + 1) + log_scale: false + x_range: [ -50.0, 50.0 ] + y_range: [ -50.0, 50.0 ] + z_range: [ -3.0, 1.0 ] + resolution: 1 + num_channels: 1 + num_cats: 10 + num_views: 2 + num_sem_cats: 19 + filtered_map_cats: [ ] + aug: + flip: true + rotate: true + keypoint_drop: false + keypoint_drop_range: [ 5,20 ] + randaug: false + train: + target: lidm.data.kitti.KITTIImageTrain + params: + condition_key: image + validation: + target: lidm.data.kitti.KITTIImageValidation + params: + condition_key: image diff --git a/models/first_stage_models/kitti/f_c2_p4_wo_logscale/model.ckpt b/models/first_stage_models/kitti/f_c2_p4_wo_logscale/model.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..821d069cbd3698a542eee442ee91977dc47975d5 --- /dev/null +++ b/models/first_stage_models/kitti/f_c2_p4_wo_logscale/model.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:987c413105ad4959899632033091ce9c0a49a56aed1a4f293dcb263ae7022d17 +size 215383923 diff --git a/models/lidm/kitti/cam2lidar/config.yaml b/models/lidm/kitti/cam2lidar/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef2f705c08fa7f811f96090e9f04ec8298c79fd6 --- /dev/null +++ b/models/lidm/kitti/cam2lidar/config.yaml @@ -0,0 +1,110 @@ +model: + base_learning_rate: 2.0e-06 + target: lidm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 100 + timesteps: 1000 + image_size: [16, 128] + channels: 8 + monitor: val/loss_simple_ema + first_stage_key: image + cond_stage_key: camera + conditioning_key: crossattn + cond_stage_trainable: true + verbose: false + unet_config: + target: lidm.modules.diffusion.openaimodel.UNetModel + params: + image_size: [16, 128] + in_channels: 8 + out_channels: 8 + model_channels: 256 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 32 + use_spatial_transformer: true + context_dim: 512 + lib_name: lidm + first_stage_config: + target: lidm.models.autoencoder.VQModelInterface + params: + embed_dim: 8 + n_embed: 16384 + lib_name: lidm + use_mask: False # False + ckpt_path: models/first_stage_models/kitti/f_c2_p4_wo_ls/model.ckpt + ddconfig: + double_z: false + z_channels: 8 + in_channels: 1 + out_ch: 1 + ch: 64 + ch_mult: [1,2,2,4] + strides: [[1,2],[2,2],[2,2]] + num_res_blocks: 2 + attn_levels: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: lidm.modules.encoders.modules.FrozenClipMultiImageEmbedder + params: + model: ViT-L/14 + out_dim: 512 + split_per_view: 4 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 8 + wrap: true + dataset: + size: [64, 1024] + fov: [ 3,-25 ] + depth_range: [ 1.0,56.0 ] + depth_scale: 56 # np.log2(depth_max + 1) + log_scale: false + x_range: [ -50.0, 50.0 ] + y_range: [ -50.0, 50.0 ] + z_range: [ -3.0, 1.0 ] + resolution: 1 + num_channels: 1 + num_cats: 10 + num_views: 1 + num_sem_cats: 19 + filtered_map_cats: [ ] + aug: + flip: false + rotate: false + keypoint_drop: false + keypoint_drop_range: + randaug: false + camera_drop: 0.5 + train: + target: lidm.data.kitti.KITTI360Train + params: + condition_key: camera + split_per_view: 4 + validation: + target: lidm.data.kitti.KITTI360Validation + params: + condition_key: camera + split_per_view: 4 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/models/lidm/kitti/cam2lidar/model.ckpt b/models/lidm/kitti/cam2lidar/model.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..807af1089443089a16864c3fa1275a731576c712 --- /dev/null +++ b/models/lidm/kitti/cam2lidar/model.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc6b466d3865ff73813268ba291b32022461e9cfca7cf5bc1534ec08c9cf932a +size 8093788309 diff --git a/models/lidm/kitti/sem2lidar/config.yaml b/models/lidm/kitti/sem2lidar/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d2cfc8e68c3bbb324570e0e85a2d7181ae78cc43 --- /dev/null +++ b/models/lidm/kitti/sem2lidar/config.yaml @@ -0,0 +1,105 @@ +model: + base_learning_rate: 1.0e-06 + target: lidm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + num_timesteps_cond: 1 + log_every_t: 100 + timesteps: 1000 + image_size: [16, 128] + channels: 8 + monitor: val/loss_simple_ema + first_stage_key: image + cond_stage_key: segmentation + concat_mode: true + cond_stage_trainable: true + verbose: false + unet_config: + target: lidm.modules.diffusion.openaimodel.UNetModel + params: + image_size: [16, 128] + in_channels: 16 + out_channels: 8 + model_channels: 256 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 32 + lib_name: lidm + first_stage_config: + target: lidm.models.autoencoder.VQModelInterface + params: + embed_dim: 8 + n_embed: 16384 + lib_name: lidm + use_mask: False # False + ckpt_path: models/first_stage_models/kitti/f_c2_p4_wo_ls/model.ckpt + ddconfig: + double_z: false + z_channels: 8 + in_channels: 1 + out_ch: 1 + ch: 64 + ch_mult: [1,2,2,4] + strides: [[1,2],[2,2],[2,2]] + num_res_blocks: 2 + attn_levels: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: lidm.modules.encoders.modules.SpatialRescaler + params: + strides: [[1,2],[2,2],[2,2]] + in_channels: 20 + out_channels: 8 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 16 + num_workers: 8 + wrap: true + dataset: + size: [64, 1024] + fov: [ 3,-25 ] + depth_range: [ 1.0,56.0 ] + depth_scale: 56 # np.log2(depth_max + 1) + log_scale: false + x_range: [ -50.0, 50.0 ] + y_range: [ -50.0, 50.0 ] + z_range: [ -3.0, 1.0 ] + resolution: 1 + num_channels: 1 + num_cats: 10 + num_views: 2 + num_sem_cats: 19 + filtered_map_cats: [ ] + aug: + flip: true + rotate: false + keypoint_drop: false + keypoint_drop_range: [ 5,20 ] + randaug: false + train: + target: lidm.data.kitti.SemanticKITTITrain + params: + condition_key: segmentation + validation: + target: lidm.data.kitti.SemanticKITTIValidation + params: + condition_key: segmentation + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: true \ No newline at end of file diff --git a/models/lidm/kitti/text2lidar/config.yaml b/models/lidm/kitti/text2lidar/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4ceee5ef9a50533fc9b54fd3b9a384109ebeda8 --- /dev/null +++ b/models/lidm/kitti/text2lidar/config.yaml @@ -0,0 +1,111 @@ +model: + base_learning_rate: 2.0e-06 + target: lidm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 100 + timesteps: 1000 + image_size: [16, 128] + channels: 8 + monitor: val/loss_simple_ema + first_stage_key: image + cond_stage_key: camera + conditioning_key: crossattn + cond_stage_trainable: true + verbose: false + unet_config: + target: lidm.modules.diffusion.openaimodel.UNetModel + params: + image_size: [16, 128] + in_channels: 8 + out_channels: 8 + model_channels: 256 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 32 + use_spatial_transformer: true + context_dim: 512 + lib_name: lidm + first_stage_config: + target: lidm.models.autoencoder.VQModelInterface + params: + embed_dim: 8 + n_embed: 16384 + lib_name: lidm + use_mask: False # False + ckpt_path: models/first_stage_models/kitti/f_c2_p4_wo_ls/model.ckpt + ddconfig: + double_z: false + z_channels: 8 + in_channels: 1 + out_ch: 1 + ch: 64 + ch_mult: [1,2,2,4] + strides: [[1,2],[2,2],[2,2]] + num_res_blocks: 2 + attn_levels: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: lidm.modules.encoders.modules.FrozenClipMultiImageEmbedder + params: + model: ViT-L/14 + split_per_view: 4 + key: camera + out_dim: 512 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 8 + wrap: true + dataset: + size: [64, 1024] + fov: [ 3,-25 ] + depth_range: [ 1.0,56.0 ] + depth_scale: 56 # np.log2(depth_max + 1) + log_scale: false + x_range: [ -50.0, 50.0 ] + y_range: [ -50.0, 50.0 ] + z_range: [ -3.0, 1.0 ] + resolution: 1 + num_channels: 1 + num_cats: 10 + num_views: 1 + num_sem_cats: 19 + filtered_map_cats: [ ] + aug: + flip: false + rotate: false + keypoint_drop: false + keypoint_drop_range: + randaug: false + camera_drop: 0.5 + train: + target: lidm.data.kitti.KITTI360Train + params: + condition_key: camera + split_per_view: 4 + validation: + target: lidm.data.kitti.KITTI360Validation + params: + condition_key: camera + split_per_view: 4 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/models/lidm/kitti/uncond/config.yaml b/models/lidm/kitti/uncond/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d2ddd0e4c9916f0a6d4a4335684028e6251a93c2 --- /dev/null +++ b/models/lidm/kitti/uncond/config.yaml @@ -0,0 +1,96 @@ +model: + base_learning_rate: 1.0e-06 + target: lidm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + image_size: [16, 128] + channels: 8 + monitor: val/loss_simple_ema + first_stage_key: image + unet_config: + target: lidm.modules.diffusion.openaimodel.UNetModel + params: + image_size: [16, 128] + in_channels: 8 + out_channels: 8 + model_channels: 256 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 32 + lib_name: lidm + first_stage_config: + target: lidm.models.autoencoder.VQModelInterface + params: + embed_dim: 8 + n_embed: 16384 + lib_name: lidm + use_mask: False # False + ckpt_path: models/first_stage_models/kitti/f_c2_p4/model.ckpt + ddconfig: + double_z: false + z_channels: 8 + in_channels: 1 + out_ch: 1 + ch: 64 + ch_mult: [1,2,2,4] + strides: [[1,2],[2,2],[2,2]] + num_res_blocks: 2 + attn_levels: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: "__is_unconditional__" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 8 + wrap: true + dataset: + size: [64, 1024] + fov: [ 3,-25 ] + depth_range: [ 1.0,56.0 ] + depth_scale: 5.84 # np.log2(depth_max + 1) + log_scale: true + x_range: [ -50.0, 50.0 ] + y_range: [ -50.0, 50.0 ] + z_range: [ -3.0, 1.0 ] + resolution: 1 + num_channels: 1 + num_cats: 10 + num_views: 2 + num_sem_cats: 19 + filtered_map_cats: [ ] + aug: + flip: true + rotate: false + keypoint_drop: false + keypoint_drop_range: [ 5,20 ] + randaug: false + train: + target: lidm.data.kitti.KITTI360Train + params: + condition_key: image + validation: + target: lidm.data.kitti.KITTI360Validation + params: + condition_key: image + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: true diff --git a/models/lidm/kitti/uncond_wo_logscale/config.yaml b/models/lidm/kitti/uncond_wo_logscale/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9eb12d7c5c36c8c4dc0ef13f8997df880c39080b --- /dev/null +++ b/models/lidm/kitti/uncond_wo_logscale/config.yaml @@ -0,0 +1,96 @@ +model: + base_learning_rate: 1.0e-06 + target: lidm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + image_size: [16, 128] + channels: 8 + monitor: val/loss_simple_ema + first_stage_key: image + unet_config: + target: lidm.modules.diffusion.openaimodel.UNetModel + params: + image_size: [16, 128] + in_channels: 8 + out_channels: 8 + model_channels: 256 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 32 + lib_name: lidm + first_stage_config: + target: lidm.models.autoencoder.VQModelInterface + params: + embed_dim: 8 + n_embed: 16384 + lib_name: lidm + use_mask: False # False + ckpt_path: models/first_stage_models/kitti/f_c2_p4_wo_ls/model.ckpt + ddconfig: + double_z: false + z_channels: 8 + in_channels: 1 + out_ch: 1 + ch: 64 + ch_mult: [1,2,2,4] + strides: [[1,2],[2,2],[2,2]] + num_res_blocks: 2 + attn_levels: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: "__is_unconditional__" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 8 + wrap: true + dataset: + size: [64, 1024] + fov: [ 3,-25 ] + depth_range: [ 1.0,56.0 ] + depth_scale: 56 # np.log2(depth_max + 1) + log_scale: false + x_range: [ -50.0, 50.0 ] + y_range: [ -50.0, 50.0 ] + z_range: [ -3.0, 1.0 ] + resolution: 1 + num_channels: 1 + num_cats: 10 + num_views: 2 + num_sem_cats: 19 + filtered_map_cats: [ ] + aug: + flip: true + rotate: false + keypoint_drop: false + keypoint_drop_range: [ 5,20 ] + randaug: false + train: + target: lidm.data.kitti.KITTI360Train + params: + condition_key: image + validation: + target: lidm.data.kitti.KITTI360Validation + params: + condition_key: image + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: true diff --git a/sample_cond.py b/sample_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..4609b4cd88d7ad8a90000b0e901d457352c88f54 --- /dev/null +++ b/sample_cond.py @@ -0,0 +1,109 @@ +import os +import torch +import numpy as np + +from omegaconf import OmegaConf +from PIL import Image + +from lidm.models.diffusion.ddim import DDIMSampler +from lidm.utils.misc_utils import instantiate_from_config, isimage, ismap +from lidm.utils.lidar_utils import range2pcd +from app_config import DEVICE + + +CUSTOM_STEPS = 50 +ETA = 1.0 + +# model loading +MODEL_PATH = './models/lidm/kitti/cam2lidar' +CFG_PATH = os.path.join(MODEL_PATH, 'config.yaml') +CKPT_PATH = os.path.join(MODEL_PATH, 'model.ckpt') + +# settings +model_config = OmegaConf.load(CFG_PATH) + + +def custom_to_pcd(x, config, rgb=None): + x = x.squeeze().detach().cpu().numpy() + x = (np.clip(x, -1., 1.) + 1.) / 2. + if rgb is not None: + rgb = rgb.squeeze().detach().cpu().numpy() + rgb = (np.clip(rgb, -1., 1.) + 1.) / 2. + rgb = rgb.transpose(1, 2, 0) + xyz, rgb, _ = range2pcd(x, color=rgb, **config['data']['params']['dataset']) + + return xyz, rgb + + +def custom_to_pil(x): + x = x.detach().cpu().squeeze().numpy() + x = (np.clip(x, -1., 1.) + 1.) / 2. + x = (255 * x).astype(np.uint8) + + if x.ndim == 3: + x = x.transpose(1, 2, 0) + x = Image.fromarray(x) + + return x + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +def load_model_from_config(config, sd, device): + model = instantiate_from_config(config) + model.load_state_dict(sd, strict=False) + model.to(device) + model.eval() + return model + + +def load_model(): + pl_sd = torch.load(CKPT_PATH, map_location="cpu") + model = load_model_from_config(model_config.model, pl_sd["state_dict"], DEVICE) + return model + + +@torch.no_grad() +def convsample_ddim(model, cond, steps, shape, eta=1.0, verbose=False): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, conditioning=cond, batch_size=bs, shape=shape, eta=eta, verbose=verbose, disable_tqdm=True) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch, batch_size, custom_steps=None, eta=1.0): + xc = batch['camera'] + c = model.get_learned_conditioning(xc.to(model.device)) + + with model.ema_scope("Plotting"): + samples, z_denoise_row = model.sample_log(cond=c, batch_size=batch_size, ddim=True, + ddim_steps=custom_steps, eta=eta) + x_samples = model.decode_first_stage(samples) + + return x_samples + + +def sample(model, cond): + batch = {'camera': cond} + img = make_convolutional_sample(model, batch, batch_size=1, custom_steps=CUSTOM_STEPS, eta=ETA) # TODO add arguments for batch_size, custom_steps and eta + img = img[0, 0] + pcd = custom_to_pcd(img, model_config)[0].astype(np.float32) + return img, pcd +