|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from scipy.ndimage import map_coordinates |
|
import cv2 |
|
import math |
|
from os import makedirs |
|
from os.path import join, exists |
|
|
|
|
|
class Equirec2Cube: |
|
def __init__(self, equ_h, equ_w, face_w): |
|
''' |
|
equ_h: int, height of the equirectangular image |
|
equ_w: int, width of the equirectangular image |
|
face_w: int, the length of each face of the cubemap |
|
''' |
|
|
|
self.equ_h = equ_h |
|
self.equ_w = equ_w |
|
self.face_w = face_w |
|
|
|
self._xyzcube() |
|
self._xyz2coor() |
|
|
|
|
|
cosmap = 1 / np.sqrt((2 * self.grid[..., 0]) ** 2 + (2 * self.grid[..., 1]) ** 2 + 1) |
|
self.cosmaps = np.concatenate(6 * [cosmap], axis=1)[..., np.newaxis] |
|
|
|
def _xyzcube(self): |
|
''' |
|
Compute the xyz cordinates of the unit cube in [F R B L U D] format. |
|
''' |
|
self.xyz = np.zeros((self.face_w, self.face_w * 6, 3), np.float32) |
|
rng = np.linspace(-0.5, 0.5, num=self.face_w, dtype=np.float32) |
|
self.grid = np.stack(np.meshgrid(rng, -rng), -1) |
|
|
|
|
|
self.xyz[:, 0 * self.face_w:1 * self.face_w, [0, 1]] = self.grid |
|
self.xyz[:, 0 * self.face_w:1 * self.face_w, 2] = 0.5 |
|
|
|
|
|
self.xyz[:, 1 * self.face_w:2 * self.face_w, [2, 1]] = self.grid[:, ::-1] |
|
self.xyz[:, 1 * self.face_w:2 * self.face_w, 0] = 0.5 |
|
|
|
|
|
self.xyz[:, 2 * self.face_w:3 * self.face_w, [0, 1]] = self.grid[:, ::-1] |
|
self.xyz[:, 2 * self.face_w:3 * self.face_w, 2] = -0.5 |
|
|
|
|
|
self.xyz[:, 3 * self.face_w:4 * self.face_w, [2, 1]] = self.grid |
|
self.xyz[:, 3 * self.face_w:4 * self.face_w, 0] = -0.5 |
|
|
|
|
|
self.xyz[:, 4 * self.face_w:5 * self.face_w, [0, 2]] = self.grid[::-1, :] |
|
self.xyz[:, 4 * self.face_w:5 * self.face_w, 1] = 0.5 |
|
|
|
|
|
self.xyz[:, 5 * self.face_w:6 * self.face_w, [0, 2]] = self.grid |
|
self.xyz[:, 5 * self.face_w:6 * self.face_w, 1] = -0.5 |
|
|
|
def _xyz2coor(self): |
|
|
|
|
|
x, y, z = np.split(self.xyz, 3, axis=-1) |
|
lon = np.arctan2(x, z) |
|
c = np.sqrt(x ** 2 + z ** 2) |
|
lat = np.arctan2(y, c) |
|
|
|
|
|
self.coor_x = (lon / (2 * np.pi) + 0.5) * self.equ_w - 0.5 |
|
self.coor_y = (-lat / np.pi + 0.5) * self.equ_h - 0.5 |
|
|
|
def sample_equirec(self, e_img, order=0): |
|
pad_u = np.roll(e_img[[0]], self.equ_w // 2, 1) |
|
pad_d = np.roll(e_img[[-1]], self.equ_w // 2, 1) |
|
e_img = np.concatenate([e_img, pad_d, pad_u], 0) |
|
|
|
|
|
|
|
|
|
return map_coordinates(e_img, [self.coor_y, self.coor_x], |
|
order=order, mode='wrap')[..., 0] |
|
|
|
def run(self, equ_img, equ_dep=None): |
|
|
|
h, w = equ_img.shape[:2] |
|
if h != self.equ_h or w != self.equ_w: |
|
equ_img = cv2.resize(equ_img, (self.equ_w, self.equ_h)) |
|
if equ_dep is not None: |
|
equ_dep = cv2.resize(equ_dep, (self.equ_w, self.equ_h), interpolation=cv2.INTER_NEAREST) |
|
|
|
cube_img = np.stack([self.sample_equirec(equ_img[..., i], order=1) |
|
for i in range(equ_img.shape[2])], axis=-1) |
|
|
|
if equ_dep is not None: |
|
cube_dep = np.stack([self.sample_equirec(equ_dep[..., i], order=0) |
|
for i in range(equ_dep.shape[2])], axis=-1) |
|
cube_dep = cube_dep * self.cosmaps |
|
|
|
if equ_dep is not None: |
|
return cube_img, cube_dep |
|
else: |
|
return cube_img |
|
|
|
|
|
class Cube2Equirec(nn.Module): |
|
def __init__(self, face_w, equ_h, equ_w): |
|
super(Cube2Equirec, self).__init__() |
|
''' |
|
face_w: int, the length of each face of the cubemap |
|
equ_h: int, height of the equirectangular image |
|
equ_w: int, width of the equirectangular image |
|
''' |
|
|
|
self.face_w = face_w |
|
self.equ_h = equ_h |
|
self.equ_w = equ_w |
|
|
|
|
|
|
|
self._equirect_facetype() |
|
self._equirect_faceuv() |
|
|
|
|
|
def _equirect_facetype(self): |
|
''' |
|
0F 1R 2B 3L 4U 5D |
|
''' |
|
tp = np.roll(np.arange(4).repeat(self.equ_w // 4)[None, :].repeat(self.equ_h, 0), 3 * self.equ_w // 8, 1) |
|
|
|
|
|
mask = np.zeros((self.equ_h, self.equ_w // 4), bool) |
|
idx = np.linspace(-np.pi, np.pi, self.equ_w // 4) / 4 |
|
idx = self.equ_h // 2 - np.round(np.arctan(np.cos(idx)) * self.equ_h / np.pi).astype(int) |
|
for i, j in enumerate(idx): |
|
mask[:j, i] = 1 |
|
mask = np.roll(np.concatenate([mask] * 4, 1), 3 * self.equ_w // 8, 1) |
|
|
|
tp[mask] = 4 |
|
tp[np.flip(mask, 0)] = 5 |
|
|
|
self.tp = tp |
|
self.mask = mask |
|
|
|
def _equirect_faceuv(self): |
|
|
|
lon = ((np.linspace(0, self.equ_w -1, num=self.equ_w, dtype=np.float32 ) +0.5 ) /self.equ_w - 0.5 ) * 2 *np.pi |
|
lat = -((np.linspace(0, self.equ_h -1, num=self.equ_h, dtype=np.float32 ) +0.5 ) /self.equ_h -0.5) * np.pi |
|
|
|
lon, lat = np.meshgrid(lon, lat) |
|
|
|
coor_u = np.zeros((self.equ_h, self.equ_w), dtype=np.float32) |
|
coor_v = np.zeros((self.equ_h, self.equ_w), dtype=np.float32) |
|
|
|
for i in range(4): |
|
mask = (self.tp == i) |
|
coor_u[mask] = 0.5 * np.tan(lon[mask] - np.pi * i / 2) |
|
coor_v[mask] = -0.5 * np.tan(lat[mask]) / np.cos(lon[mask] - np.pi * i / 2) |
|
|
|
mask = (self.tp == 4) |
|
c = 0.5 * np.tan(np.pi / 2 - lat[mask]) |
|
coor_u[mask] = c * np.sin(lon[mask]) |
|
coor_v[mask] = c * np.cos(lon[mask]) |
|
|
|
mask = (self.tp == 5) |
|
c = 0.5 * np.tan(np.pi / 2 - np.abs(lat[mask])) |
|
coor_u[mask] = c * np.sin(lon[mask]) |
|
coor_v[mask] = -c * np.cos(lon[mask]) |
|
|
|
|
|
coor_u = (np.clip(coor_u, -0.5, 0.5)) * 2 |
|
coor_v = (np.clip(coor_v, -0.5, 0.5)) * 2 |
|
|
|
|
|
self.tp = torch.from_numpy(self.tp.astype(np.float32) / 2.5 - 1) |
|
self.coor_u = torch.from_numpy(coor_u) |
|
self.coor_v = torch.from_numpy(coor_v) |
|
|
|
sample_grid = torch.stack([self.coor_u, self.coor_v, self.tp], dim=-1).view(1, 1, self.equ_h, self.equ_w, 3) |
|
self.sample_grid = nn.Parameter(sample_grid, requires_grad=False) |
|
|
|
def forward(self, cube_feat): |
|
|
|
bs, ch, h, w = cube_feat.shape |
|
assert h == self.face_w and w // 6 == self.face_w |
|
|
|
cube_feat = cube_feat.view(bs, ch, 1, h, w) |
|
cube_feat = torch.cat(torch.split(cube_feat, self.face_w, dim=-1), dim=2) |
|
|
|
cube_feat = cube_feat.view([bs, ch, 6, self.face_w, self.face_w]) |
|
sample_grid = torch.cat(bs * [self.sample_grid], dim=0) |
|
equi_feat = F.grid_sample(cube_feat, sample_grid, padding_mode="border", align_corners=True) |
|
|
|
return equi_feat.squeeze(2) |
|
|
|
|
|
|
|
def pair(t): |
|
return t if isinstance(t, tuple) else (t, t) |
|
|
|
def uv2xyz(uv): |
|
xyz = np.zeros((*uv.shape[:-1], 3), dtype = np.float32) |
|
xyz[..., 0] = np.multiply(np.cos(uv[..., 1]), np.sin(uv[..., 0])) |
|
xyz[..., 1] = np.multiply(np.cos(uv[..., 1]), np.cos(uv[..., 0])) |
|
xyz[..., 2] = np.sin(uv[..., 1]) |
|
return xyz |
|
|
|
def equi2pers(erp_img, fov, nrows, patch_size): |
|
bs, _, erp_h, erp_w = erp_img.shape |
|
height, width = pair(patch_size) |
|
fov_h, fov_w = pair(fov) |
|
FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32) |
|
|
|
PI = math.pi |
|
PI_2 = math.pi * 0.5 |
|
PI2 = math.pi * 2 |
|
yy, xx = torch.meshgrid(torch.linspace(0, 1, height), torch.linspace(0, 1, width)) |
|
screen_points = torch.stack([xx.flatten(), yy.flatten()], -1) |
|
|
|
if nrows==4: |
|
num_rows = 4 |
|
num_cols = [3, 6, 6, 3] |
|
phi_centers = [-67.5, -22.5, 22.5, 67.5] |
|
if nrows==6: |
|
num_rows = 6 |
|
num_cols = [3, 8, 12, 12, 8, 3] |
|
phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2] |
|
if nrows==3: |
|
num_rows = 3 |
|
num_cols = [3, 4, 3] |
|
phi_centers = [-60, 0, 60] |
|
if nrows==5: |
|
num_rows = 5 |
|
num_cols = [3, 6, 8, 6, 3] |
|
phi_centers = [-72.2, -36.1, 0, 36.1, 72.2] |
|
|
|
phi_interval = 180 // num_rows |
|
all_combos = [] |
|
erp_mask = [] |
|
for i, n_cols in enumerate(num_cols): |
|
for j in np.arange(n_cols): |
|
theta_interval = 360 / n_cols |
|
theta_center = j * theta_interval + theta_interval / 2 |
|
|
|
center = [theta_center, phi_centers[i]] |
|
all_combos.append(center) |
|
up = phi_centers[i] + phi_interval / 2 |
|
down = phi_centers[i] - phi_interval / 2 |
|
left = theta_center - theta_interval / 2 |
|
right = theta_center + theta_interval / 2 |
|
up = int((up + 90) / 180 * erp_h) |
|
down = int((down + 90) / 180 * erp_h) |
|
left = int(left / 360 * erp_w) |
|
right = int(right / 360 * erp_w) |
|
mask = np.zeros((erp_h, erp_w), dtype=int) |
|
mask[down:up, left:right] = 1 |
|
erp_mask.append(mask) |
|
all_combos = np.vstack(all_combos) |
|
shifts = np.arange(all_combos.shape[0]) * width |
|
shifts = torch.from_numpy(shifts).float() |
|
erp_mask = np.stack(erp_mask) |
|
erp_mask = torch.from_numpy(erp_mask).float() |
|
num_patch = all_combos.shape[0] |
|
|
|
center_point = torch.from_numpy(all_combos).float() |
|
center_point[:, 0] = (center_point[:, 0]) / 360 |
|
center_point[:, 1] = (center_point[:, 1] + 90) / 180 |
|
|
|
cp = center_point * 2 - 1 |
|
center_p = cp.clone() |
|
cp[:, 0] = cp[:, 0] * PI |
|
cp[:, 1] = cp[:, 1] * PI_2 |
|
cp = cp.unsqueeze(1) |
|
convertedCoord = screen_points * 2 - 1 |
|
convertedCoord[:, 0] = convertedCoord[:, 0] * PI |
|
convertedCoord[:, 1] = convertedCoord[:, 1] * PI_2 |
|
convertedCoord = convertedCoord * (torch.ones(screen_points.shape, dtype=torch.float32) * FOV) |
|
convertedCoord = convertedCoord.unsqueeze(0).repeat(cp.shape[0], 1, 1) |
|
|
|
x = convertedCoord[:, :, 0] |
|
y = convertedCoord[:, :, 1] |
|
|
|
rou = torch.sqrt(x ** 2 + y ** 2) |
|
c = torch.atan(rou) |
|
sin_c = torch.sin(c) |
|
cos_c = torch.cos(c) |
|
lat = torch.asin(cos_c * torch.sin(cp[:, :, 1]) + (y * sin_c * torch.cos(cp[:, :, 1])) / rou) |
|
lon = cp[:, :, 0] + torch.atan2(x * sin_c, rou * torch.cos(cp[:, :, 1]) * cos_c - y * torch.sin(cp[:, :, 1]) * sin_c) |
|
lat_new = lat / PI_2 |
|
lon_new = lon / PI |
|
lon_new[lon_new > 1] -= 2 |
|
lon_new[lon_new<-1] += 2 |
|
|
|
lon_new = lon_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width) |
|
lat_new = lat_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width) |
|
grid = torch.stack([lon_new, lat_new], -1) |
|
grid = grid.unsqueeze(0).repeat(bs, 1, 1, 1).to(erp_img.device) |
|
pers = F.grid_sample(erp_img, grid, mode='bilinear', padding_mode='border', align_corners=True) |
|
pers = F.unfold(pers, kernel_size=(height, width), stride=(height, width)) |
|
pers = pers.reshape(bs, -1, height, width, num_patch) |
|
|
|
grid_tmp = torch.stack([lon, lat], -1) |
|
xyz = uv2xyz(grid_tmp) |
|
xyz = xyz.reshape(num_patch, height, width, 3).transpose(0, 3, 1, 2) |
|
xyz = torch.from_numpy(xyz).to(pers.device).contiguous() |
|
|
|
uv = grid[0, ...].reshape(height, width, num_patch, 2).permute(2, 3, 0, 1) |
|
uv = uv.contiguous() |
|
return pers, xyz, uv, center_p |
|
|
|
def pers2equi(pers_img, fov, nrows, patch_size, erp_size, layer_name): |
|
bs = pers_img.shape[0] |
|
channel = pers_img.shape[1] |
|
device=pers_img.device |
|
height, width = pair(patch_size) |
|
fov_h, fov_w = pair(fov) |
|
erp_h, erp_w = pair(erp_size) |
|
n_patch = pers_img.shape[-1] |
|
grid_dir = './grid' |
|
if not exists(grid_dir): |
|
makedirs(grid_dir) |
|
grid_file = join(grid_dir, layer_name + '.pth') |
|
|
|
if not exists(grid_file): |
|
FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32) |
|
|
|
PI = math.pi |
|
PI_2 = math.pi * 0.5 |
|
PI2 = math.pi * 2 |
|
|
|
if nrows==4: |
|
num_rows = 4 |
|
num_cols = [3, 6, 6, 3] |
|
phi_centers = [-67.5, -22.5, 22.5, 67.5] |
|
if nrows==6: |
|
num_rows = 6 |
|
num_cols = [3, 8, 12, 12, 8, 3] |
|
phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2] |
|
if nrows==3: |
|
num_rows = 3 |
|
num_cols = [3, 4, 3] |
|
phi_centers = [-59.6, 0, 59.6] |
|
if nrows==5: |
|
num_rows = 5 |
|
num_cols = [3, 6, 8, 6, 3] |
|
phi_centers = [-72.2, -36.1, 0, 36.1, 72.2] |
|
phi_interval = 180 // num_rows |
|
all_combos = [] |
|
|
|
for i, n_cols in enumerate(num_cols): |
|
for j in np.arange(n_cols): |
|
theta_interval = 360 / n_cols |
|
theta_center = j * theta_interval + theta_interval / 2 |
|
|
|
center = [theta_center, phi_centers[i]] |
|
all_combos.append(center) |
|
|
|
|
|
all_combos = np.vstack(all_combos) |
|
n_patch = all_combos.shape[0] |
|
|
|
center_point = torch.from_numpy(all_combos).float() |
|
center_point[:, 0] = (center_point[:, 0]) / 360 |
|
center_point[:, 1] = (center_point[:, 1] + 90) / 180 |
|
|
|
cp = center_point * 2 - 1 |
|
cp[:, 0] = cp[:, 0] * PI |
|
cp[:, 1] = cp[:, 1] * PI_2 |
|
cp = cp.unsqueeze(1) |
|
|
|
lat_grid, lon_grid = torch.meshgrid(torch.linspace(-PI_2, PI_2, erp_h), torch.linspace(-PI, PI, erp_w)) |
|
lon_grid = lon_grid.float().reshape(1, -1) |
|
lat_grid = lat_grid.float().reshape(1, -1) |
|
cos_c = torch.sin(cp[..., 1]) * torch.sin(lat_grid) + torch.cos(cp[..., 1]) * torch.cos(lat_grid) * torch.cos(lon_grid - cp[..., 0]) |
|
new_x = (torch.cos(lat_grid) * torch.sin(lon_grid - cp[..., 0])) / cos_c |
|
new_y = (torch.cos(cp[..., 1])*torch.sin(lat_grid) - torch.sin(cp[...,1])*torch.cos(lat_grid)*torch.cos(lon_grid-cp[...,0])) / cos_c |
|
new_x = new_x / FOV[0] / PI |
|
new_y = new_y / FOV[1] / PI_2 |
|
cos_c_mask = cos_c.reshape(n_patch, erp_h, erp_w) |
|
cos_c_mask = torch.where(cos_c_mask > 0, 1, 0) |
|
|
|
w_list = torch.zeros((n_patch, erp_h, erp_w, 4), dtype=torch.float32) |
|
|
|
new_x_patch = (new_x + 1) * 0.5 * height |
|
new_y_patch = (new_y + 1) * 0.5 * width |
|
new_x_patch = new_x_patch.reshape(n_patch, erp_h, erp_w) |
|
new_y_patch = new_y_patch.reshape(n_patch, erp_h, erp_w) |
|
mask = torch.where((new_x_patch < width) & (new_x_patch > 0) & (new_y_patch < height) & (new_y_patch > 0), 1, 0) |
|
mask *= cos_c_mask |
|
|
|
x0 = torch.floor(new_x_patch).type(torch.int64) |
|
x1 = x0 + 1 |
|
y0 = torch.floor(new_y_patch).type(torch.int64) |
|
y1 = y0 + 1 |
|
|
|
x0 = torch.clamp(x0, 0, width-1) |
|
x1 = torch.clamp(x1, 0, width-1) |
|
y0 = torch.clamp(y0, 0, height-1) |
|
y1 = torch.clamp(y1, 0, height-1) |
|
|
|
wa = (x1.type(torch.float32)-new_x_patch) * (y1.type(torch.float32)-new_y_patch) |
|
wb = (x1.type(torch.float32)-new_x_patch) * (new_y_patch-y0.type(torch.float32)) |
|
wc = (new_x_patch-x0.type(torch.float32)) * (y1.type(torch.float32)-new_y_patch) |
|
wd = (new_x_patch-x0.type(torch.float32)) * (new_y_patch-y0.type(torch.float32)) |
|
|
|
wa = wa * mask.expand_as(wa) |
|
wb = wb * mask.expand_as(wb) |
|
wc = wc * mask.expand_as(wc) |
|
wd = wd * mask.expand_as(wd) |
|
|
|
w_list[..., 0] = wa |
|
w_list[..., 1] = wb |
|
w_list[..., 2] = wc |
|
w_list[..., 3] = wd |
|
|
|
|
|
save_file = {'x0':x0, 'y0':y0, 'x1':x1, 'y1':y1, 'w_list': w_list, 'mask':mask} |
|
torch.save(save_file, grid_file) |
|
else: |
|
|
|
|
|
load_file = torch.load(grid_file) |
|
|
|
x0 = load_file['x0'] |
|
y0 = load_file['y0'] |
|
x1 = load_file['x1'] |
|
y1 = load_file['y1'] |
|
w_list = load_file['w_list'] |
|
mask = load_file['mask'] |
|
|
|
w_list = w_list.to(device) |
|
mask = mask.to(device) |
|
z = torch.arange(n_patch) |
|
z = z.reshape(n_patch, 1, 1) |
|
Ia = pers_img[:, :, y0, x0, z] |
|
Ib = pers_img[:, :, y1, x0, z] |
|
Ic = pers_img[:, :, y0, x1, z] |
|
Id = pers_img[:, :, y1, x1, z] |
|
output_a = Ia * mask.expand_as(Ia) |
|
output_b = Ib * mask.expand_as(Ib) |
|
output_c = Ic * mask.expand_as(Ic) |
|
output_d = Id * mask.expand_as(Id) |
|
|
|
output_a = output_a.permute(0, 1, 3, 4, 2) |
|
output_b = output_b.permute(0, 1, 3, 4, 2) |
|
output_c = output_c.permute(0, 1, 3, 4, 2) |
|
output_d = output_d.permute(0, 1, 3, 4, 2) |
|
w_list = w_list.permute(1, 2, 0, 3) |
|
w_list = w_list.flatten(2) |
|
w_list *= torch.gt(w_list, 1e-5).type(torch.float32) |
|
w_list = F.normalize(w_list, p=1, dim=-1).reshape(erp_h, erp_w, n_patch, 4) |
|
w_list = w_list.unsqueeze(0).unsqueeze(0) |
|
output = output_a * w_list[..., 0] + output_b * w_list[..., 1] + \ |
|
output_c * w_list[..., 2] + output_d * w_list[..., 3] |
|
img_erp = output.sum(-1) |
|
|
|
return img_erp |
|
|
|
def img2windows(img, H_sp, W_sp): |
|
""" |
|
img: B C H W |
|
""" |
|
B, C, H, W = img.shape |
|
img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) |
|
img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp, W_sp, C) |
|
return img_perm |
|
|
|
def windows2img(img_splits_hw, H_sp, W_sp, H, W): |
|
""" |
|
img_splits_hw: B' H W C |
|
""" |
|
B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) |
|
|
|
img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) |
|
img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
|
return img |