File size: 4,421 Bytes
5515ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import numpy as np
import torch
import torch.nn as nn
from pydmd import DMD


class DMDProcessor:
    def __init__(self, data: torch.Tensor, rank: int):
        """Process input data using Dynamic Mode Decomposition.

        Args:
            data: Input tensor of shape (batch_size, ny, nx)
            rank: Rank for SVD approximation
        """
        self.data = data
        self.rank = rank

    def _validate_input(self):
        if self.rank <= 0:
            raise ValueError("Rank must be positive integer")

    def _compute_dmd(self):
        """Perform DMD and return reconstructed data."""
        try:
            snapshots = self.data.reshape(self.data.shape[0], -1).T
            dmd = DMD(svd_rank=self.rank)
            dmd.fit(snapshots)

            if dmd.reconstructed_data is None:
                raise RuntimeError("DMD reconstruction failed")

            return dmd

        except Exception as e:
            raise RuntimeError(f"DMD processing failed: {str(e)}")

    def _calc_energy(self):
        dmd = self._compute_dmd()
        energy = np.cumsum(np.abs(dmd.amplitudes)) / np.sum(np.abs(dmd.amplitudes))
        n_modes = np.argmax(energy > 0.95) + 1
        return n_modes

    def method(self):
        dmd = self._compute_dmd()

        modes = [dmd.modes.real[:, i] for i in range(len(dmd.amplitudes))]
        dynamics = [dmd.dynamics.real[i] for i in range(len(dmd.amplitudes))]
        return [modes, dynamics]


  class DMDNeuralOperator(nn.Module):
  def __init__(self, branch1_dim, branch_dmd_dim_modes, branch_dmd_dim_dynamics, trunk_dim):
      """Neural operator with DMD preprocessing.

      Args:
          branch1_dim: Layer dimensions for primary branch
          branch_dmd_dim_modes: Layer dimensions for DMD modes branch
          branch_dmd_dim_dynamics: Layer dimensions for DMD dynamics branch
          trunk_dims: Layer dimensions for trunk network
      """
      super(DMDNeuralOperator, self).__init__()

      modules = []
      for i, h_dim in enumerate(branch1_dim):
        if i == 0:
          in_channels = h_dim
        else:
          modules.append(nn.Sequential(
              nn.Linear(in_channels, h_dim),
              nn.Tanh()
              )
          )
          in_channels = h_dim

      self._branch_1 = nn.Sequential(*modules)

      modules = []
      for i, h_dim in enumerate(branch_dmd_dim_modes):
        if i == 0:
          in_channels = h_dim
        else:
          modules.append(nn.Sequential(
              nn.Linear(in_channels, h_dim),
              nn.Tanh()
              )
          )
          in_channels = h_dim
      self._branch_dmd_modes = nn.Sequential(*modules)

      modules = []
      for i, h_dim in enumerate(branch_dmd_dim_dynamics):
        if i == 0:
          in_channels = h_dim
        else:
          modules.append(nn.Sequential(
              nn.Linear(in_channels, h_dim),
              nn.Tanh()
              )
          )
          in_channels = h_dim
      self._branch_dmd_dynamics = nn.Sequential(*modules)

      modules = []
      for i, h_dim in enumerate(trunk_dim):
        if i == 0:
          in_channels = h_dim
        else:
          modules.append(nn.Sequential(
              nn.Linear(in_channels, h_dim),
              nn.Tanh()
              )
          )
          in_channels = h_dim

      self._trunk = nn.Sequential(*modules)

      self.final_linear = nn.Linear(trunk_dim[-1], 10)

  def forward(self, f: torch.Tensor, f_dmd_modes: torch.Tensor, f_dmd_dynamics: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """Forward pass.

    Args:
        f: Input function (batch_size, *spatial_dims)
        x: Evaluation points (num_points, coord_dim)

    Returns:
        Output tensor (batch_size, num_points)
    """
    modes, dynamics = f_dmd_modes, f_dmd_dynamics

    branch_dmd_modes = self._branch_dmd_modes(modes)
    branch_dmd_dynamics = self._branch_dmd_dynamics(dynamics)
    y_branch_dmd = branch_dmd_modes * branch_dmd_dynamics

    y_branch1 = self._branch_1(f)
    y_br = y_branch1 * y_branch_dmd

    y_tr = self._trunk(x)

    y_out = y_br @ y_tr

    linear_out = nn.Linear(y_out.shape[-1], 10)
    tanh_out = nn.Tanh()

    y_out = self.final_linear(y_out)

    return y_out

  def loss(self, f, f_dmd_modes, f_dmd_dynamics, x, y):
    y_out = self.forward(f, f_dmd_modes, f_dmd_dynamics, x)
    loss = ((y_out - y) ** 2).mean()
    return loss