krystv commited on
Commit
d7d1235
·
verified ·
1 Parent(s): 95b50da

Upload liquid_flow/mamba2_ssd.py

Browse files
Files changed (1) hide show
  1. liquid_flow/mamba2_ssd.py +143 -163
liquid_flow/mamba2_ssd.py CHANGED
@@ -1,22 +1,21 @@
1
  """
2
- Mamba-2 SSD — Pure PyTorch, autograd-safe, fully parallel.
3
 
4
- IMPORTANT DESIGN DECISIONS:
5
- 1. NO in-place operations (breaks autograd)
6
- 2. Uses chunk-based scan instead of Blelloch (simpler, still parallel within chunks)
7
- 3. Correct dimension handling for dt, B, C projections
8
- 4. Works on CPU and GPU without custom kernels
9
 
10
- The SSM recurrence:
11
- h_t = exp(A * Δ_t) * h_{t-1} + Δ_t * B_t * x_t
12
- y_t = C_t^T * h_t + D * x_t
13
 
14
- We compute this via the "chunk scan" approach from Mamba-2:
15
- - Split sequence into chunks of size T
16
- - Within each chunk: compute via matrix multiply (O(T²) but T is small)
17
- - Across chunks: carry hidden state forward (O(L/T) steps)
18
 
19
- For L=256 (16×16 latent) with T=16: only 16 chunks, each parallelized.
 
 
 
 
 
 
20
  """
21
 
22
  import torch
@@ -27,30 +26,27 @@ import math
27
 
28
  class Mamba2SSD(nn.Module):
29
  """
30
- Mamba-2 State Space Duality module.
31
-
32
- Pure PyTorch implementation using chunk-scan parallelism.
33
- No in-place ops, fully autograd-compatible.
34
 
35
  Args:
36
  dim: Input/output dimension
37
- d_state: State dimension (default 16)
38
- d_conv: Conv1d kernel size
39
- expand: Inner dimension expansion factor
40
- chunk_size: Chunk size for parallel scan (default 16)
41
  """
42
 
43
- def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=16):
44
  super().__init__()
45
  self.dim = dim
46
  self.d_state = d_state
47
  self.chunk_size = chunk_size
48
  self.inner_dim = dim * expand
49
 
50
- # Input projection: x and z (gate) branches
51
  self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
52
 
53
- # Short conv for local context (depthwise, causal)
54
  self.conv1d = nn.Conv1d(
55
  self.inner_dim, self.inner_dim,
56
  kernel_size=d_conv, padding=d_conv - 1,
@@ -58,169 +54,173 @@ class Mamba2SSD(nn.Module):
58
  )
59
 
60
  # SSM parameter projections
61
- # dt: [inner_dim] — one scalar per channel
62
- # B: [d_state] — state input matrix
63
- # C: [d_state] — state output matrix
64
  self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True)
65
  self.B_proj = nn.Linear(self.inner_dim, d_state, bias=False)
66
  self.C_proj = nn.Linear(self.inner_dim, d_state, bias=False)
67
 
68
- # A: learnable log-space parameter (negative for stability)
69
  A = torch.arange(1, d_state + 1, dtype=torch.float32)
70
- self.A_log = nn.Parameter(torch.log(A)) # [d_state]
71
 
72
- # D: skip connection (residual)
73
  self.D = nn.Parameter(torch.ones(self.inner_dim))
74
 
75
- # Output projection
76
  self.norm = nn.LayerNorm(self.inner_dim)
77
  self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)
78
 
79
  self._init_weights()
80
 
81
  def _init_weights(self):
82
- # Initialize dt bias to small positive values (fast dynamics)
83
  nn.init.constant_(self.dt_proj.bias, -4.0) # softplus(-4) ≈ 0.018
84
  nn.init.xavier_uniform_(self.in_proj.weight, gain=0.1)
85
  nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1)
86
 
87
  def forward(self, x):
88
- """
89
- Args:
90
- x: [B, L, dim]
91
- Returns:
92
- [B, L, dim]
93
- """
94
- return self._process_sequence(x)
95
 
96
- def _process_sequence(self, x):
97
- """Full Mamba-2 SSD forward pass."""
98
  B, L, D = x.shape
99
 
100
  # Input projection
101
- xz = self.in_proj(x) # [B, L, inner_dim * 2]
102
- x_inner, z = xz.chunk(2, dim=-1) # each [B, L, inner_dim]
103
 
104
- # Causal conv1d
105
- x_conv = x_inner.transpose(1, 2) # [B, inner_dim, L]
106
- x_conv = self.conv1d(x_conv)[:, :, :L] # Remove right padding (causal)
107
- x_conv = F.silu(x_conv).transpose(1, 2) # [B, L, inner_dim]
108
 
109
- # Compute SSM parameters (all per-position, parallel)
110
  dt = F.softplus(self.dt_proj(x_conv)) # [B, L, inner_dim], positive
111
- B_mat = self.B_proj(x_conv) # [B, L, d_state]
112
- C_mat = self.C_proj(x_conv) # [B, L, d_state]
 
113
 
114
- # A: negative for stable dynamics
115
- A = -torch.exp(self.A_log) # [d_state]
116
 
117
- # Run selective scan via chunk decomposition
118
- y = self._chunk_scan(x_conv, dt, A, B_mat, C_mat)
119
-
120
- # Add skip connection
121
  y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
 
122
 
123
- # Normalize + gate with z
124
- y = self.norm(y)
125
- y = y * F.silu(z)
126
-
127
- # Output projection
128
  return self.out_proj(y)
129
 
130
- def _chunk_scan(self, u, delta, A, B, C):
131
  """
132
- Chunk-based selective scan (Mamba-2 style).
133
 
134
- Within each chunk: parallel matmul computation.
135
- Across chunks: sequential state propagation (only L/chunk_size steps).
136
 
137
- For L=256, chunk_size=16: only 16 sequential steps, each doing
138
- parallel matmul over 16 positions. Much better than 256 sequential steps.
139
 
140
- Args:
141
- u: [B, L, inner_dim] — input
142
- delta: [B, L, inner_dim] — timestep (positive)
143
- A: [d_state] — state decay (negative)
144
- B: [B, L, d_state] — state input projection
145
- C: [B, L, d_state] — state output projection
146
-
147
- Returns:
148
- y: [B, L, inner_dim]
149
  """
150
  batch, L, d_inner = u.shape
151
  d_state = A.shape[0]
152
- T = self.chunk_size
153
-
154
- # Pad sequence to multiple of chunk_size
155
- pad_len = (T - L % T) % T
156
- if pad_len > 0:
157
- u = F.pad(u, (0, 0, 0, pad_len))
158
- delta = F.pad(delta, (0, 0, 0, pad_len))
159
- B = F.pad(B, (0, 0, 0, pad_len))
160
- C = F.pad(C, (0, 0, 0, pad_len))
161
-
162
- L_padded = u.shape[1]
163
- num_chunks = L_padded // T
164
-
165
- # Reshape into chunks: [B, num_chunks, T, ...]
166
- u_chunks = u.reshape(batch, num_chunks, T, d_inner)
167
- dt_chunks = delta.reshape(batch, num_chunks, T, d_inner)
168
- B_chunks = B.reshape(batch, num_chunks, T, d_state)
169
- C_chunks = C.reshape(batch, num_chunks, T, d_state)
170
-
171
- # Compute discretized A for each position
172
- # dA[b, chunk, t, d_inner, d_state] = exp(delta[b,chunk,t,d_inner] * A[d_state])
173
- # But A is shared across inner_dim, so we expand:
174
- # For scalar-A per state dim: A is [d_state]
175
- # delta is [B, num_chunks, T, d_inner]
176
- # We need: for each (batch, chunk, t): dA = exp(delta_mean * A)
177
- # Simplification: use mean delta across inner_dim for state decay
178
- dt_mean = dt_chunks.mean(dim=-1, keepdim=True) # [B, nc, T, 1]
179
-
180
- # dA per position: [B, nc, T, d_state]
181
- dA = torch.exp(dt_mean * A.view(1, 1, 1, -1)) # [B, nc, T, d_state]
182
-
183
- # dB * u: [B, nc, T, d_state, d_inner]
184
- # B is [B, nc, T, d_state], u is [B, nc, T, d_inner]
185
- # delta is [B, nc, T, d_inner]
186
- dBu = dt_chunks.unsqueeze(-2) * B_chunks.unsqueeze(-1) * u_chunks.unsqueeze(-2)
 
 
 
 
 
 
 
 
 
 
 
 
187
  # dBu: [B, nc, T, d_state, d_inner]
188
 
189
- # Within each chunk: compute states via cumulative product of dA + dBu
190
- # h_t = dA_t * h_{t-1} + dBu_t
191
- # This is a linear recurrence within the chunk — compute via sequential scan
192
- # but T is small (16), so this is fast
193
-
194
- outputs = []
195
- h = torch.zeros(batch, d_state, d_inner, device=u.device, dtype=u.dtype)
196
 
197
- for chunk_idx in range(num_chunks):
198
- chunk_out = torch.zeros(batch, T, d_inner, device=u.device, dtype=u.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- for t in range(T):
201
- # State update: h = dA * h + dBu
202
- h = dA[:, chunk_idx, t, :].unsqueeze(-1) * h + dBu[:, chunk_idx, t, :, :]
203
- # h: [B, d_state, d_inner]
204
-
205
- # Output: y = C^T * h
206
- c_t = C_chunks[:, chunk_idx, t, :] # [B, d_state]
207
- y_t = (c_t.unsqueeze(-1) * h).sum(dim=1) # [B, d_inner]
208
- chunk_out[:, t, :] = y_t
209
 
210
- outputs.append(chunk_out)
 
 
 
 
211
 
212
- y = torch.cat(outputs, dim=1) # [B, L_padded, d_inner]
 
 
 
213
 
214
- # Remove padding
 
215
  return y[:, :L, :]
216
 
217
 
218
  class Mamba2Block(nn.Module):
219
  """
220
  Mamba-2 block with bidirectional scanning for 2D images.
221
-
222
- Uses forward + backward raster scan and averages them.
223
- This captures 2D spatial context without quadratic cost.
224
  """
225
 
226
  def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
@@ -228,50 +228,30 @@ class Mamba2Block(nn.Module):
228
  self.norm1 = nn.LayerNorm(dim)
229
  self.norm2 = nn.LayerNorm(dim)
230
 
231
- # Forward and backward SSM
232
  self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand)
233
  self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand)
234
-
235
- # Merge projection
236
  self.merge = nn.Linear(dim * 2, dim, bias=False)
237
 
238
- # Feed-forward
239
  ff_dim = dim * expand
240
  self.ff = nn.Sequential(
241
- nn.Linear(dim, ff_dim),
242
- nn.GELU(),
243
- nn.Dropout(dropout),
244
- nn.Linear(ff_dim, dim),
245
- nn.Dropout(dropout),
246
  )
247
 
248
  def forward(self, x):
249
- """
250
- Args:
251
- x: [B, C, H, W] or [B, L, C]
252
- Returns:
253
- Same shape as input
254
- """
255
  is_2d = x.dim() == 4
256
  if is_2d:
257
  B, C, H, W = x.shape
258
- x = x.flatten(2).transpose(1, 2) # [B, HW, C]
259
 
260
  residual = x
261
  x_norm = self.norm1(x)
262
 
263
- # Forward scan
264
- fwd_out = self.ssd_fwd(x_norm)
265
-
266
- # Backward scan (flip, process, flip back)
267
- x_flip = torch.flip(x_norm, dims=[1])
268
- bwd_out = self.ssd_bwd(x_flip)
269
- bwd_out = torch.flip(bwd_out, dims=[1])
270
-
271
- # Merge both directions
272
- merged = self.merge(torch.cat([fwd_out, bwd_out], dim=-1))
273
 
274
- # Residual + FF
275
  x = residual + merged
276
  x = x + self.ff(self.norm2(x))
277
 
 
1
  """
2
+ Mamba-2 SSD — OPTIMIZED: intra-chunk parallelism via matrix multiply.
3
 
4
+ The key Mamba-2 insight (State Space Duality):
5
+ Within each chunk of size T, the SSM can be computed as a MATRIX MULTIPLY:
 
 
 
6
 
7
+ Y_chunk = (L ⊙ (C B^T)) @ (Δ ⊙ X)
 
 
8
 
9
+ Where L is a lower-triangular mask with cumulative A products.
10
+ This replaces the T sequential steps with a single matmul of size T×T.
 
 
11
 
12
+ For L=256, T=16, num_chunks=16:
13
+ - Within chunk: parallel matmul (T×T = 16×16)
14
+ - Across chunks: 16 sequential state carries (unavoidable, but trivial)
15
+
16
+ Total: 16 sequential state carries + 16 parallel matmuls = FAST.
17
+
18
+ NO in-place ops. Fully autograd safe. Works on CPU and GPU.
19
  """
20
 
21
  import torch
 
26
 
27
  class Mamba2SSD(nn.Module):
28
  """
29
+ Mamba-2 SSD with intra-chunk matrix-multiply parallelism.
 
 
 
30
 
31
  Args:
32
  dim: Input/output dimension
33
+ d_state: SSM state dimension (default 16)
34
+ d_conv: Conv1d kernel size (default 4)
35
+ expand: Inner dimension expansion (default 2)
36
+ chunk_size: Chunk size for scan (default 64 — larger = more parallel)
37
  """
38
 
39
+ def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=64):
40
  super().__init__()
41
  self.dim = dim
42
  self.d_state = d_state
43
  self.chunk_size = chunk_size
44
  self.inner_dim = dim * expand
45
 
46
+ # Input projection: x and gate
47
  self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
48
 
49
+ # Short causal conv for local context
50
  self.conv1d = nn.Conv1d(
51
  self.inner_dim, self.inner_dim,
52
  kernel_size=d_conv, padding=d_conv - 1,
 
54
  )
55
 
56
  # SSM parameter projections
 
 
 
57
  self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True)
58
  self.B_proj = nn.Linear(self.inner_dim, d_state, bias=False)
59
  self.C_proj = nn.Linear(self.inner_dim, d_state, bias=False)
60
 
61
+ # A: fixed decay rates (log-space, negative for stability)
62
  A = torch.arange(1, d_state + 1, dtype=torch.float32)
63
+ self.A_log = nn.Parameter(torch.log(A))
64
 
65
+ # D: residual skip
66
  self.D = nn.Parameter(torch.ones(self.inner_dim))
67
 
68
+ # Output
69
  self.norm = nn.LayerNorm(self.inner_dim)
70
  self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)
71
 
72
  self._init_weights()
73
 
74
  def _init_weights(self):
 
75
  nn.init.constant_(self.dt_proj.bias, -4.0) # softplus(-4) ≈ 0.018
76
  nn.init.xavier_uniform_(self.in_proj.weight, gain=0.1)
77
  nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1)
78
 
79
  def forward(self, x):
80
+ """x: [B, L, dim] → [B, L, dim]"""
81
+ return self._process(x)
 
 
 
 
 
82
 
83
+ def _process(self, x):
 
84
  B, L, D = x.shape
85
 
86
  # Input projection
87
+ xz = self.in_proj(x)
88
+ x_inner, z = xz.chunk(2, dim=-1)
89
 
90
+ # Causal conv
91
+ x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
92
+ x_conv = F.silu(x_conv)
 
93
 
94
+ # SSM params
95
  dt = F.softplus(self.dt_proj(x_conv)) # [B, L, inner_dim], positive
96
+ B_mat = self.B_proj(x_conv) # [B, L, d_state]
97
+ C_mat = self.C_proj(x_conv) # [B, L, d_state]
98
+ A = -torch.exp(self.A_log) # [d_state], negative
99
 
100
+ # Chunk-parallel scan
101
+ y = self._chunk_ssm(x_conv, dt, A, B_mat, C_mat)
102
 
103
+ # Skip + norm + gate
 
 
 
104
  y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
105
+ y = self.norm(y) * F.silu(z)
106
 
 
 
 
 
 
107
  return self.out_proj(y)
108
 
109
+ def _chunk_ssm(self, u, dt, A, B, C):
110
  """
111
+ Chunk-parallel SSM computation.
112
 
113
+ Within each chunk: compute via cumulative decay matrix (parallel).
114
+ Across chunks: propagate final state (sequential, only num_chunks steps).
115
 
116
+ The intra-chunk computation uses the identity:
117
+ h_t = sum_{s=0}^{t} (prod_{k=s+1}^{t} dA_k) * dB_s * u_s
118
 
119
+ This is a lower-triangular matrix-vector product, computable in parallel.
 
 
 
 
 
 
 
 
120
  """
121
  batch, L, d_inner = u.shape
122
  d_state = A.shape[0]
123
+ T = min(self.chunk_size, L)
124
+
125
+ # Pad to multiple of T
126
+ pad = (T - L % T) % T
127
+ if pad > 0:
128
+ u = F.pad(u, (0, 0, 0, pad))
129
+ dt = F.pad(dt, (0, 0, 0, pad))
130
+ B = F.pad(B, (0, 0, 0, pad))
131
+ C = F.pad(C, (0, 0, 0, pad))
132
+
133
+ L_pad = u.shape[1]
134
+ n_chunks = L_pad // T
135
+
136
+ # Reshape: [B, n_chunks, T, ...]
137
+ u_c = u.reshape(batch, n_chunks, T, d_inner)
138
+ dt_c = dt.reshape(batch, n_chunks, T, d_inner)
139
+ B_c = B.reshape(batch, n_chunks, T, d_state)
140
+ C_c = C.reshape(batch, n_chunks, T, d_state)
141
+
142
+ # Mean dt per position for state decay (simplification for scalar-A)
143
+ dt_mean = dt_c.mean(dim=-1) # [B, n_chunks, T]
144
+
145
+ # Compute log(dA) per position: log_dA = dt_mean * A
146
+ # A is [d_state], dt_mean is [B, nc, T]
147
+ log_dA = dt_mean.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0).unsqueeze(0)
148
+ # log_dA: [B, nc, T, d_state]
149
+
150
+ # Cumulative sum for decay within chunk: cumsum along T dimension
151
+ # For position t, decay from position s is: exp(sum_{k=s+1}^{t} log_dA_k)
152
+ log_dA_cumsum = torch.cumsum(log_dA, dim=2) # [B, nc, T, d_state]
153
+
154
+ # Lower-triangular decay matrix: L[t,s] = exp(cumsum[t] - cumsum[s])
155
+ # L[t,s,n] = exp(sum_{k=s+1}^{t} log_dA_k_n) for t >= s, else 0
156
+ # Shape: [B, nc, T, T, d_state]
157
+ decay_matrix = log_dA_cumsum.unsqueeze(3) - log_dA_cumsum.unsqueeze(2)
158
+ # decay_matrix[..., t, s, :] = cumsum[t] - cumsum[s]
159
+
160
+ # Apply causal mask (t >= s only)
161
+ causal_mask = torch.tril(torch.ones(T, T, device=u.device)) # [T, T]
162
+ decay_matrix = decay_matrix * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
163
+ decay_matrix = torch.exp(decay_matrix) * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
164
+ # [B, nc, T, T, d_state]
165
+
166
+ # Compute dBu: dt * B * u → state input at each position
167
+ # dt_c: [B, nc, T, d_inner], B_c: [B, nc, T, d_state], u_c: [B, nc, T, d_inner]
168
+ # We need [B, nc, T, d_state, d_inner]
169
+ dBu = dt_c.unsqueeze(-2) * B_c.unsqueeze(-1) * u_c.unsqueeze(-2)
170
  # dBu: [B, nc, T, d_state, d_inner]
171
 
172
+ # Intra-chunk SSM via matrix multiply:
173
+ # h[t] = sum_s decay[t,s] * dBu[s]
174
+ # h: [B, nc, T, d_state, d_inner]
175
+ # decay_matrix: [B, nc, T, T, d_state]
176
+ # dBu: [B, nc, T, d_state, d_inner]
 
 
177
 
178
+ # Einsum: h[b,c,t,n,d] = sum_s decay[b,c,t,s,n] * dBu[b,c,s,n,d]
179
+ h_intra = torch.einsum('bctsn,bcsnd->bctnd', decay_matrix, dBu)
180
+ # h_intra: [B, nc, T, d_state, d_inner]
181
+
182
+ # Inter-chunk state propagation
183
+ # Decay of previous chunk's final state into current chunk
184
+ # Total decay for a full chunk: exp(sum of all T log_dA values)
185
+ chunk_decay = torch.exp(log_dA_cumsum[:, :, -1, :]) # [B, nc, d_state]
186
+ # Decay from chunk start to each position within chunk:
187
+ # position_decay[t] = exp(cumsum[t]) (from position 0)
188
+ position_decay = torch.exp(log_dA_cumsum) # [B, nc, T, d_state]
189
+
190
+ # Propagate states across chunks
191
+ h_carry = torch.zeros(batch, d_state, d_inner, device=u.device)
192
+ h_chunks = []
193
+
194
+ for c_idx in range(n_chunks):
195
+ # Decay carry state to each position in this chunk
196
+ # h_from_prev[t] = position_decay[t] * h_carry
197
+ h_from_prev = position_decay[:, c_idx, :, :].unsqueeze(-1) * h_carry.unsqueeze(1)
198
+ # h_from_prev: [B, T, d_state, d_inner]
199
 
200
+ # Total hidden state
201
+ h_total = h_intra[:, c_idx] + h_from_prev # [B, T, d_state, d_inner]
202
+ h_chunks.append(h_total)
 
 
 
 
 
 
203
 
204
+ # Update carry: final state of this chunk
205
+ h_carry = h_total[:, -1, :, :] # [B, d_state, d_inner]
206
+
207
+ # Stack chunks: [B, nc, T, d_state, d_inner]
208
+ h_all = torch.stack(h_chunks, dim=1)
209
 
210
+ # Output: y[t] = C[t]^T @ h[t]
211
+ # C_c: [B, nc, T, d_state], h_all: [B, nc, T, d_state, d_inner]
212
+ y = torch.einsum('bctn,bctnd->bctd', C_c, h_all)
213
+ # y: [B, nc, T, d_inner]
214
 
215
+ # Reshape back
216
+ y = y.reshape(batch, L_pad, d_inner)
217
  return y[:, :L, :]
218
 
219
 
220
  class Mamba2Block(nn.Module):
221
  """
222
  Mamba-2 block with bidirectional scanning for 2D images.
223
+ Forward + backward raster scan, merged via learned projection.
 
 
224
  """
225
 
226
  def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
 
228
  self.norm1 = nn.LayerNorm(dim)
229
  self.norm2 = nn.LayerNorm(dim)
230
 
 
231
  self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand)
232
  self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand)
 
 
233
  self.merge = nn.Linear(dim * 2, dim, bias=False)
234
 
 
235
  ff_dim = dim * expand
236
  self.ff = nn.Sequential(
237
+ nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout),
238
+ nn.Linear(ff_dim, dim), nn.Dropout(dropout),
 
 
 
239
  )
240
 
241
  def forward(self, x):
242
+ """x: [B, C, H, W] or [B, L, C]"""
 
 
 
 
 
243
  is_2d = x.dim() == 4
244
  if is_2d:
245
  B, C, H, W = x.shape
246
+ x = x.flatten(2).transpose(1, 2)
247
 
248
  residual = x
249
  x_norm = self.norm1(x)
250
 
251
+ fwd = self.ssd_fwd(x_norm)
252
+ bwd = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1])
253
+ merged = self.merge(torch.cat([fwd, bwd], dim=-1))
 
 
 
 
 
 
 
254
 
 
255
  x = residual + merged
256
  x = x + self.ff(self.norm2(x))
257