File size: 4,417 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from typing import Protocol, runtime_checkable
import torch
from einops import rearrange, reduce
from jaxtyping import Bool, Float
from torch import Tensor
@runtime_checkable
class ColorFunction(Protocol):
def __call__(
self,
xy: Float[Tensor, "point 2"],
) -> Float[Tensor, "point 4"]: # RGBA color
pass
def generate_sample_grid(
shape: tuple[int, int],
device: torch.device,
) -> Float[Tensor, "height width 2"]:
h, w = shape
x = torch.arange(w, device=device) + 0.5
y = torch.arange(h, device=device) + 0.5
x, y = torch.meshgrid(x, y, indexing="xy")
return torch.stack([x, y], dim=-1)
def detect_msaa_pixels(
image: Float[Tensor, "batch 4 height width"],
) -> Bool[Tensor, "batch height width"]:
b, _, h, w = image.shape
mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device)
# Detect horizontal differences.
horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1)
mask[:, :, 1:] |= horizontal
mask[:, :, :-1] |= horizontal
# Detect vertical differences.
vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1)
mask[:, 1:, :] |= vertical
mask[:, :-1, :] |= vertical
# Detect diagonal (top left to bottom right) differences.
tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1)
mask[:, 1:, 1:] |= tlbr
mask[:, :-1, :-1] |= tlbr
# Detect diagonal (top right to bottom left) differences.
trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1)
mask[:, :-1, 1:] |= trbl
mask[:, 1:, :-1] |= trbl
return mask
def reduce_straight_alpha(
rgba: Float[Tensor, "batch 4 height width"],
) -> Float[Tensor, "batch 4"]:
color, alpha = rgba.split((3, 1), dim=1)
# Color becomes a weighted average of color (weighted by alpha).
weighted_color = reduce(color * alpha, "b c h w -> b c", "sum")
alpha_sum = reduce(alpha, "b c h w -> b c", "sum")
color = weighted_color / (alpha_sum + 1e-10)
# Alpha becomes mean alpha.
alpha = reduce(alpha, "b c h w -> b c", "mean")
return torch.cat((color, alpha), dim=-1)
@torch.no_grad()
def run_msaa_pass(
xy: Float[Tensor, "batch height width 2"],
color_function: ColorFunction,
scale: float,
subdivision: int,
remaining_passes: int,
device: torch.device,
batch_size: int = int(2**16),
) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha)
# Sample the color function.
b, h, w, _ = xy.shape
color = [
color_function(batch)
for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size)
]
color = torch.cat(color, dim=0)
color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w)
# If any MSAA passes remain, subdivide.
if remaining_passes > 0:
mask = detect_msaa_pixels(color)
batch_index, row_index, col_index = torch.where(mask)
xy = xy[batch_index, row_index, col_index]
offsets = generate_sample_grid((subdivision, subdivision), device)
offsets = (offsets / subdivision - 0.5) * scale
color_fine = run_msaa_pass(
xy[:, None, None] + offsets,
color_function,
scale / subdivision,
subdivision,
remaining_passes - 1,
device,
batch_size=batch_size,
)
color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine)
return color
@torch.no_grad()
def render(
shape: tuple[int, int],
color_function: ColorFunction,
device: torch.device,
subdivision: int = 8,
num_passes: int = 2,
) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha)
xy = generate_sample_grid(shape, device)
return run_msaa_pass(
xy[None],
color_function,
1.0,
subdivision,
num_passes,
device,
)[0]
def render_over_image(
image: Float[Tensor, "3 height width"],
color_function: ColorFunction,
device: torch.device,
subdivision: int = 8,
num_passes: int = 1,
) -> Float[Tensor, "3 height width"]:
_, h, w = image.shape
overlay = render(
(h, w),
color_function,
device,
subdivision=subdivision,
num_passes=num_passes,
)
color, alpha = overlay.split((3, 1), dim=0)
return image * (1 - alpha) + color * alpha
|