| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from pydmd import DMD | |
| class DMDProcessor: | |
| def __init__(self, data: torch.Tensor, rank: int): | |
| """Process input data using Dynamic Mode Decomposition. | |
| Args: | |
| data: Input tensor of shape (batch_size, ny, nx) | |
| rank: Rank for SVD approximation | |
| """ | |
| self.data = data | |
| self.rank = rank | |
| def _validate_input(self): | |
| if self.rank <= 0: | |
| raise ValueError("Rank must be positive integer") | |
| def _compute_dmd(self): | |
| """Perform DMD and return reconstructed data.""" | |
| try: | |
| snapshots = self.data.reshape(self.data.shape[0], -1).T | |
| dmd = DMD(svd_rank=self.rank) | |
| dmd.fit(snapshots) | |
| if dmd.reconstructed_data is None: | |
| raise RuntimeError("DMD reconstruction failed") | |
| return dmd | |
| except Exception as e: | |
| raise RuntimeError(f"DMD processing failed: {str(e)}") | |
| def _calc_energy(self): | |
| dmd = self._compute_dmd() | |
| energy = np.cumsum(np.abs(dmd.amplitudes)) / np.sum(np.abs(dmd.amplitudes)) | |
| n_modes = np.argmax(energy > 0.95) + 1 | |
| return n_modes | |
| def method(self): | |
| dmd = self._compute_dmd() | |
| modes = [dmd.modes.real[:, i] for i in range(len(dmd.amplitudes))] | |
| dynamics = [dmd.dynamics.real[i] for i in range(len(dmd.amplitudes))] | |
| return [modes, dynamics] | |
| class DMDNeuralOperator(nn.Module): | |
| def __init__(self, branch1_dim, branch_dmd_dim_modes, branch_dmd_dim_dynamics, trunk_dim): | |
| """Neural operator with DMD preprocessing. | |
| Args: | |
| branch1_dim: Layer dimensions for primary branch | |
| branch_dmd_dim_modes: Layer dimensions for DMD modes branch | |
| branch_dmd_dim_dynamics: Layer dimensions for DMD dynamics branch | |
| trunk_dims: Layer dimensions for trunk network | |
| """ | |
| super(DMDNeuralOperator, self).__init__() | |
| modules = [] | |
| for i, h_dim in enumerate(branch1_dim): | |
| if i == 0: | |
| in_channels = h_dim | |
| else: | |
| modules.append(nn.Sequential( | |
| nn.Linear(in_channels, h_dim), | |
| nn.Tanh() | |
| ) | |
| ) | |
| in_channels = h_dim | |
| self._branch_1 = nn.Sequential(*modules) | |
| modules = [] | |
| for i, h_dim in enumerate(branch_dmd_dim_modes): | |
| if i == 0: | |
| in_channels = h_dim | |
| else: | |
| modules.append(nn.Sequential( | |
| nn.Linear(in_channels, h_dim), | |
| nn.Tanh() | |
| ) | |
| ) | |
| in_channels = h_dim | |
| self._branch_dmd_modes = nn.Sequential(*modules) | |
| modules = [] | |
| for i, h_dim in enumerate(branch_dmd_dim_dynamics): | |
| if i == 0: | |
| in_channels = h_dim | |
| else: | |
| modules.append(nn.Sequential( | |
| nn.Linear(in_channels, h_dim), | |
| nn.Tanh() | |
| ) | |
| ) | |
| in_channels = h_dim | |
| self._branch_dmd_dynamics = nn.Sequential(*modules) | |
| modules = [] | |
| for i, h_dim in enumerate(trunk_dim): | |
| if i == 0: | |
| in_channels = h_dim | |
| else: | |
| modules.append(nn.Sequential( | |
| nn.Linear(in_channels, h_dim), | |
| nn.Tanh() | |
| ) | |
| ) | |
| in_channels = h_dim | |
| self._trunk = nn.Sequential(*modules) | |
| self.final_linear = nn.Linear(trunk_dim[-1], 10) | |
| def forward(self, f: torch.Tensor, f_dmd_modes: torch.Tensor, f_dmd_dynamics: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass. | |
| Args: | |
| f: Input function (batch_size, *spatial_dims) | |
| x: Evaluation points (num_points, coord_dim) | |
| Returns: | |
| Output tensor (batch_size, num_points) | |
| """ | |
| modes, dynamics = f_dmd_modes, f_dmd_dynamics | |
| branch_dmd_modes = self._branch_dmd_modes(modes) | |
| branch_dmd_dynamics = self._branch_dmd_dynamics(dynamics) | |
| y_branch_dmd = branch_dmd_modes * branch_dmd_dynamics | |
| y_branch1 = self._branch_1(f) | |
| y_br = y_branch1 * y_branch_dmd | |
| y_tr = self._trunk(x) | |
| y_out = y_br @ y_tr | |
| linear_out = nn.Linear(y_out.shape[-1], 10) | |
| tanh_out = nn.Tanh() | |
| y_out = self.final_linear(y_out) | |
| return y_out | |
| def loss(self, f, f_dmd_modes, f_dmd_dynamics, x, y): | |
| y_out = self.forward(f, f_dmd_modes, f_dmd_dynamics, x) | |
| loss = ((y_out - y) ** 2).mean() | |
| return loss | |