from typing import Sequence import flax.linen as nn class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, x): for feat in self.features[:-1]: x = nn.relu(nn.Dense(feat)(x)) x = nn.Dense(self.features[-1])(x) return x def assertEqual(actual, expected, msg, first="Got", second="Expected"): if actual != expected: raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"') def assertIn(actual, expected, msg, first="Got", second="Expected one of"): if actual not in expected: raise ValueError(msg + f' {first}: "{actual}" {second}: {expected}')