rogermt commited on
Commit
abe114d
·
verified ·
1 Parent(s): e23f433

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +327 -0
model.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """model.py — Neural network architectures for NSGF/NSGF++.
2
+
3
+ Contains:
4
+ - VelocityMLP: MLP for 2D velocity field matching
5
+ - VelocityUNet: UNet for image velocity field matching (NSGF + NSF)
6
+ - PhaseTransitionPredictor: CNN for predicting transition time t_ϕ(x)
7
+
8
+ Reference: arXiv:2401.14069, Appendix E.1, E.2
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from typing import List, Optional
16
+
17
+
18
+ class SinusoidalPosEmb(nn.Module):
19
+ def __init__(self, dim: int):
20
+ super().__init__()
21
+ self.dim = dim
22
+
23
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
24
+ device = t.device
25
+ half_dim = self.dim // 2
26
+ emb = math.log(10000.0) / (half_dim - 1)
27
+ emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
28
+ emb = t.float().unsqueeze(-1) * emb.unsqueeze(0)
29
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
30
+ return emb
31
+
32
+
33
+ class VelocityMLP(nn.Module):
34
+ """MLP velocity field for 2D experiments.
35
+ Paper: 3 hidden layers, 256 hidden units, SiLU activation.
36
+ """
37
+ def __init__(self, input_dim: int = 2, hidden_dim: int = 256,
38
+ num_hidden_layers: int = 3, time_emb_dim: int = 64):
39
+ super().__init__()
40
+ self.time_emb = SinusoidalPosEmb(time_emb_dim)
41
+ layers = []
42
+ layers.append(nn.Linear(input_dim + time_emb_dim, hidden_dim))
43
+ layers.append(nn.SiLU())
44
+ for _ in range(num_hidden_layers - 1):
45
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
46
+ layers.append(nn.SiLU())
47
+ layers.append(nn.Linear(hidden_dim, input_dim))
48
+ self.net = nn.Sequential(*layers)
49
+
50
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
51
+ t_emb = self.time_emb(t)
52
+ xt = torch.cat([x, t_emb], dim=-1)
53
+ return self.net(xt)
54
+
55
+
56
+ class ResBlock(nn.Module):
57
+ """Residual block with AdaGN timestep conditioning."""
58
+ def __init__(self, channels: int, emb_dim: int, out_channels: Optional[int] = None,
59
+ dropout: float = 0.0, use_scale_shift_norm: bool = True):
60
+ super().__init__()
61
+ self.out_channels = out_channels or channels
62
+ self.use_scale_shift_norm = use_scale_shift_norm
63
+ self.norm1 = nn.GroupNorm(32, channels)
64
+ self.conv1 = nn.Conv2d(channels, self.out_channels, 3, padding=1)
65
+ self.time_proj = nn.Sequential(
66
+ nn.SiLU(),
67
+ nn.Linear(emb_dim, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
68
+ )
69
+ self.norm2 = nn.GroupNorm(32, self.out_channels)
70
+ self.dropout = nn.Dropout(dropout)
71
+ self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)
72
+ if channels != self.out_channels:
73
+ self.skip = nn.Conv2d(channels, self.out_channels, 1)
74
+ else:
75
+ self.skip = nn.Identity()
76
+ nn.init.zeros_(self.conv2.weight)
77
+ nn.init.zeros_(self.conv2.bias)
78
+
79
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
80
+ h = self.norm1(x)
81
+ h = F.silu(h)
82
+ h = self.conv1(h)
83
+ emb_out = self.time_proj(emb)[:, :, None, None]
84
+ if self.use_scale_shift_norm:
85
+ scale, shift = emb_out.chunk(2, dim=1)
86
+ h = self.norm2(h) * (1 + scale) + shift
87
+ else:
88
+ h = self.norm2(h + emb_out)
89
+ h = F.silu(h)
90
+ h = self.dropout(h)
91
+ h = self.conv2(h)
92
+ return h + self.skip(x)
93
+
94
+
95
+ class AttentionBlock(nn.Module):
96
+ def __init__(self, channels: int, num_heads: int = 1, num_head_channels: int = -1):
97
+ super().__init__()
98
+ if num_head_channels > 0:
99
+ self.num_heads = channels // num_head_channels
100
+ else:
101
+ self.num_heads = num_heads
102
+ self.norm = nn.GroupNorm(32, channels)
103
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
104
+ self.proj = nn.Conv1d(channels, channels, 1)
105
+ nn.init.zeros_(self.proj.weight)
106
+ nn.init.zeros_(self.proj.bias)
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ B, C, H, W = x.shape
110
+ h = self.norm(x).view(B, C, -1)
111
+ qkv = self.qkv(h).reshape(B, 3, self.num_heads, C // self.num_heads, -1)
112
+ q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
113
+ scale = (C // self.num_heads) ** -0.5
114
+ attn = torch.einsum("bhcn,bhcm->bhnm", q, k) * scale
115
+ attn = attn.softmax(dim=-1)
116
+ out = torch.einsum("bhnm,bhcm->bhcn", attn, v)
117
+ out = out.reshape(B, C, -1)
118
+ out = self.proj(out).view(B, C, H, W)
119
+ return x + out
120
+
121
+
122
+ class Downsample(nn.Module):
123
+ def __init__(self, channels: int):
124
+ super().__init__()
125
+ self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ return self.conv(x)
128
+
129
+
130
+ class Upsample(nn.Module):
131
+ def __init__(self, channels: int):
132
+ super().__init__()
133
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
136
+ return self.conv(x)
137
+
138
+
139
+ class VelocityUNet(nn.Module):
140
+ """UNet velocity field for image experiments (Dhariwal & Nichol 2021).
141
+
142
+ MNIST: channels=32, depth=1, ch_mult=[1,2,2], heads=1
143
+ CIFAR-10: channels=128, depth=2, ch_mult=[1,2,2,2], heads=4, head_ch=64
144
+ """
145
+ def __init__(self, image_size: int = 32, in_channels: int = 3,
146
+ model_channels: int = 128, num_res_blocks: int = 2,
147
+ channel_mult: List[int] = [1, 2, 2, 2],
148
+ attention_resolutions: List[int] = [16],
149
+ num_heads: int = 4, num_head_channels: int = 64,
150
+ dropout: float = 0.0, use_scale_shift_norm: bool = True):
151
+ super().__init__()
152
+ self.image_size = image_size
153
+ self.in_channels = in_channels
154
+ self.model_channels = model_channels
155
+
156
+ time_dim = model_channels * 4
157
+ self.time_embed = nn.Sequential(
158
+ SinusoidalPosEmb(model_channels),
159
+ nn.Linear(model_channels, time_dim), nn.SiLU(),
160
+ nn.Linear(time_dim, time_dim),
161
+ )
162
+ self.input_conv = nn.Conv2d(in_channels, model_channels, 3, padding=1)
163
+
164
+ self.down_blocks = nn.ModuleList()
165
+ self.down_attns = nn.ModuleList()
166
+ self.downsamplers = nn.ModuleList()
167
+
168
+ ch = model_channels
169
+ ds = image_size
170
+ input_block_channels = [ch]
171
+
172
+ for level, mult in enumerate(channel_mult):
173
+ out_ch = model_channels * mult
174
+ for _ in range(num_res_blocks):
175
+ block = ResBlock(ch, time_dim, out_ch, dropout, use_scale_shift_norm)
176
+ self.down_blocks.append(block)
177
+ if ds in attention_resolutions:
178
+ self.down_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels))
179
+ else:
180
+ self.down_attns.append(nn.Identity())
181
+ ch = out_ch
182
+ input_block_channels.append(ch)
183
+ if level < len(channel_mult) - 1:
184
+ self.downsamplers.append(Downsample(ch))
185
+ ds //= 2
186
+ input_block_channels.append(ch)
187
+ else:
188
+ self.downsamplers.append(nn.Identity())
189
+
190
+ self.mid_block1 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm)
191
+ self.mid_attn = AttentionBlock(ch, num_heads, num_head_channels)
192
+ self.mid_block2 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm)
193
+
194
+ self.up_blocks = nn.ModuleList()
195
+ self.up_attns = nn.ModuleList()
196
+ self.upsamplers = nn.ModuleList()
197
+
198
+ for level in reversed(range(len(channel_mult))):
199
+ mult = channel_mult[level]
200
+ out_ch = model_channels * mult
201
+ for i in range(num_res_blocks + 1):
202
+ skip_ch = input_block_channels.pop()
203
+ block = ResBlock(ch + skip_ch, time_dim, out_ch, dropout, use_scale_shift_norm)
204
+ self.up_blocks.append(block)
205
+ if ds in attention_resolutions:
206
+ self.up_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels))
207
+ else:
208
+ self.up_attns.append(nn.Identity())
209
+ ch = out_ch
210
+ if level > 0:
211
+ self.upsamplers.append(Upsample(ch))
212
+ ds *= 2
213
+ else:
214
+ self.upsamplers.append(nn.Identity())
215
+
216
+ self.out_norm = nn.GroupNorm(32, ch)
217
+ self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1)
218
+ nn.init.zeros_(self.out_conv.weight)
219
+ nn.init.zeros_(self.out_conv.bias)
220
+
221
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
222
+ emb = self.time_embed(t * 1000.0)
223
+ h = self.input_conv(x)
224
+ skips = [h]
225
+
226
+ block_idx = 0
227
+ for level in range(len(self.downsamplers)):
228
+ for _ in range(self._get_num_res_blocks()):
229
+ if block_idx < len(self.down_blocks):
230
+ h = self.down_blocks[block_idx](h, emb)
231
+ h = self.down_attns[block_idx](h)
232
+ skips.append(h)
233
+ block_idx += 1
234
+ if not isinstance(self.downsamplers[level], nn.Identity):
235
+ h = self.downsamplers[level](h)
236
+ skips.append(h)
237
+
238
+ h = self.mid_block1(h, emb)
239
+ h = self.mid_attn(h)
240
+ h = self.mid_block2(h, emb)
241
+
242
+ block_idx = 0
243
+ for level in range(len(self.upsamplers)):
244
+ for _ in range(self._get_num_res_blocks() + 1):
245
+ if block_idx < len(self.up_blocks):
246
+ skip = skips.pop()
247
+ h = torch.cat([h, skip], dim=1)
248
+ h = self.up_blocks[block_idx](h, emb)
249
+ h = self.up_attns[block_idx](h)
250
+ block_idx += 1
251
+ if not isinstance(self.upsamplers[level], nn.Identity):
252
+ h = self.upsamplers[level](h)
253
+
254
+ h = self.out_norm(h)
255
+ h = F.silu(h)
256
+ h = self.out_conv(h)
257
+ return h
258
+
259
+ def _get_num_res_blocks(self):
260
+ total_down = len(self.down_blocks)
261
+ num_levels = len(self.downsamplers)
262
+ return total_down // num_levels
263
+
264
+
265
+ class PhaseTransitionPredictor(nn.Module):
266
+ """CNN predicting phase-transition time t_ϕ(x) ∈ [0, 1].
267
+ 4 conv layers (32→64→128→256), 3x3, AvgPool2d, FC + Sigmoid.
268
+ """
269
+ def __init__(self, in_channels: int = 1, image_size: int = 28,
270
+ conv_channels: List[int] = [32, 64, 128, 256]):
271
+ super().__init__()
272
+ layers = []
273
+ ch = in_channels
274
+ for out_ch in conv_channels:
275
+ layers.extend([
276
+ nn.Conv2d(ch, out_ch, kernel_size=3, stride=1, padding=1),
277
+ nn.ReLU(inplace=True),
278
+ nn.AvgPool2d(kernel_size=2, stride=2),
279
+ ])
280
+ ch = out_ch
281
+ self.conv = nn.Sequential(*layers)
282
+ final_size = image_size
283
+ for _ in conv_channels:
284
+ final_size = final_size // 2
285
+ self.fc_input_dim = conv_channels[-1] * final_size * final_size
286
+ self.fc = nn.Linear(self.fc_input_dim, 1)
287
+ self.sigmoid = nn.Sigmoid()
288
+
289
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
290
+ h = self.conv(x)
291
+ h = h.view(h.size(0), -1)
292
+ h = self.fc(h)
293
+ return self.sigmoid(h).squeeze(-1)
294
+
295
+
296
+ # Factory functions
297
+ def create_velocity_model_2d(config: dict) -> VelocityMLP:
298
+ model_cfg = config.get("model", {})
299
+ return VelocityMLP(
300
+ input_dim=model_cfg.get("input_dim", 2),
301
+ hidden_dim=model_cfg.get("hidden_dim", 256),
302
+ num_hidden_layers=model_cfg.get("num_hidden_layers", 3),
303
+ time_emb_dim=model_cfg.get("time_emb_dim", 64),
304
+ )
305
+
306
+ def create_velocity_unet(config: dict) -> VelocityUNet:
307
+ unet_cfg = config.get("unet", {})
308
+ return VelocityUNet(
309
+ image_size=config.get("image_size", 32),
310
+ in_channels=config.get("in_channels", 3),
311
+ model_channels=unet_cfg.get("model_channels", 128),
312
+ num_res_blocks=unet_cfg.get("num_res_blocks", 2),
313
+ channel_mult=unet_cfg.get("channel_mult", [1, 2, 2, 2]),
314
+ attention_resolutions=unet_cfg.get("attention_resolutions", [16]),
315
+ num_heads=unet_cfg.get("num_heads", 4),
316
+ num_head_channels=unet_cfg.get("num_head_channels", 64),
317
+ dropout=unet_cfg.get("dropout", 0.0),
318
+ use_scale_shift_norm=unet_cfg.get("use_scale_shift_norm", True),
319
+ )
320
+
321
+ def create_phase_predictor(config: dict) -> PhaseTransitionPredictor:
322
+ tp_cfg = config.get("time_predictor", {})
323
+ return PhaseTransitionPredictor(
324
+ in_channels=config.get("in_channels", 1),
325
+ image_size=config.get("image_size", 28),
326
+ conv_channels=tp_cfg.get("conv_channels", [32, 64, 128, 256]),
327
+ )