project2 / model.py
nikethanreddy's picture
Upload 6 files
865db26 verified
raw
history blame contribute delete
620 Bytes
# model.py
import flax.linen as nn
import jax.numpy as jnp
class AQIPredictor(nn.Module):
features: int
@nn.compact
def __call__(self, x, deterministic: bool):
x = nn.Conv(features=64, kernel_size=(3,))(x)
x = nn.relu(x)
x = nn.LayerNorm()(x)
x = nn.Conv(features=64, kernel_size=(3,))(x)
x = nn.relu(x)
x = nn.LayerNorm()(x)
x = jnp.mean(x, axis=1)
x = nn.Dense(128)(x)
x = nn.Dropout(0.1)(nn.silu(x), deterministic=deterministic)
x = nn.Dense(64)(x)
x = nn.silu(x)
return nn.Dense(1)(x)