rogermt commited on
Commit
81d2fba
·
verified ·
1 Parent(s): abe114d

Upload trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +268 -0
trainer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """trainer.py — Training procedures for NSGF and NSGF++.
2
+
3
+ Implements:
4
+ 1. Trajectory pool construction (Phase 1: Sinkhorn gradient flow)
5
+ 2. NSGF velocity field matching training
6
+ 3. NSF (Neural Straight Flow) training for NSGF++
7
+ 4. Phase-transition time predictor training
8
+ 5. End-to-end NSGF++ training pipeline
9
+
10
+ Reference: arXiv:2401.14069, Section 4.2–4.4, Appendix D, E
11
+ """
12
+
13
+ import os
14
+ import logging
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.optim as optim
18
+ from typing import Optional, Dict, Any, Tuple
19
+
20
+ from dataset_loader import DatasetLoader
21
+ from sinkhorn_flow import (
22
+ SinkhornPotentialComputer, SinkhornGradientFlow, TrajectoryPool,
23
+ )
24
+ from model import VelocityMLP, VelocityUNet, PhaseTransitionPredictor
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class NSGFTrainer:
30
+ """Trainer for the Neural Sinkhorn Gradient Flow model.
31
+
32
+ Loss (Eq. 14): L(θ) = E_{(x,v,t) ~ pool} ||v_θ(x, t) - v̂(x)||²
33
+ """
34
+ def __init__(self, model: nn.Module, data_loader: DatasetLoader,
35
+ config: dict, device: str = "cpu"):
36
+ self.model = model.to(device)
37
+ self.data_loader = data_loader
38
+ self.config = config
39
+ self.device = device
40
+
41
+ sink_cfg = config.get("sinkhorn", {})
42
+ self.potential_computer = SinkhornPotentialComputer(
43
+ blur=sink_cfg.get("blur", 0.5), scaling=sink_cfg.get("scaling", 0.80),
44
+ )
45
+ self.gradient_flow = SinkhornGradientFlow(
46
+ potential_computer=self.potential_computer,
47
+ eta=sink_cfg.get("eta", 1.0), num_steps=sink_cfg.get("num_steps", 5),
48
+ )
49
+ self.pool = TrajectoryPool(max_size=5_000_000)
50
+
51
+ train_cfg = config.get("training", config.get("nsgf_training", {}))
52
+ self.num_iterations = train_cfg.get("num_iterations", 20000)
53
+ self.train_batch_size = train_cfg.get("batch_size", 256)
54
+ self.lr = train_cfg.get("learning_rate", 1e-3)
55
+ self.optimizer = optim.Adam(
56
+ self.model.parameters(), lr=self.lr,
57
+ betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
58
+ weight_decay=train_cfg.get("weight_decay", 0.0),
59
+ )
60
+
61
+ def build_trajectory_pool(self, num_batches: Optional[int] = None):
62
+ if num_batches is None:
63
+ num_batches = self.config.get("pool", {}).get("num_batches", 200)
64
+ sink_batch_size = self.config.get("sinkhorn", {}).get("batch_size", 256)
65
+ logger.info(
66
+ f"Building trajectory pool: {num_batches} batches × "
67
+ f"{sink_batch_size} samples × {self.gradient_flow.num_steps} steps"
68
+ )
69
+ for batch_idx in range(num_batches):
70
+ X0 = self.data_loader.sample_source(sink_batch_size, self.device)
71
+ Y = self.data_loader.sample_target(sink_batch_size, self.device)
72
+ _, trajectory = self.gradient_flow.run_flow(X0, Y, store_trajectory=True)
73
+ self.pool.add_trajectory(trajectory)
74
+ if (batch_idx + 1) % max(1, num_batches // 10) == 0:
75
+ logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
76
+ logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
77
+
78
+ def train(self) -> Dict[str, list]:
79
+ self.model.train()
80
+ history = {"loss": [], "step": []}
81
+ logger.info(f"Starting NSGF velocity field matching: {self.num_iterations} iterations")
82
+ for step in range(self.num_iterations):
83
+ x_batch, v_batch, t_batch = self.pool.sample(self.train_batch_size, self.device)
84
+ t_normalized = t_batch / max(self.gradient_flow.num_steps, 1.0)
85
+ v_pred = self.model(x_batch, t_normalized)
86
+ loss = ((v_pred - v_batch) ** 2).mean()
87
+ self.optimizer.zero_grad()
88
+ loss.backward()
89
+ self.optimizer.step()
90
+ if (step + 1) % 500 == 0 or step == 0:
91
+ loss_val = loss.item()
92
+ history["loss"].append(loss_val)
93
+ history["step"].append(step + 1)
94
+ logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
95
+ logger.info("NSGF training complete.")
96
+ return history
97
+
98
+
99
+ class NSFTrainer:
100
+ """Trainer for Neural Straight Flow (Phase 2 of NSGF++).
101
+
102
+ Straight flow: X_t = (1-t)*P_0 + t*P_1, target velocity = P_1 - P_0
103
+ """
104
+ def __init__(self, model: nn.Module, nsgf_model: nn.Module,
105
+ data_loader: DatasetLoader, config: dict,
106
+ nsgf_num_steps: int = 5, device: str = "cpu"):
107
+ self.model = model.to(device)
108
+ self.nsgf_model = nsgf_model.to(device)
109
+ self.nsgf_model.eval()
110
+ self.data_loader = data_loader
111
+ self.config = config
112
+ self.device = device
113
+ self.nsgf_num_steps = nsgf_num_steps
114
+
115
+ train_cfg = config.get("nsf_training", config.get("training", {}))
116
+ self.num_iterations = train_cfg.get("num_iterations", 100000)
117
+ self.train_batch_size = train_cfg.get("batch_size", 128)
118
+ self.lr = train_cfg.get("learning_rate", 1e-4)
119
+ self.optimizer = optim.Adam(
120
+ self.model.parameters(), lr=self.lr,
121
+ betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
122
+ weight_decay=train_cfg.get("weight_decay", 0.0),
123
+ )
124
+
125
+ @torch.no_grad()
126
+ def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
127
+ X = self.data_loader.sample_source(n, self.device)
128
+ dt = 1.0 / self.nsgf_num_steps
129
+ for step in range(self.nsgf_num_steps):
130
+ t = torch.full((n,), step * dt, device=self.device)
131
+ v = self.nsgf_model(X, t)
132
+ X = X + dt * v
133
+ return X
134
+
135
+ def train(self) -> Dict[str, list]:
136
+ self.model.train()
137
+ history = {"loss": [], "step": []}
138
+ logger.info(f"Starting NSF training: {self.num_iterations} iterations")
139
+ for step in range(self.num_iterations):
140
+ P0 = self._generate_nsgf_samples(self.train_batch_size)
141
+ P1 = self.data_loader.sample_target(self.train_batch_size, self.device)
142
+ t = torch.rand(self.train_batch_size, device=self.device)
143
+ if P0.dim() == 2:
144
+ t_expand = t.unsqueeze(-1)
145
+ else:
146
+ t_expand = t.view(-1, 1, 1, 1)
147
+ X_t = (1 - t_expand) * P0 + t_expand * P1
148
+ v_target = P1 - P0
149
+ v_pred = self.model(X_t, t)
150
+ loss = ((v_pred - v_target) ** 2).mean()
151
+ self.optimizer.zero_grad()
152
+ loss.backward()
153
+ self.optimizer.step()
154
+ if (step + 1) % 500 == 0 or step == 0:
155
+ loss_val = loss.item()
156
+ history["loss"].append(loss_val)
157
+ history["step"].append(step + 1)
158
+ logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
159
+ logger.info("NSF training complete.")
160
+ return history
161
+
162
+
163
+ class PhaseTransitionTrainer:
164
+ """Trainer for the phase-transition time predictor.
165
+ Loss: L(ϕ) = E_{t~U(0,1)} ||t - t_ϕ(X_t)||²
166
+ """
167
+ def __init__(self, predictor: PhaseTransitionPredictor, nsgf_model: nn.Module,
168
+ data_loader: DatasetLoader, config: dict,
169
+ nsgf_num_steps: int = 5, device: str = "cpu"):
170
+ self.predictor = predictor.to(device)
171
+ self.nsgf_model = nsgf_model.to(device)
172
+ self.nsgf_model.eval()
173
+ self.data_loader = data_loader
174
+ self.config = config
175
+ self.device = device
176
+ self.nsgf_num_steps = nsgf_num_steps
177
+ tp_cfg = config.get("time_predictor", {})
178
+ self.num_iterations = tp_cfg.get("num_iterations", 40000)
179
+ self.batch_size = tp_cfg.get("batch_size", 128)
180
+ self.lr = tp_cfg.get("learning_rate", 1e-4)
181
+ self.optimizer = optim.Adam(self.predictor.parameters(), lr=self.lr, betas=(0.9, 0.999))
182
+
183
+ @torch.no_grad()
184
+ def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
185
+ X = self.data_loader.sample_source(n, self.device)
186
+ dt = 1.0 / self.nsgf_num_steps
187
+ for step in range(self.nsgf_num_steps):
188
+ t = torch.full((n,), step * dt, device=self.device)
189
+ v = self.nsgf_model(X, t)
190
+ X = X + dt * v
191
+ return X
192
+
193
+ def train(self) -> Dict[str, list]:
194
+ self.predictor.train()
195
+ history = {"loss": [], "step": []}
196
+ logger.info(f"Starting phase predictor training: {self.num_iterations} iterations")
197
+ for step in range(self.num_iterations):
198
+ P0 = self._generate_nsgf_samples(self.batch_size)
199
+ P1 = self.data_loader.sample_target(self.batch_size, self.device)
200
+ t = torch.rand(self.batch_size, device=self.device)
201
+ if P0.dim() == 4:
202
+ t_expand = t.view(-1, 1, 1, 1)
203
+ else:
204
+ t_expand = t.unsqueeze(-1)
205
+ X_t = (1 - t_expand) * P0 + t_expand * P1
206
+ t_pred = self.predictor(X_t)
207
+ loss = ((t_pred - t) ** 2).mean()
208
+ self.optimizer.zero_grad()
209
+ loss.backward()
210
+ self.optimizer.step()
211
+ if (step + 1) % 1000 == 0 or step == 0:
212
+ loss_val = loss.item()
213
+ history["loss"].append(loss_val)
214
+ history["step"].append(step + 1)
215
+ logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
216
+ logger.info("Phase predictor training complete.")
217
+ return history
218
+
219
+
220
+ class NSGFPlusPlusTrainer:
221
+ """End-to-end NSGF++ trainer (Algorithm 3 / Appendix D)."""
222
+ def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module,
223
+ phase_predictor: PhaseTransitionPredictor,
224
+ data_loader: DatasetLoader, config: dict, device: str = "cpu"):
225
+ self.nsgf_model = nsgf_model
226
+ self.nsf_model = nsf_model
227
+ self.phase_predictor = phase_predictor
228
+ self.data_loader = data_loader
229
+ self.config = config
230
+ self.device = device
231
+
232
+ def train_all(self) -> Dict[str, Any]:
233
+ results = {}
234
+ logger.info("=" * 60)
235
+ logger.info("Phase 1: Training NSGF model")
236
+ logger.info("=" * 60)
237
+ nsgf_trainer = NSGFTrainer(
238
+ model=self.nsgf_model, data_loader=self.data_loader,
239
+ config=self.config, device=self.device,
240
+ )
241
+ nsgf_trainer.build_trajectory_pool()
242
+ results["nsgf"] = nsgf_trainer.train()
243
+
244
+ logger.info("=" * 60)
245
+ logger.info("Phase 2: Training NSF (Neural Straight Flow) model")
246
+ logger.info("=" * 60)
247
+ nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5)
248
+ nsf_trainer = NSFTrainer(
249
+ model=self.nsf_model, nsgf_model=self.nsgf_model,
250
+ data_loader=self.data_loader, config=self.config,
251
+ nsgf_num_steps=nsgf_steps, device=self.device,
252
+ )
253
+ results["nsf"] = nsf_trainer.train()
254
+
255
+ logger.info("=" * 60)
256
+ logger.info("Phase 3: Training phase-transition time predictor")
257
+ logger.info("=" * 60)
258
+ pt_trainer = PhaseTransitionTrainer(
259
+ predictor=self.phase_predictor, nsgf_model=self.nsgf_model,
260
+ data_loader=self.data_loader, config=self.config,
261
+ nsgf_num_steps=nsgf_steps, device=self.device,
262
+ )
263
+ results["phase_predictor"] = pt_trainer.train()
264
+
265
+ logger.info("=" * 60)
266
+ logger.info("NSGF++ training complete!")
267
+ logger.info("=" * 60)
268
+ return results