Spaces:
Running
on
L4
Running
on
L4
import math | |
import jax | |
from flax.core import unfreeze, freeze | |
import jax.numpy as jnp | |
import flax.linen as nn | |
from jaxtyping import Array, ArrayLike, PyTree | |
from .edsr import EDSR | |
from .rdn import RDN | |
from .hyper import Hypernetwork | |
from .tail import build_tail | |
from .init import uniform_between, linear_up | |
from utils import make_grid, interpolate_grid, repeat_vmap | |
class Thermal(nn.Module): | |
w0_scale: float = 1. | |
def __call__(self, x: ArrayLike, t, norm, k) -> Array: | |
phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:]) | |
return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t) | |
class TheraField(nn.Module): | |
dim_hidden: int | |
dim_out: int | |
w0: float = 1. | |
c: float = 6. | |
def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array: | |
# coordinate projection according to shared components ("first layer") | |
x = x @ components | |
# thermal activations | |
norm = jnp.linalg.norm(components, axis=-2) | |
x = Thermal(self.w0)(x, t, norm, k) | |
# linear projection from hidden to output space ("second layer") | |
w_std = math.sqrt(self.c / self.dim_hidden) / self.w0 | |
dense_init_fn = uniform_between(-w_std, w_std) | |
x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x) | |
return x | |
class Thera: | |
def __init__( | |
self, | |
hidden_dim: int, | |
out_dim: int, | |
backbone: nn.Module, | |
tail: nn.Module, | |
k_init: float = None, | |
components_init_scale: float = None | |
): | |
self.hidden_dim = hidden_dim | |
self.k_init = k_init | |
self.components_init_scale = components_init_scale | |
# single TheraField object whose `apply` method is used for all grid cells | |
self.field = TheraField(hidden_dim, out_dim) | |
# infer output size of the hypernetwork from a sample pass through the field; | |
# key doesnt matter as field params are only used for size inference | |
sample_params = self.field.init(jax.random.PRNGKey(0), | |
jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim))) | |
sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params) | |
param_shapes = [p.shape for p in sample_params_flat] | |
self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def) | |
def init(self, key, sample_source) -> PyTree: | |
keys = jax.random.split(key, 2) | |
sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,)) | |
params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords)) | |
params['params']['k'] = jnp.array(self.k_init) | |
params['params']['components'] = \ | |
linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim)) | |
return freeze(params) | |
def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array: | |
""" | |
Performs a forward pass through the hypernetwork to obtain an encoding. | |
""" | |
return self.hypernet.apply( | |
params, source, method=self.hypernet.get_encoding, **kwargs) | |
def apply_decoder( | |
self, | |
params: PyTree, | |
encoding: ArrayLike, | |
coords: ArrayLike, | |
t: ArrayLike, | |
return_jac: bool = False | |
) -> Array | tuple[Array, Array]: | |
""" | |
Performs a forward prediction through a grid of HxW Thera fields, | |
informed by `encoding`, at spatial and temporal coordinates | |
`coords` and `t`, respectively. | |
args: | |
params: Field parameters, shape (B, H, W, N) | |
encoding: Encoding tensor, shape (B, H, W, C) | |
coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2) | |
t: Temporal coordinates, shape (B, 1) | |
""" | |
phi_params: PyTree = self.hypernet.apply( | |
params, encoding, coords, method=self.hypernet.get_params_at_coords) | |
# create local coordinate systems | |
source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1])) | |
source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1)) | |
interp_coords = interpolate_grid(coords, source_coords) | |
rel_coords = (coords - interp_coords) | |
rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3]) | |
rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2]) | |
# three maps over params, coords; one over t; dont map k and components | |
in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)] | |
apply_field = repeat_vmap(self.field.apply, in_axes) | |
out = apply_field(phi_params, rel_coords, t, params['params']['k'], | |
params['params']['components']) | |
if return_jac: | |
apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes) | |
jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'], | |
params['params']['components']) | |
return out, jac | |
return out | |
def apply( | |
self, | |
params: ArrayLike, | |
source: ArrayLike, | |
coords: ArrayLike, | |
t: ArrayLike, | |
return_jac: bool = False, | |
**kwargs | |
) -> Array: | |
""" | |
Performs a forward pass through the Thera model. | |
""" | |
encoding = self.apply_encoder(params, source, **kwargs) | |
out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac) | |
return out | |
def build_thera( | |
out_dim: int, | |
backbone: str, | |
size: str, | |
k_init: float = None, | |
components_init_scale: float = None | |
): | |
""" | |
Convenience function for building the three Thera sizes described in the paper. | |
""" | |
hidden_dim = 32 if size == 'air' else 512 | |
if backbone == 'edsr-baseline': | |
backbone_module = EDSR(None, num_blocks=16, num_feats=64) | |
elif backbone == 'rdn': | |
backbone_module = RDN() | |
else: | |
raise NotImplementedError(backbone) | |
tail_module = build_tail(size) | |
return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale) | |