pokkiri's picture
Update model.py
b9cff30 verified
"""
StableResNet Model for Biomass Prediction
A numerically stable ResNet architecture for regression tasks
Author: najahpokkiri
Date: 2025-05-17
"""
"""
StableResNet Model Architecture
This module defines the StableResNet architecture used for biomass prediction.
The model is designed for numerical stability with batch normalization and residual connections.
Author: najahpokkiri
Date: 2025-05-17
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class StableResNet(nn.Module):
"""Numerically stable ResNet for biomass regression"""
def __init__(self, n_features, dropout=0.2):
super().__init__()
self.input_proj = nn.Sequential(
nn.Linear(n_features, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Dropout(dropout)
)
self.layer1 = self._make_simple_resblock(256, 256)
self.layer2 = self._make_simple_resblock(256, 128)
self.layer3 = self._make_simple_resblock(128, 64)
self.regressor = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
self._init_weights()
def _make_simple_resblock(self, in_dim, out_dim):
"""Create a simple residual block or downsampling block"""
if in_dim == out_dim:
# Residual block
return nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.BatchNorm1d(out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
nn.BatchNorm1d(out_dim),
nn.ReLU()
)
else:
# Downsampling block
return nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.BatchNorm1d(out_dim),
nn.ReLU(),
)
def _init_weights(self):
"""Initialize weights for better convergence"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
"""Forward pass through the network"""
x = self.input_proj(x)
# First residual block
identity = x
out = self.layer1(x)
x = out + identity
# Remaining blocks
x = self.layer2(x)
x = self.layer3(x)
# Regression output
x = self.regressor(x)
return x.squeeze()
def get_model_info():
"""Return information about the model architecture"""
return {
'name': 'StableResNet',
'description': 'Numerically stable ResNet for biomass regression',
'parameters': {
'n_features': 'Number of input features',
'dropout': 'Dropout rate (default: 0.2)'
},
'architecture': [
'Input projection with layer normalization',
'Residual blocks with batch normalization',
'Downsampling blocks',
'Regression head'
]
}