DMD-Neural-Operator / DMDNeuralOperator.py
ScientificOperator's picture
Upload 32 files
5515ef1 verified
import numpy as np
import torch
import torch.nn as nn
from pydmd import DMD
class DMDProcessor:
def __init__(self, data: torch.Tensor, rank: int):
"""Process input data using Dynamic Mode Decomposition.
Args:
data: Input tensor of shape (batch_size, ny, nx)
rank: Rank for SVD approximation
"""
self.data = data
self.rank = rank
def _validate_input(self):
if self.rank <= 0:
raise ValueError("Rank must be positive integer")
def _compute_dmd(self):
"""Perform DMD and return reconstructed data."""
try:
snapshots = self.data.reshape(self.data.shape[0], -1).T
dmd = DMD(svd_rank=self.rank)
dmd.fit(snapshots)
if dmd.reconstructed_data is None:
raise RuntimeError("DMD reconstruction failed")
return dmd
except Exception as e:
raise RuntimeError(f"DMD processing failed: {str(e)}")
def _calc_energy(self):
dmd = self._compute_dmd()
energy = np.cumsum(np.abs(dmd.amplitudes)) / np.sum(np.abs(dmd.amplitudes))
n_modes = np.argmax(energy > 0.95) + 1
return n_modes
def method(self):
dmd = self._compute_dmd()
modes = [dmd.modes.real[:, i] for i in range(len(dmd.amplitudes))]
dynamics = [dmd.dynamics.real[i] for i in range(len(dmd.amplitudes))]
return [modes, dynamics]
class DMDNeuralOperator(nn.Module):
def __init__(self, branch1_dim, branch_dmd_dim_modes, branch_dmd_dim_dynamics, trunk_dim):
"""Neural operator with DMD preprocessing.
Args:
branch1_dim: Layer dimensions for primary branch
branch_dmd_dim_modes: Layer dimensions for DMD modes branch
branch_dmd_dim_dynamics: Layer dimensions for DMD dynamics branch
trunk_dims: Layer dimensions for trunk network
"""
super(DMDNeuralOperator, self).__init__()
modules = []
for i, h_dim in enumerate(branch1_dim):
if i == 0:
in_channels = h_dim
else:
modules.append(nn.Sequential(
nn.Linear(in_channels, h_dim),
nn.Tanh()
)
)
in_channels = h_dim
self._branch_1 = nn.Sequential(*modules)
modules = []
for i, h_dim in enumerate(branch_dmd_dim_modes):
if i == 0:
in_channels = h_dim
else:
modules.append(nn.Sequential(
nn.Linear(in_channels, h_dim),
nn.Tanh()
)
)
in_channels = h_dim
self._branch_dmd_modes = nn.Sequential(*modules)
modules = []
for i, h_dim in enumerate(branch_dmd_dim_dynamics):
if i == 0:
in_channels = h_dim
else:
modules.append(nn.Sequential(
nn.Linear(in_channels, h_dim),
nn.Tanh()
)
)
in_channels = h_dim
self._branch_dmd_dynamics = nn.Sequential(*modules)
modules = []
for i, h_dim in enumerate(trunk_dim):
if i == 0:
in_channels = h_dim
else:
modules.append(nn.Sequential(
nn.Linear(in_channels, h_dim),
nn.Tanh()
)
)
in_channels = h_dim
self._trunk = nn.Sequential(*modules)
self.final_linear = nn.Linear(trunk_dim[-1], 10)
def forward(self, f: torch.Tensor, f_dmd_modes: torch.Tensor, f_dmd_dynamics: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
f: Input function (batch_size, *spatial_dims)
x: Evaluation points (num_points, coord_dim)
Returns:
Output tensor (batch_size, num_points)
"""
modes, dynamics = f_dmd_modes, f_dmd_dynamics
branch_dmd_modes = self._branch_dmd_modes(modes)
branch_dmd_dynamics = self._branch_dmd_dynamics(dynamics)
y_branch_dmd = branch_dmd_modes * branch_dmd_dynamics
y_branch1 = self._branch_1(f)
y_br = y_branch1 * y_branch_dmd
y_tr = self._trunk(x)
y_out = y_br @ y_tr
linear_out = nn.Linear(y_out.shape[-1], 10)
tanh_out = nn.Tanh()
y_out = self.final_linear(y_out)
return y_out
def loss(self, f, f_dmd_modes, f_dmd_dynamics, x, y):
y_out = self.forward(f, f_dmd_modes, f_dmd_dynamics, x)
loss = ((y_out - y) ** 2).mean()
return loss