Spaces:
Sleeping
Sleeping
| # model.py | |
| import flax.linen as nn | |
| import jax.numpy as jnp | |
| class AQIPredictor(nn.Module): | |
| features: int | |
| def __call__(self, x, deterministic: bool): | |
| x = nn.Conv(features=64, kernel_size=(3,))(x) | |
| x = nn.relu(x) | |
| x = nn.LayerNorm()(x) | |
| x = nn.Conv(features=64, kernel_size=(3,))(x) | |
| x = nn.relu(x) | |
| x = nn.LayerNorm()(x) | |
| x = jnp.mean(x, axis=1) | |
| x = nn.Dense(128)(x) | |
| x = nn.Dropout(0.1)(nn.silu(x), deterministic=deterministic) | |
| x = nn.Dense(64)(x) | |
| x = nn.silu(x) | |
| return nn.Dense(1)(x) | |