| | from typing import Literal, Optional |
| |
|
| | import torch |
| | from einops import einsum, repeat |
| | from jaxtyping import Float |
| | from torch import Tensor |
| |
|
| | from .coordinate_conversion import generate_conversions |
| | from .rendering import render_over_image |
| | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector |
| |
|
| |
|
| | def draw_lines( |
| | image: Float[Tensor, "3 height width"], |
| | start: Vector, |
| | end: Vector, |
| | color: Vector, |
| | width: Scalar, |
| | cap: Literal["butt", "round", "square"] = "round", |
| | num_msaa_passes: int = 1, |
| | x_range: Optional[Pair] = None, |
| | y_range: Optional[Pair] = None, |
| | ) -> Float[Tensor, "3 height width"]: |
| | device = image.device |
| | start = sanitize_vector(start, 2, device) |
| | end = sanitize_vector(end, 2, device) |
| | color = sanitize_vector(color, 3, device) |
| | width = sanitize_scalar(width, device) |
| | (num_lines,) = torch.broadcast_shapes( |
| | start.shape[0], |
| | end.shape[0], |
| | color.shape[0], |
| | width.shape, |
| | ) |
| |
|
| | |
| | _, h, w = image.shape |
| | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) |
| | start = world_to_pixel(start) |
| | end = world_to_pixel(end) |
| |
|
| | def color_function( |
| | xy: Float[Tensor, "point 2"], |
| | ) -> Float[Tensor, "point 4"]: |
| | |
| | delta = end - start |
| | delta_norm = delta.norm(dim=-1, keepdim=True) |
| | u_delta = delta / delta_norm |
| |
|
| | |
| | indicator = xy - start[:, None] |
| |
|
| | |
| | extra = 0.5 * width[:, None] if cap == "square" else 0 |
| | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") |
| | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) |
| |
|
| | |
| | perpendicular = indicator - parallel[..., None] * u_delta[:, None] |
| | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] |
| |
|
| | inside_line = parallel_inside_line & perpendicular_inside_line |
| |
|
| | |
| | if cap == "round": |
| | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] |
| | inside_line |= near_start |
| | end_indicator = indicator = xy - end[:, None] |
| | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] |
| | inside_line |= near_end |
| |
|
| | |
| | selectable_color = color.broadcast_to((num_lines, 3)) |
| | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] |
| | top_color = selectable_color.gather( |
| | dim=0, |
| | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), |
| | ) |
| | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) |
| |
|
| | return rgba |
| |
|
| | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) |
| |
|