rogermt commited on
Commit
950b617
·
verified ·
1 Parent(s): 4b9c417

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +138 -44
main.py CHANGED
@@ -1,10 +1,19 @@
1
  """main.py — Entry point for NSGF/NSGF++ experiments.
2
 
 
 
 
 
 
 
 
 
3
  Usage:
4
  python main.py --experiment 2d --dataset 8gaussians --steps 10
5
- python main.py --experiment 2d --dataset moons --steps 100
6
- python main.py --experiment mnist
7
- python main.py --experiment cifar10
 
8
 
9
  Reference: arXiv:2401.14069 (Neural Sinkhorn Gradient Flow)
10
  """
@@ -38,15 +47,28 @@ logger = logging.getLogger(__name__)
38
 
39
 
40
  def load_config(config_path: str = "config.yaml") -> dict:
 
41
  with open(config_path, "r") as f:
42
  return yaml.safe_load(f)
43
 
44
 
 
 
 
 
 
 
 
45
  def run_2d_experiment(config: dict, args):
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
47
  logger.info(f"Running 2D experiment on {device}")
48
  logger.info(f"Dataset: {config['dataset']}, Steps: {config['sinkhorn']['num_steps']}")
49
-
 
50
  if args.dataset:
51
  config["dataset"] = args.dataset
52
  if args.steps:
@@ -56,92 +78,143 @@ def run_2d_experiment(config: dict, args):
56
  config["pool"]["num_batches"] = args.pool_batches
57
  if args.train_iters:
58
  config["training"]["num_iterations"] = args.train_iters
59
-
 
60
  data_loader = DatasetLoader(config)
61
  model = create_velocity_model_2d(config)
 
62
  logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
63
-
 
64
  start_time = time.time()
65
- trainer = NSGFTrainer(model=model, data_loader=data_loader, config=config, device=device)
 
 
 
 
 
 
 
 
66
  trainer.build_trajectory_pool()
 
 
67
  history = trainer.train()
 
68
  train_time = time.time() - start_time
69
  logger.info(f"Training completed in {train_time:.1f}s")
70
-
 
71
  num_eval = config.get("evaluation", {}).get("num_test_samples", 1024)
72
  num_steps = config.get("inference", {}).get("num_euler_steps", 10)
73
- sampler = NSGFSampler(model=model, data_loader=data_loader, num_steps=num_steps, device=device)
 
 
 
 
 
 
 
74
  samples = sampler.sample(num_eval)
 
 
75
  trajectory = sampler.sample_trajectory(min(200, num_eval))
76
-
 
77
  test_samples = data_loader.get_test_samples(num_eval, device)
78
  evaluator = Evaluation(config, device)
79
  metrics = evaluator.evaluate(samples, test_samples)
80
-
81
  logger.info(f"\n{'='*50}")
82
  logger.info(f"RESULTS — 2D {config['dataset']}, {num_steps} steps")
83
  logger.info(f"{'='*50}")
84
  for k, v in metrics.items():
85
  logger.info(f" {k}: {v:.4f}")
86
  logger.info(f" Training time: {train_time:.1f}s")
87
-
 
88
  os.makedirs("results", exist_ok=True)
89
  plot_2d_samples(
90
  samples, test_samples,
91
  title=f"NSGF — {config['dataset']} ({num_steps} steps), W2={metrics.get('w2', 0):.4f}",
92
  save_path=f"results/nsgf_2d_{config['dataset']}_{num_steps}steps.png",
93
  )
 
94
  plot_2d_trajectory(
95
  trajectory, test_samples,
96
  title=f"NSGF Trajectory — {config['dataset']}",
97
  save_path=f"results/nsgf_trajectory_{config['dataset']}_{num_steps}steps.png",
98
  )
 
 
99
  torch.save(model.state_dict(), f"results/nsgf_2d_{config['dataset']}.pt")
100
  logger.info("Model saved.")
 
101
  return metrics
102
 
103
 
104
  def run_image_experiment(config: dict, args, dataset_name: str):
105
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
106
  logger.info(f"Running {dataset_name.upper()} experiment on {device}")
107
-
 
108
  if args.pool_batches:
109
  config["pool"]["num_batches"] = args.pool_batches
110
  if args.train_iters:
111
  config["nsgf_training"]["num_iterations"] = args.train_iters
112
  config["nsf_training"]["num_iterations"] = args.train_iters
113
-
 
114
  data_loader = DatasetLoader(config)
 
 
115
  nsgf_model = create_velocity_unet(config)
116
  nsf_model = create_velocity_unet(config)
117
  phase_predictor = create_phase_predictor(config)
118
-
119
  logger.info(f"NSGF UNet params: {sum(p.numel() for p in nsgf_model.parameters()):,}")
120
  logger.info(f"NSF UNet params: {sum(p.numel() for p in nsf_model.parameters()):,}")
121
  logger.info(f"Phase predictor params: {sum(p.numel() for p in phase_predictor.parameters()):,}")
122
-
 
123
  start_time = time.time()
 
124
  pp_trainer = NSGFPlusPlusTrainer(
125
- nsgf_model=nsgf_model, nsf_model=nsf_model,
126
- phase_predictor=phase_predictor, data_loader=data_loader,
127
- config=config, device=device,
 
 
 
128
  )
 
129
  results = pp_trainer.train_all()
 
130
  train_time = time.time() - start_time
131
  logger.info(f"Training completed in {train_time:.1f}s")
132
-
 
133
  inference_cfg = config.get("inference", {})
134
  nsgf_steps = inference_cfg.get("nsgf_steps", 5)
135
  nsf_steps = inference_cfg.get("nsf_steps", 55)
136
  num_gen = config.get("evaluation", {}).get("num_generated", 10000)
137
-
138
  sampler = NSGFPlusPlusSampler(
139
- nsgf_model=nsgf_model, nsf_model=nsf_model,
140
- phase_predictor=phase_predictor, data_loader=data_loader,
141
- nsgf_steps=nsgf_steps, nsf_steps=nsf_steps, device=device,
 
 
 
 
142
  )
143
-
144
  logger.info(f"Generating {num_gen} samples...")
 
145
  batch_size = 128
146
  all_samples = []
147
  for i in range(0, num_gen, batch_size):
@@ -149,11 +222,15 @@ def run_image_experiment(config: dict, args, dataset_name: str):
149
  samples = sampler.sample_simple(n)
150
  all_samples.append(samples.cpu())
151
  generated = torch.cat(all_samples, dim=0)
152
-
153
- test_samples = data_loader.get_test_samples(num_gen, device="cpu")
 
 
 
 
154
  evaluator = Evaluation(config, device)
155
  metrics = evaluator.evaluate(generated, test_samples)
156
-
157
  logger.info(f"\n{'='*50}")
158
  logger.info(f"RESULTS — NSGF++ on {dataset_name.upper()}")
159
  logger.info(f"{'='*50}")
@@ -161,43 +238,60 @@ def run_image_experiment(config: dict, args, dataset_name: str):
161
  logger.info(f" {k}: {v:.4f}")
162
  logger.info(f" NFE: {nsgf_steps + nsf_steps}")
163
  logger.info(f" Training time: {train_time:.1f}s")
164
-
 
165
  os.makedirs("results", exist_ok=True)
166
  plot_image_grid(
167
  generated[:64],
168
  title=f"NSGF++ — {dataset_name.upper()}",
169
  save_path=f"results/nsgf_pp_{dataset_name}_samples.png",
170
  )
 
 
171
  torch.save(nsgf_model.state_dict(), f"results/nsgf_{dataset_name}_nsgf.pt")
172
  torch.save(nsf_model.state_dict(), f"results/nsgf_{dataset_name}_nsf.pt")
173
  torch.save(phase_predictor.state_dict(), f"results/nsgf_{dataset_name}_predictor.pt")
174
  logger.info("Models saved.")
 
175
  return metrics
176
 
177
 
178
  def main():
179
  parser = argparse.ArgumentParser(description="NSGF/NSGF++ Experiments")
180
- parser.add_argument("--experiment", type=str, default="2d",
181
- choices=["2d", "mnist", "cifar10"])
182
- parser.add_argument("--dataset", type=str, default=None)
183
- parser.add_argument("--steps", type=int, default=None)
184
- parser.add_argument("--pool-batches", type=int, default=None)
185
- parser.add_argument("--train-iters", type=int, default=None)
186
- parser.add_argument("--config", type=str, default="config.yaml")
187
- parser.add_argument("--seed", type=int, default=42)
 
 
 
 
 
 
 
188
  args = parser.parse_args()
189
-
 
190
  torch.manual_seed(args.seed)
191
  import numpy as np
192
  np.random.seed(args.seed)
193
-
 
194
  full_config = load_config(args.config)
 
195
  if args.experiment == "2d":
196
- run_2d_experiment(full_config["experiment_2d"], args)
 
197
  elif args.experiment == "mnist":
198
- run_image_experiment(full_config["experiment_mnist"], args, "mnist")
 
199
  elif args.experiment == "cifar10":
200
- run_image_experiment(full_config["experiment_cifar10"], args, "cifar10")
 
201
  else:
202
  logger.error(f"Unknown experiment: {args.experiment}")
203
  sys.exit(1)
 
1
  """main.py — Entry point for NSGF/NSGF++ experiments.
2
 
3
+ Orchestrates the full experiment pipeline:
4
+ 1. Load configuration
5
+ 2. Set up dataset, model, trainer
6
+ 3. Train (build pool → velocity matching → [NSF → phase predictor for NSGF++])
7
+ 4. Generate samples
8
+ 5. Evaluate (W2 for 2D, FID/IS for images)
9
+ 6. Visualize results
10
+
11
  Usage:
12
  python main.py --experiment 2d --dataset 8gaussians --steps 10
13
+ python main.py --experiment 2d --dataset 8gaussians --steps 10 --device cuda
14
+ python main.py --experiment 2d --dataset 8gaussians --steps 10 --device cpu
15
+ python main.py --experiment mnist --device cuda
16
+ python main.py --experiment cifar10 --device cuda
17
 
18
  Reference: arXiv:2401.14069 (Neural Sinkhorn Gradient Flow)
19
  """
 
47
 
48
 
49
  def load_config(config_path: str = "config.yaml") -> dict:
50
+ """Load configuration from YAML file."""
51
  with open(config_path, "r") as f:
52
  return yaml.safe_load(f)
53
 
54
 
55
+ def get_device(args) -> str:
56
+ """Resolve device from CLI args or auto-detect."""
57
+ if args.device:
58
+ return args.device
59
+ return "cuda" if torch.cuda.is_available() else "cpu"
60
+
61
+
62
  def run_2d_experiment(config: dict, args):
63
+ """Run 2D synthetic experiment (NSGF).
64
+
65
+ Reference: Section 5.1, Appendix E.1
66
+ """
67
+ device = get_device(args)
68
  logger.info(f"Running 2D experiment on {device}")
69
  logger.info(f"Dataset: {config['dataset']}, Steps: {config['sinkhorn']['num_steps']}")
70
+
71
+ # Override from args
72
  if args.dataset:
73
  config["dataset"] = args.dataset
74
  if args.steps:
 
78
  config["pool"]["num_batches"] = args.pool_batches
79
  if args.train_iters:
80
  config["training"]["num_iterations"] = args.train_iters
81
+
82
+ # Setup
83
  data_loader = DatasetLoader(config)
84
  model = create_velocity_model_2d(config)
85
+
86
  logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
87
+
88
+ # ---- Training ----
89
  start_time = time.time()
90
+
91
+ trainer = NSGFTrainer(
92
+ model=model,
93
+ data_loader=data_loader,
94
+ config=config,
95
+ device=device,
96
+ )
97
+
98
+ # Build trajectory pool
99
  trainer.build_trajectory_pool()
100
+
101
+ # Train velocity field
102
  history = trainer.train()
103
+
104
  train_time = time.time() - start_time
105
  logger.info(f"Training completed in {train_time:.1f}s")
106
+
107
+ # ---- Inference ----
108
  num_eval = config.get("evaluation", {}).get("num_test_samples", 1024)
109
  num_steps = config.get("inference", {}).get("num_euler_steps", 10)
110
+
111
+ sampler = NSGFSampler(
112
+ model=model,
113
+ data_loader=data_loader,
114
+ num_steps=num_steps,
115
+ device=device,
116
+ )
117
+
118
  samples = sampler.sample(num_eval)
119
+
120
+ # Also get trajectory for visualization
121
  trajectory = sampler.sample_trajectory(min(200, num_eval))
122
+
123
+ # ---- Evaluation ----
124
  test_samples = data_loader.get_test_samples(num_eval, device)
125
  evaluator = Evaluation(config, device)
126
  metrics = evaluator.evaluate(samples, test_samples)
127
+
128
  logger.info(f"\n{'='*50}")
129
  logger.info(f"RESULTS — 2D {config['dataset']}, {num_steps} steps")
130
  logger.info(f"{'='*50}")
131
  for k, v in metrics.items():
132
  logger.info(f" {k}: {v:.4f}")
133
  logger.info(f" Training time: {train_time:.1f}s")
134
+
135
+ # ---- Visualization ----
136
  os.makedirs("results", exist_ok=True)
137
  plot_2d_samples(
138
  samples, test_samples,
139
  title=f"NSGF — {config['dataset']} ({num_steps} steps), W2={metrics.get('w2', 0):.4f}",
140
  save_path=f"results/nsgf_2d_{config['dataset']}_{num_steps}steps.png",
141
  )
142
+
143
  plot_2d_trajectory(
144
  trajectory, test_samples,
145
  title=f"NSGF Trajectory — {config['dataset']}",
146
  save_path=f"results/nsgf_trajectory_{config['dataset']}_{num_steps}steps.png",
147
  )
148
+
149
+ # Save model
150
  torch.save(model.state_dict(), f"results/nsgf_2d_{config['dataset']}.pt")
151
  logger.info("Model saved.")
152
+
153
  return metrics
154
 
155
 
156
  def run_image_experiment(config: dict, args, dataset_name: str):
157
+ """Run image experiment (NSGF++).
158
+
159
+ Reference: Section 5.2, Appendix E.2
160
+ """
161
+ device = get_device(args)
162
  logger.info(f"Running {dataset_name.upper()} experiment on {device}")
163
+
164
+ # Override from args
165
  if args.pool_batches:
166
  config["pool"]["num_batches"] = args.pool_batches
167
  if args.train_iters:
168
  config["nsgf_training"]["num_iterations"] = args.train_iters
169
  config["nsf_training"]["num_iterations"] = args.train_iters
170
+
171
+ # Setup
172
  data_loader = DatasetLoader(config)
173
+
174
+ # Create models
175
  nsgf_model = create_velocity_unet(config)
176
  nsf_model = create_velocity_unet(config)
177
  phase_predictor = create_phase_predictor(config)
178
+
179
  logger.info(f"NSGF UNet params: {sum(p.numel() for p in nsgf_model.parameters()):,}")
180
  logger.info(f"NSF UNet params: {sum(p.numel() for p in nsf_model.parameters()):,}")
181
  logger.info(f"Phase predictor params: {sum(p.numel() for p in phase_predictor.parameters()):,}")
182
+
183
+ # ---- Training ----
184
  start_time = time.time()
185
+
186
  pp_trainer = NSGFPlusPlusTrainer(
187
+ nsgf_model=nsgf_model,
188
+ nsf_model=nsf_model,
189
+ phase_predictor=phase_predictor,
190
+ data_loader=data_loader,
191
+ config=config,
192
+ device=device,
193
  )
194
+
195
  results = pp_trainer.train_all()
196
+
197
  train_time = time.time() - start_time
198
  logger.info(f"Training completed in {train_time:.1f}s")
199
+
200
+ # ---- Inference ----
201
  inference_cfg = config.get("inference", {})
202
  nsgf_steps = inference_cfg.get("nsgf_steps", 5)
203
  nsf_steps = inference_cfg.get("nsf_steps", 55)
204
  num_gen = config.get("evaluation", {}).get("num_generated", 10000)
205
+
206
  sampler = NSGFPlusPlusSampler(
207
+ nsgf_model=nsgf_model,
208
+ nsf_model=nsf_model,
209
+ phase_predictor=phase_predictor,
210
+ data_loader=data_loader,
211
+ nsgf_steps=nsgf_steps,
212
+ nsf_steps=nsf_steps,
213
+ device=device,
214
  )
215
+
216
  logger.info(f"Generating {num_gen} samples...")
217
+ # Generate in batches to avoid OOM
218
  batch_size = 128
219
  all_samples = []
220
  for i in range(0, num_gen, batch_size):
 
222
  samples = sampler.sample_simple(n)
223
  all_samples.append(samples.cpu())
224
  generated = torch.cat(all_samples, dim=0)
225
+
226
+ # ---- Evaluation ----
227
+ # Get test set
228
+ eval_loader = data_loader
229
+ test_samples = eval_loader.get_test_samples(num_gen, device="cpu")
230
+
231
  evaluator = Evaluation(config, device)
232
  metrics = evaluator.evaluate(generated, test_samples)
233
+
234
  logger.info(f"\n{'='*50}")
235
  logger.info(f"RESULTS — NSGF++ on {dataset_name.upper()}")
236
  logger.info(f"{'='*50}")
 
238
  logger.info(f" {k}: {v:.4f}")
239
  logger.info(f" NFE: {nsgf_steps + nsf_steps}")
240
  logger.info(f" Training time: {train_time:.1f}s")
241
+
242
+ # ---- Visualization ----
243
  os.makedirs("results", exist_ok=True)
244
  plot_image_grid(
245
  generated[:64],
246
  title=f"NSGF++ — {dataset_name.upper()}",
247
  save_path=f"results/nsgf_pp_{dataset_name}_samples.png",
248
  )
249
+
250
+ # Save models
251
  torch.save(nsgf_model.state_dict(), f"results/nsgf_{dataset_name}_nsgf.pt")
252
  torch.save(nsf_model.state_dict(), f"results/nsgf_{dataset_name}_nsf.pt")
253
  torch.save(phase_predictor.state_dict(), f"results/nsgf_{dataset_name}_predictor.pt")
254
  logger.info("Models saved.")
255
+
256
  return metrics
257
 
258
 
259
  def main():
260
  parser = argparse.ArgumentParser(description="NSGF/NSGF++ Experiments")
261
+ parser.add_argument(
262
+ "--experiment", type=str, default="2d",
263
+ choices=["2d", "mnist", "cifar10"],
264
+ help="Experiment type"
265
+ )
266
+ parser.add_argument("--dataset", type=str, default=None, help="2D dataset name")
267
+ parser.add_argument("--steps", type=int, default=None, help="Number of flow steps")
268
+ parser.add_argument("--pool-batches", type=int, default=None, help="Pool building batches")
269
+ parser.add_argument("--train-iters", type=int, default=None, help="Training iterations")
270
+ parser.add_argument("--config", type=str, default="config.yaml", help="Config file path")
271
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
272
+ parser.add_argument("--device", type=str, default=None,
273
+ choices=["cpu", "cuda"],
274
+ help="Force device (default: auto-detect)")
275
+
276
  args = parser.parse_args()
277
+
278
+ # Set seed
279
  torch.manual_seed(args.seed)
280
  import numpy as np
281
  np.random.seed(args.seed)
282
+
283
+ # Load config
284
  full_config = load_config(args.config)
285
+
286
  if args.experiment == "2d":
287
+ config = full_config["experiment_2d"]
288
+ run_2d_experiment(config, args)
289
  elif args.experiment == "mnist":
290
+ config = full_config["experiment_mnist"]
291
+ run_image_experiment(config, args, "mnist")
292
  elif args.experiment == "cifar10":
293
+ config = full_config["experiment_cifar10"]
294
+ run_image_experiment(config, args, "cifar10")
295
  else:
296
  logger.error(f"Unknown experiment: {args.experiment}")
297
  sys.exit(1)