Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Adapted from code originally written by David Novotny. | |
""" | |
import torch | |
from pytorch3d.transforms import Rotate, Translate | |
import cv2 | |
import numpy as np | |
import torch | |
from pytorch3d.renderer import PerspectiveCameras, RayBundle | |
def intersect_skew_line_groups(p, r, mask): | |
# p, r both of shape (B, N, n_intersected_lines, 3) | |
# mask of shape (B, N, n_intersected_lines) | |
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) | |
if p_intersect is None: | |
return None, None, None, None | |
_, p_line_intersect = point_line_distance( | |
p, r, p_intersect[..., None, :].expand_as(p) | |
) | |
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( | |
dim=-1 | |
) | |
return p_intersect, p_line_intersect, intersect_dist_squared, r | |
def intersect_skew_lines_high_dim(p, r, mask=None): | |
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions | |
dim = p.shape[-1] | |
# make sure the heading vectors are l2-normed | |
if mask is None: | |
mask = torch.ones_like(p[..., 0]) | |
r = torch.nn.functional.normalize(r, dim=-1) | |
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] | |
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] | |
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) | |
# I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10 | |
# p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0] | |
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] | |
# I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3]) | |
# sum_proj: torch.Size([1, 1, 3, 1]) | |
# p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0] | |
if torch.any(torch.isnan(p_intersect)): | |
print(p_intersect) | |
return None, None | |
ipdb.set_trace() | |
assert False | |
return p_intersect, r | |
def point_line_distance(p1, r1, p2): | |
df = p2 - p1 | |
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) | |
line_pt_nearest = p2 - proj_vector | |
d = (proj_vector).norm(dim=-1) | |
return d, line_pt_nearest | |
def compute_optical_axis_intersection(cameras): | |
centers = cameras.get_camera_center() | |
principal_points = cameras.principal_point | |
one_vec = torch.ones((len(cameras), 1), device=centers.device) | |
optical_axis = torch.cat((principal_points, one_vec), -1) | |
# optical_axis = torch.cat( | |
# (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1 | |
# ) | |
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) | |
pp2 = torch.diagonal(pp, dim1=0, dim2=1).T | |
directions = pp2 - centers | |
centers = centers.unsqueeze(0).unsqueeze(0) | |
directions = directions.unsqueeze(0).unsqueeze(0) | |
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( | |
p=centers, r=directions, mask=None | |
) | |
if p_intersect is None: | |
dist = None | |
else: | |
p_intersect = p_intersect.squeeze().unsqueeze(0) | |
dist = (p_intersect - centers).norm(dim=-1) | |
return p_intersect, dist, p_line_intersect, pp2, r | |
def normalize_cameras(cameras, scale=1.0): | |
""" | |
Normalizes cameras such that the optical axes point to the origin, the rotation is | |
identity, and the norm of the translation of the first camera is 1. | |
Args: | |
cameras (pytorch3d.renderer.cameras.CamerasBase). | |
scale (float): Norm of the translation of the first camera. | |
Returns: | |
new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras. | |
undo_transform (function): Function that undoes the normalization. | |
""" | |
# Let distance from first camera to origin be unit | |
new_cameras = cameras.clone() | |
new_transform = ( | |
new_cameras.get_world_to_view_transform() | |
) # potential R is not valid matrix | |
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection( | |
cameras | |
) | |
if p_intersect is None: | |
print("Warning: optical axes code has a nan. Returning identity cameras.") | |
new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype) | |
new_cameras.T[:] = torch.tensor( | |
[0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype | |
) | |
return new_cameras, lambda x: x | |
d = dist.squeeze(dim=1).squeeze(dim=0)[0] | |
# Degenerate case | |
if d == 0: | |
print(cameras.T) | |
print(new_transform.get_matrix()[:, 3, :3]) | |
assert False | |
assert d != 0 | |
# Can't figure out how to make scale part of the transform too without messing up R. | |
# Ideally, we would just wrap it all in a single Pytorch3D transform so that it | |
# would work with any structure (eg PointClouds, Meshes). | |
tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse() | |
tT = Translate(p_intersect) | |
t = tR.compose(tT) | |
new_transform = t.compose(new_transform) | |
new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale | |
def undo_transform(cameras): | |
cameras_copy = cameras.clone() | |
cameras_copy.T *= d / scale | |
new_t = ( | |
t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix() | |
) | |
cameras_copy.R = new_t[:, :3, :3] | |
cameras_copy.T = new_t[:, 3, :3] | |
return cameras_copy | |
return new_cameras, undo_transform | |
def first_camera_transform(cameras, rotation_only=True): | |
new_cameras = cameras.clone() | |
new_transform = new_cameras.get_world_to_view_transform() | |
tR = Rotate(new_cameras.R[0].unsqueeze(0)) | |
if rotation_only: | |
t = tR.inverse() | |
else: | |
tT = Translate(new_cameras.T[0].unsqueeze(0)) | |
t = tR.compose(tT).inverse() | |
new_transform = t.compose(new_transform) | |
new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
new_cameras.T = new_transform.get_matrix()[:, 3, :3] | |
return new_cameras | |
def get_identity_cameras_with_intrinsics(cameras): | |
D = len(cameras) | |
device = cameras.R.device | |
new_cameras = cameras.clone() | |
new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1)) | |
new_cameras.T = torch.zeros((D, 3), device=device) | |
return new_cameras | |
def normalize_cameras_batch(cameras, scale=1.0, normalize_first_camera=False): | |
new_cameras = [] | |
undo_transforms = [] | |
for cam in cameras: | |
if normalize_first_camera: | |
# Normalize cameras such that first camera is identity and origin is at | |
# first camera center. | |
normalized_cameras = first_camera_transform(cam, rotation_only=False) | |
undo_transform = None | |
else: | |
normalized_cameras, undo_transform = normalize_cameras(cam, scale=scale) | |
new_cameras.append(normalized_cameras) | |
undo_transforms.append(undo_transform) | |
return new_cameras, undo_transforms | |
class Rays(object): | |
def __init__( | |
self, | |
rays=None, | |
origins=None, | |
directions=None, | |
moments=None, | |
is_plucker=False, | |
moments_rescale=1.0, | |
ndc_coordinates=None, | |
crop_parameters=None, | |
num_patches_x=16, | |
num_patches_y=16, | |
): | |
""" | |
Ray class to keep track of current ray representation. | |
Args: | |
rays: (..., 6). | |
origins: (..., 3). | |
directions: (..., 3). | |
moments: (..., 3). | |
is_plucker: If True, rays are in plucker coordinates (Default: False). | |
moments_rescale: Rescale the moment component of the rays by a scalar. | |
ndc_coordinates: (..., 2): NDC coordinates of each ray. | |
""" | |
if rays is not None: | |
self.rays = rays | |
self._is_plucker = is_plucker | |
elif origins is not None and directions is not None: | |
self.rays = torch.cat((origins, directions), dim=-1) | |
self._is_plucker = False | |
elif directions is not None and moments is not None: | |
self.rays = torch.cat((directions, moments), dim=-1) | |
self._is_plucker = True | |
else: | |
raise Exception("Invalid combination of arguments") | |
if moments_rescale != 1.0: | |
self.rescale_moments(moments_rescale) | |
if ndc_coordinates is not None: | |
self.ndc_coordinates = ndc_coordinates | |
elif crop_parameters is not None: | |
# (..., H, W, 2) | |
xy_grid = compute_ndc_coordinates( | |
crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
)[..., :2] | |
xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2) | |
self.ndc_coordinates = xy_grid | |
else: | |
self.ndc_coordinates = None | |
def __getitem__(self, index): | |
return Rays( | |
rays=self.rays[index], | |
is_plucker=self._is_plucker, | |
ndc_coordinates=( | |
self.ndc_coordinates[index] | |
if self.ndc_coordinates is not None | |
else None | |
), | |
) | |
def to_spatial(self, include_ndc_coordinates=False): | |
""" | |
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) | |
Returns: | |
torch.Tensor: (..., 6, H, W) | |
""" | |
rays = self.to_plucker().rays | |
*batch_dims, P, D = rays.shape | |
H = W = int(np.sqrt(P)) | |
assert H * W == P | |
rays = torch.transpose(rays, -1, -2) # (..., 6, H * W) | |
rays = rays.reshape(*batch_dims, D, H, W) | |
if include_ndc_coordinates: | |
ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W) | |
ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W) | |
rays = torch.cat((rays, ndc_coords), dim=-3) | |
return rays | |
def rescale_moments(self, scale): | |
""" | |
Rescale the moment component of the rays by a scalar. Might be desirable since | |
moments may come from a very narrow distribution. | |
Note that this modifies in place! | |
""" | |
if self.is_plucker: | |
self.rays[..., 3:] *= scale | |
return self | |
else: | |
return self.to_plucker().rescale_moments(scale) | |
def from_spatial(cls, rays, moments_rescale=1.0, ndc_coordinates=None): | |
""" | |
Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6) | |
Args: | |
rays: (..., 6, H, W) | |
Returns: | |
Rays: (..., H * W, 6) | |
""" | |
*batch_dims, D, H, W = rays.shape | |
rays = rays.reshape(*batch_dims, D, H * W) | |
rays = torch.transpose(rays, -1, -2) | |
return cls( | |
rays=rays, | |
is_plucker=True, | |
moments_rescale=moments_rescale, | |
ndc_coordinates=ndc_coordinates, | |
) | |
def to_point_direction(self, normalize_moment=True): | |
""" | |
Convert to point direction representation <O, D>. | |
Returns: | |
rays: (..., 6). | |
""" | |
if self._is_plucker: | |
direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1) | |
moment = self.rays[..., 3:] | |
if normalize_moment: | |
c = torch.linalg.norm(direction, dim=-1, keepdim=True) | |
moment = moment / c | |
points = torch.cross(direction, moment, dim=-1) | |
return Rays( | |
rays=torch.cat((points, direction), dim=-1), | |
is_plucker=False, | |
ndc_coordinates=self.ndc_coordinates, | |
) | |
else: | |
return self | |
def to_plucker(self): | |
""" | |
Convert to plucker representation <D, OxD>. | |
""" | |
if self.is_plucker: | |
return self | |
else: | |
ray = self.rays.clone() | |
ray_origins = ray[..., :3] | |
ray_directions = ray[..., 3:] | |
# Normalize ray directions to unit vectors | |
ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True) | |
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) | |
new_ray = torch.cat([ray_directions, plucker_normal], dim=-1) | |
return Rays( | |
rays=new_ray, is_plucker=True, ndc_coordinates=self.ndc_coordinates | |
) | |
def get_directions(self, normalize=True): | |
if self.is_plucker: | |
directions = self.rays[..., :3] | |
else: | |
directions = self.rays[..., 3:] | |
if normalize: | |
directions = torch.nn.functional.normalize(directions, dim=-1) | |
return directions | |
def get_origins(self): | |
if self.is_plucker: | |
origins = self.to_point_direction().get_origins() | |
else: | |
origins = self.rays[..., :3] | |
return origins | |
def get_moments(self): | |
if self.is_plucker: | |
moments = self.rays[..., 3:] | |
else: | |
moments = self.to_plucker().get_moments() | |
return moments | |
def get_ndc_coordinates(self): | |
return self.ndc_coordinates | |
def is_plucker(self): | |
return self._is_plucker | |
def device(self): | |
return self.rays.device | |
def __repr__(self, *args, **kwargs): | |
ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor" | |
if self._is_plucker: | |
return "PluRay" + ray_str | |
else: | |
return "DirRay" + ray_str | |
def to(self, device): | |
self.rays = self.rays.to(device) | |
def clone(self): | |
return Rays(rays=self.rays.clone(), is_plucker=self._is_plucker) | |
def shape(self): | |
return self.rays.shape | |
def visualize(self): | |
directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu() | |
moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu() | |
return (directions + 1) / 2, (moments + 1) / 2 | |
def to_ray_bundle(self, length=0.3, recenter=True): | |
lengths = torch.ones_like(self.get_origins()[..., :2]) * length | |
lengths[..., 0] = 0 | |
if recenter: | |
centers, _ = intersect_skew_lines_high_dim( | |
self.get_origins(), self.get_directions() | |
) | |
centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1) | |
else: | |
centers = self.get_origins() | |
return RayBundle( | |
origins=centers, | |
directions=self.get_directions(), | |
lengths=lengths, | |
xys=self.get_directions(), | |
) | |
def cameras_to_rays( | |
cameras, | |
crop_parameters, | |
use_half_pix=True, | |
use_plucker=True, | |
num_patches_x=16, | |
num_patches_y=16, | |
): | |
""" | |
Unprojects rays from camera center to grid on image plane. | |
Args: | |
cameras: Pytorch3D cameras to unproject. Can be batched. | |
crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale). | |
Shape is (B, 4). | |
use_half_pix: If True, use half pixel offset (Default: True). | |
use_plucker: If True, return rays in plucker coordinates (Default: False). | |
num_patches_x: Number of patches in x direction (Default: 16). | |
num_patches_y: Number of patches in y direction (Default: 16). | |
""" | |
unprojected = [] | |
crop_parameters_list = ( | |
crop_parameters if crop_parameters is not None else [None for _ in cameras] | |
) | |
for camera, crop_param in zip(cameras, crop_parameters_list): | |
xyd_grid = compute_ndc_coordinates( | |
crop_parameters=crop_param, | |
use_half_pix=use_half_pix, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
) | |
unprojected.append( | |
camera.unproject_points( | |
xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True | |
) | |
) | |
unprojected = torch.stack(unprojected, dim=0) # (N, P, 3) | |
origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3) | |
origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3) | |
directions = unprojected - origins | |
rays = Rays( | |
origins=origins, | |
directions=directions, | |
crop_parameters=crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
) | |
if use_plucker: | |
return rays.to_plucker() | |
return rays | |
def rays_to_cameras( | |
rays, | |
crop_parameters, | |
num_patches_x=16, | |
num_patches_y=16, | |
use_half_pix=True, | |
sampled_ray_idx=None, | |
cameras=None, | |
focal_length=(3.453,), | |
): | |
""" | |
If cameras are provided, will use those intrinsics. Otherwise will use the provided | |
focal_length(s). Dataset default is 3.32. | |
Args: | |
rays (Rays): (N, P, 6) | |
crop_parameters (torch.Tensor): (N, 4) | |
""" | |
device = rays.device | |
origins = rays.get_origins() | |
directions = rays.get_directions() | |
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) | |
# Retrieve target rays | |
if cameras is None: | |
if len(focal_length) == 1: | |
focal_length = focal_length * rays.shape[0] | |
I_camera = PerspectiveCameras(focal_length=focal_length, device=device) | |
else: | |
# Use same intrinsics but reset to identity extrinsics. | |
I_camera = cameras.clone() | |
I_camera.R[:] = torch.eye(3, device=device) | |
I_camera.T[:] = torch.zeros(3, device=device) | |
I_patch_rays = cameras_to_rays( | |
cameras=I_camera, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
use_half_pix=use_half_pix, | |
crop_parameters=crop_parameters, | |
).get_directions() | |
if sampled_ray_idx is not None: | |
I_patch_rays = I_patch_rays[:, sampled_ray_idx] | |
# Compute optimal rotation to align rays | |
R = torch.zeros_like(I_camera.R) | |
for i in range(len(I_camera)): | |
R[i] = compute_optimal_rotation_alignment( | |
I_patch_rays[i], | |
directions[i], | |
) | |
# Construct and return rotated camera | |
cam = I_camera.clone() | |
cam.R = R | |
cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) | |
return cam | |
# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/ | |
def ql_decomposition(A): | |
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float() | |
A_tilde = torch.matmul(A, P) | |
Q_tilde, R_tilde = torch.linalg.qr(A_tilde) | |
Q = torch.matmul(Q_tilde, P) | |
L = torch.matmul(torch.matmul(P, R_tilde), P) | |
d = torch.diag(L) | |
Q[:, 0] *= torch.sign(d[0]) | |
Q[:, 1] *= torch.sign(d[1]) | |
Q[:, 2] *= torch.sign(d[2]) | |
L[0] *= torch.sign(d[0]) | |
L[1] *= torch.sign(d[1]) | |
L[2] *= torch.sign(d[2]) | |
return Q, L | |
def rays_to_cameras_homography( | |
rays, | |
crop_parameters, | |
num_patches_x=16, | |
num_patches_y=16, | |
use_half_pix=True, | |
sampled_ray_idx=None, | |
reproj_threshold=0.2, | |
): | |
""" | |
Args: | |
rays (Rays): (N, P, 6) | |
crop_parameters (torch.Tensor): (N, 4) | |
""" | |
device = rays.device | |
origins = rays.get_origins() | |
directions = rays.get_directions() | |
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) | |
# Retrieve target rays | |
I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device) | |
I_patch_rays = cameras_to_rays( | |
cameras=I_camera, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
use_half_pix=use_half_pix, | |
crop_parameters=crop_parameters, | |
).get_directions() | |
if sampled_ray_idx is not None: | |
I_patch_rays = I_patch_rays[:, sampled_ray_idx] | |
# Compute optimal rotation to align rays | |
Rs = [] | |
focal_lengths = [] | |
principal_points = [] | |
for i in range(rays.shape[-3]): | |
R, f, pp = compute_optimal_rotation_intrinsics( | |
I_patch_rays[i], | |
directions[i], | |
reproj_threshold=reproj_threshold, | |
) | |
Rs.append(R) | |
focal_lengths.append(f) | |
principal_points.append(pp) | |
R = torch.stack(Rs) | |
focal_lengths = torch.stack(focal_lengths) | |
principal_points = torch.stack(principal_points) | |
T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) | |
return PerspectiveCameras( | |
R=R, | |
T=T, | |
focal_length=focal_lengths, | |
principal_point=principal_points, | |
device=device, | |
) | |
def compute_optimal_rotation_alignment(A, B): | |
""" | |
Compute optimal R that minimizes: || A - B @ R ||_F | |
Args: | |
A (torch.Tensor): (N, 3) | |
B (torch.Tensor): (N, 3) | |
Returns: | |
R (torch.tensor): (3, 3) | |
""" | |
# normally with R @ B, this would be A @ B.T | |
H = B.T @ A | |
U, _, Vh = torch.linalg.svd(H, full_matrices=True) | |
s = torch.linalg.det(U @ Vh) | |
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) | |
return U @ S_prime @ Vh | |
def compute_optimal_rotation_intrinsics( | |
rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2 | |
): | |
""" | |
Note: for some reason, f seems to be 1/f. | |
Args: | |
rays_origin (torch.Tensor): (N, 3) | |
rays_target (torch.Tensor): (N, 3) | |
z_threshold (float): Threshold for z value to be considered valid. | |
Returns: | |
R (torch.tensor): (3, 3) | |
focal_length (torch.tensor): (2,) | |
principal_point (torch.tensor): (2,) | |
""" | |
device = rays_origin.device | |
z_mask = torch.logical_and( | |
torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold | |
)[:, 2] | |
rays_target = rays_target[z_mask] | |
rays_origin = rays_origin[z_mask] | |
rays_origin = rays_origin[:, :2] / rays_origin[:, -1:] | |
rays_target = rays_target[:, :2] / rays_target[:, -1:] | |
A, _ = cv2.findHomography( | |
rays_origin.cpu().numpy(), | |
rays_target.cpu().numpy(), | |
cv2.RANSAC, | |
reproj_threshold, | |
) | |
A = torch.from_numpy(A).float().to(device) | |
if torch.linalg.det(A) < 0: | |
A = -A | |
R, L = ql_decomposition(A) | |
L = L / L[2][2] | |
f = torch.stack((L[0][0], L[1][1])) | |
pp = torch.stack((L[2][0], L[2][1])) | |
return R, f, pp | |
def compute_ndc_coordinates( | |
crop_parameters=None, | |
use_half_pix=True, | |
num_patches_x=16, | |
num_patches_y=16, | |
device=None, | |
): | |
""" | |
Computes NDC Grid using crop_parameters. If crop_parameters is not provided, | |
then it assumes that the crop is the entire image (corresponding to an NDC grid | |
where top left corner is (1, 1) and bottom right corner is (-1, -1)). | |
""" | |
if crop_parameters is None: | |
cc_x, cc_y, width = 0, 0, 2 | |
else: | |
if len(crop_parameters.shape) > 1: | |
return torch.stack( | |
[ | |
compute_ndc_coordinates( | |
crop_parameters=crop_param, | |
use_half_pix=use_half_pix, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
) | |
for crop_param in crop_parameters | |
], | |
dim=0, | |
) | |
device = crop_parameters.device | |
cc_x, cc_y, width, _ = crop_parameters | |
dx = 1 / num_patches_x | |
dy = 1 / num_patches_y | |
if use_half_pix: | |
min_y = 1 - dy | |
max_y = -min_y | |
min_x = 1 - dx | |
max_x = -min_x | |
else: | |
min_y = min_x = 1 | |
max_y = -1 + 2 * dy | |
max_x = -1 + 2 * dx | |
y, x = torch.meshgrid( | |
torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device), | |
torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device), | |
indexing="ij", | |
) | |
x_prime = x * width / 2 - cc_x | |
y_prime = y * width / 2 - cc_y | |
xyd_grid = torch.stack([x_prime, y_prime, torch.ones_like(x)], dim=-1) | |
return xyd_grid | |