MuseVSpace / MuseV /musev /utils /tensor_util.py
anchorxia's picture
add musev
96d7ad8
raw
history blame contribute delete
No virus
1.11 kB
import torch
import numpy as np
def generate_meshgrid_2d(h: int, w: int, device) -> torch.tensor:
x = torch.linspace(-1, 1, h, device=device)
y = torch.linspace(-1, 1, w, device=device)
grid_x, grid_y = torch.meshgrid(x, y)
grid = torch.stack([grid_x, grid_y], dim=2)
return grid
def his_match(src, dst):
src = src * 255.0
dst = dst * 255.0
src = src.astype(np.uint8)
dst = dst.astype(np.uint8)
res = np.zeros_like(dst)
cdf_src = np.zeros((3, 256))
cdf_dst = np.zeros((3, 256))
cdf_res = np.zeros((3, 256))
kw = dict(bins=256, range=(0, 256), density=True)
for ch in range(3):
his_src, _ = np.histogram(src[:, :, ch], **kw)
hist_dst, _ = np.histogram(dst[:, :, ch], **kw)
cdf_src[ch] = np.cumsum(his_src)
cdf_dst[ch] = np.cumsum(hist_dst)
index = np.searchsorted(cdf_src[ch], cdf_dst[ch], side="left")
np.clip(index, 0, 255, out=index)
res[:, :, ch] = index[dst[:, :, ch]]
his_res, _ = np.histogram(res[:, :, ch], **kw)
cdf_res[ch] = np.cumsum(his_res)
return res / 255.0