import torch import numpy as np def generate_meshgrid_2d(h: int, w: int, device) -> torch.tensor: x = torch.linspace(-1, 1, h, device=device) y = torch.linspace(-1, 1, w, device=device) grid_x, grid_y = torch.meshgrid(x, y) grid = torch.stack([grid_x, grid_y], dim=2) return grid def his_match(src, dst): src = src * 255.0 dst = dst * 255.0 src = src.astype(np.uint8) dst = dst.astype(np.uint8) res = np.zeros_like(dst) cdf_src = np.zeros((3, 256)) cdf_dst = np.zeros((3, 256)) cdf_res = np.zeros((3, 256)) kw = dict(bins=256, range=(0, 256), density=True) for ch in range(3): his_src, _ = np.histogram(src[:, :, ch], **kw) hist_dst, _ = np.histogram(dst[:, :, ch], **kw) cdf_src[ch] = np.cumsum(his_src) cdf_dst[ch] = np.cumsum(hist_dst) index = np.searchsorted(cdf_src[ch], cdf_dst[ch], side="left") np.clip(index, 0, 255, out=index) res[:, :, ch] = index[dst[:, :, ch]] his_res, _ = np.histogram(res[:, :, ch], **kw) cdf_res[ch] = np.cumsum(his_res) return res / 255.0