import json import math import os import re import shutil from typing import List, Optional, Union import cv2 import imageio import matplotlib.pyplot as plt import numpy as np import torch # import wandb from matplotlib import cm from matplotlib.colors import LinearSegmentedColormap from PIL import Image, ImageDraw from .typing import * def tensor_to_image( data: Union[Image.Image, torch.Tensor, np.ndarray], batched: bool = False, format: str = "HWC", ) -> Union[Image.Image, List[Image.Image]]: if isinstance(data, Image.Image): return data if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() if data.dtype == np.float32 or data.dtype == np.float16: data = (data * 255).astype(np.uint8) elif data.dtype == np.bool_: data = data.astype(np.uint8) * 255 assert data.dtype == np.uint8 if format == "CHW": if batched and data.ndim == 4: data = data.transpose((0, 2, 3, 1)) elif not batched and data.ndim == 3: data = data.transpose((1, 2, 0)) if batched: return [Image.fromarray(d) for d in data] return Image.fromarray(data) def largest_factor_near_sqrt(n: int) -> int: """ Finds the largest factor of n that is closest to the square root of n. Args: n (int): The integer for which to find the largest factor near its square root. Returns: int: The largest factor of n that is closest to the square root of n. """ sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root # First, check if the square root itself is a factor if sqrt_n * sqrt_n == n: return sqrt_n # Otherwise, find the largest factor by iterating from sqrt_n downwards for i in range(sqrt_n, 0, -1): if n % i == 0: return i # If n is 1, return 1 return 1 def make_image_grid( images: List[Image.Image], rows: Optional[int] = None, cols: Optional[int] = None, resize: Optional[int] = None, ) -> Image.Image: """ Prepares a single grid of images. Useful for visualization purposes. """ if rows is None and cols is not None: assert len(images) % cols == 0 rows = len(images) // cols elif cols is None and rows is not None: assert len(images) % rows == 0 cols = len(images) // rows elif rows is None and cols is None: rows = largest_factor_near_sqrt(len(images)) cols = len(images) // rows assert len(images) == rows * cols if resize is not None: images = [img.resize((resize, resize)) for img in images] w, h = images[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(images): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid class SaverMixin: _save_dir: Optional[str] = None _wandb_logger: Optional[Any] = None def set_save_dir(self, save_dir: str): self._save_dir = save_dir def get_save_dir(self): if self._save_dir is None: raise ValueError("Save dir is not set") return self._save_dir def convert_data(self, data): if data is None: return None elif isinstance(data, np.ndarray): return data elif isinstance(data, torch.Tensor): if data.dtype in [torch.float16, torch.bfloat16]: data = data.float() return data.detach().cpu().numpy() elif isinstance(data, list): return [self.convert_data(d) for d in data] elif isinstance(data, dict): return {k: self.convert_data(v) for k, v in data.items()} else: raise TypeError( "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", type(data), ) def get_save_path(self, filename): save_path = os.path.join(self.get_save_dir(), filename) os.makedirs(os.path.dirname(save_path), exist_ok=True) return save_path DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} DEFAULT_UV_KWARGS = { "data_format": "HWC", "data_range": (0, 1), "cmap": "checkerboard", } DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} DEFAULT_GRID_KWARGS = {"align": "max"} def get_rgb_image_(self, img, data_format, data_range, rgba=False): img = self.convert_data(img) assert data_format in ["CHW", "HWC"] if data_format == "CHW": img = img.transpose(1, 2, 0) if img.dtype != np.uint8: img = img.clip(min=data_range[0], max=data_range[1]) img = ( (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 ).astype(np.uint8) nc = 4 if rgba else 3 imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] imgs = [ ( img_ if img_.shape[-1] == nc else np.concatenate( [ img_, np.zeros( (img_.shape[0], img_.shape[1], nc - img_.shape[2]), dtype=img_.dtype, ), ], axis=-1, ) ) for img_ in imgs ] img = np.concatenate(imgs, axis=1) if rgba: img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) else: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return img def _save_rgb_image( self, filename, img, data_format, data_range, name: Optional[str] = None, step: Optional[int] = None, ): img = self.get_rgb_image_(img, data_format, data_range) cv2.imwrite(filename, img) if name and self._wandb_logger: self._wandb_logger.log_image( key=name, images=[self.get_save_path(filename)], step=step ) def save_rgb_image( self, filename, img, data_format=DEFAULT_RGB_KWARGS["data_format"], data_range=DEFAULT_RGB_KWARGS["data_range"], name: Optional[str] = None, step: Optional[int] = None, ) -> str: save_path = self.get_save_path(filename) self._save_rgb_image(save_path, img, data_format, data_range, name, step) return save_path def get_uv_image_(self, img, data_format, data_range, cmap): img = self.convert_data(img) assert data_format in ["CHW", "HWC"] if data_format == "CHW": img = img.transpose(1, 2, 0) img = img.clip(min=data_range[0], max=data_range[1]) img = (img - data_range[0]) / (data_range[1] - data_range[0]) assert cmap in ["checkerboard", "color"] if cmap == "checkerboard": n_grid = 64 mask = (img * n_grid).astype(int) mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 img[mask] = np.array([255, 0, 255], dtype=np.uint8) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) elif cmap == "color": img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) img = img_ return img def save_uv_image( self, filename, img, data_format=DEFAULT_UV_KWARGS["data_format"], data_range=DEFAULT_UV_KWARGS["data_range"], cmap=DEFAULT_UV_KWARGS["cmap"], ) -> str: save_path = self.get_save_path(filename) img = self.get_uv_image_(img, data_format, data_range, cmap) cv2.imwrite(save_path, img) return save_path def get_grayscale_image_(self, img, data_range, cmap): img = self.convert_data(img) img = np.nan_to_num(img) if data_range is None: img = (img - img.min()) / (img.max() - img.min()) else: img = img.clip(data_range[0], data_range[1]) img = (img - data_range[0]) / (data_range[1] - data_range[0]) assert cmap in [None, "jet", "magma", "spectral"] if cmap == None: img = (img * 255.0).astype(np.uint8) img = np.repeat(img[..., None], 3, axis=2) elif cmap == "jet": img = (img * 255.0).astype(np.uint8) img = cv2.applyColorMap(img, cv2.COLORMAP_JET) elif cmap == "magma": img = 1.0 - img base = cm.get_cmap("magma") num_bins = 256 colormap = LinearSegmentedColormap.from_list( f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins )(np.linspace(0, 1, num_bins))[:, :3] a = np.floor(img * 255.0) b = (a + 1).clip(max=255.0) f = img * 255.0 - a a = a.astype(np.uint16).clip(0, 255) b = b.astype(np.uint16).clip(0, 255) img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] img = (img * 255.0).astype(np.uint8) elif cmap == "spectral": colormap = plt.get_cmap("Spectral") def blend_rgba(image): image = image[..., :3] * image[..., -1:] + ( 1.0 - image[..., -1:] ) # blend A to RGB return image img = colormap(img) img = blend_rgba(img) img = (img * 255).astype(np.uint8) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return img def _save_grayscale_image( self, filename, img, data_range, cmap, name: Optional[str] = None, step: Optional[int] = None, ): img = self.get_grayscale_image_(img, data_range, cmap) cv2.imwrite(filename, img) if name and self._wandb_logger: self._wandb_logger.log_image( key=name, images=[self.get_save_path(filename)], step=step ) def save_grayscale_image( self, filename, img, data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], name: Optional[str] = None, step: Optional[int] = None, ) -> str: save_path = self.get_save_path(filename) self._save_grayscale_image(save_path, img, data_range, cmap, name, step) return save_path def get_image_grid_(self, imgs, align): if isinstance(imgs[0], list): return np.concatenate( [self.get_image_grid_(row, align) for row in imgs], axis=0 ) cols = [] for col in imgs: assert col["type"] in ["rgb", "uv", "grayscale"] if col["type"] == "rgb": rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() rgb_kwargs.update(col["kwargs"]) cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) elif col["type"] == "uv": uv_kwargs = self.DEFAULT_UV_KWARGS.copy() uv_kwargs.update(col["kwargs"]) cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) elif col["type"] == "grayscale": grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() grayscale_kwargs.update(col["kwargs"]) cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) if align == "max": h = max([col.shape[0] for col in cols]) elif align == "min": h = min([col.shape[0] for col in cols]) elif isinstance(align, int): h = align else: raise ValueError( f"Unsupported image grid align: {align}, should be min, max, or int" ) for i in range(len(cols)): if cols[i].shape[0] != h: w = int(cols[i].shape[1] * h / cols[i].shape[0]) cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC) return np.concatenate(cols, axis=1) def save_image_grid( self, filename, imgs, align=DEFAULT_GRID_KWARGS["align"], name: Optional[str] = None, step: Optional[int] = None, texts: Optional[List[float]] = None, ): save_path = self.get_save_path(filename) img = self.get_image_grid_(imgs, align=align) if texts is not None: img = Image.fromarray(img) draw = ImageDraw.Draw(img) black, white = (0, 0, 0), (255, 255, 255) for i, text in enumerate(texts): draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) img = np.asarray(img) cv2.imwrite(save_path, img) if name and self._wandb_logger: self._wandb_logger.log_image(key=name, images=[save_path], step=step) return save_path def save_image(self, filename, img) -> str: save_path = self.get_save_path(filename) img = self.convert_data(img) assert img.dtype == np.uint8 or img.dtype == np.uint16 if img.ndim == 3 and img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) elif img.ndim == 3 and img.shape[-1] == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) cv2.imwrite(save_path, img) return save_path def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: save_path = self.get_save_path(filename) img = self.convert_data(img) assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] imgs_full = [] for start in range(0, img.shape[-1], 3): img_ = img[..., start : start + 3] img_ = np.stack( [ self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) for i in range(img_.shape[0]) ], axis=0, ) size = img_.shape[1] placeholder = np.zeros((size, size, 3), dtype=np.float32) img_full = np.concatenate( [ np.concatenate( [placeholder, img_[2], placeholder, placeholder], axis=1 ), np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), np.concatenate( [placeholder, img_[3], placeholder, placeholder], axis=1 ), ], axis=0, ) imgs_full.append(img_full) imgs_full = np.concatenate(imgs_full, axis=1) cv2.imwrite(save_path, imgs_full) return save_path def save_data(self, filename, data) -> str: data = self.convert_data(data) if isinstance(data, dict): if not filename.endswith(".npz"): filename += ".npz" save_path = self.get_save_path(filename) np.savez(save_path, **data) else: if not filename.endswith(".npy"): filename += ".npy" save_path = self.get_save_path(filename) np.save(save_path, data) return save_path def save_state_dict(self, filename, data) -> str: save_path = self.get_save_path(filename) torch.save(data, save_path) return save_path def save_img_sequence( self, filename, img_dir, matcher, save_format="mp4", fps=30, name: Optional[str] = None, step: Optional[int] = None, ) -> str: assert save_format in ["gif", "mp4"] if not filename.endswith(save_format): filename += f".{save_format}" save_path = self.get_save_path(filename) matcher = re.compile(matcher) img_dir = os.path.join(self.get_save_dir(), img_dir) imgs = [] for f in os.listdir(img_dir): if matcher.search(f): imgs.append(f) imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] if save_format == "gif": imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) elif save_format == "mp4": imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] imageio.mimsave(save_path, imgs, fps=fps) if name and self._wandb_logger: from .core import warn warn("Wandb logger does not support video logging yet!") return save_path def save_img_sequences( self, seq_dir, matcher, save_format="mp4", fps=30, delete=True, name: Optional[str] = None, step: Optional[int] = None, ): seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) for f in os.listdir(seq_dir_): img_dir_ = os.path.join(seq_dir_, f) if not os.path.isdir(img_dir_): continue try: self.save_img_sequence( os.path.join(seq_dir, f), os.path.join(seq_dir, f), matcher, save_format=save_format, fps=fps, name=f"{name}_{f}", step=step, ) if delete: shutil.rmtree(img_dir_) except: from .core import warn warn(f"Video saving for directory {seq_dir_} failed!") def save_file(self, filename, src_path, delete=False) -> str: save_path = self.get_save_path(filename) shutil.copyfile(src_path, save_path) if delete: os.remove(src_path) return save_path def save_json(self, filename, payload) -> str: save_path = self.get_save_path(filename) with open(save_path, "w") as f: f.write(json.dumps(payload)) return save_path