| from typing import Optional, Protocol, runtime_checkable |
|
|
| import torch |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
| from .types import Pair, sanitize_pair |
|
|
|
|
| @runtime_checkable |
| class ConversionFunction(Protocol): |
| def __call__( |
| self, |
| xy: Float[Tensor, "*batch 2"], |
| ) -> Float[Tensor, "*batch 2"]: |
| pass |
|
|
|
|
| def generate_conversions( |
| shape: tuple[int, int], |
| device: torch.device, |
| x_range: Optional[Pair] = None, |
| y_range: Optional[Pair] = None, |
| ) -> tuple[ |
| ConversionFunction, |
| ConversionFunction, |
| ]: |
| h, w = shape |
| x_range = sanitize_pair((0, w) if x_range is None else x_range, device) |
| y_range = sanitize_pair((0, h) if y_range is None else y_range, device) |
| minima, maxima = torch.stack((x_range, y_range), dim=-1) |
| wh = torch.tensor((w, h), dtype=torch.float32, device=device) |
|
|
| def convert_world_to_pixel( |
| xy: Float[Tensor, "*batch 2"], |
| ) -> Float[Tensor, "*batch 2"]: |
| return (xy - minima) / (maxima - minima) * wh |
|
|
| def convert_pixel_to_world( |
| xy: Float[Tensor, "*batch 2"], |
| ) -> Float[Tensor, "*batch 2"]: |
| return xy / wh * (maxima - minima) + minima |
|
|
| return convert_world_to_pixel, convert_pixel_to_world |
|
|