rogermt commited on
Commit
bc9ef03
·
verified ·
1 Parent(s): 81d2fba

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +126 -0
inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """inference.py — Sampling / inference for NSGF and NSGF++.
2
+
3
+ Implements:
4
+ - NSGF Euler-step inference (standard model)
5
+ - NSGF++ two-phase inference (NSGF → phase transition → NSF)
6
+
7
+ Reference: arXiv:2401.14069, Section 4.4, Appendix D
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Tuple, List
13
+ from dataset_loader import DatasetLoader
14
+
15
+
16
+ class NSGFSampler:
17
+ """Sampler using a trained NSGF velocity field model."""
18
+ def __init__(self, model: nn.Module, data_loader: DatasetLoader,
19
+ num_steps: int = 10, device: str = "cpu"):
20
+ self.model = model.to(device)
21
+ self.model.eval()
22
+ self.data_loader = data_loader
23
+ self.num_steps = num_steps
24
+ self.device = device
25
+
26
+ @torch.no_grad()
27
+ def sample(self, n: int) -> torch.Tensor:
28
+ X = self.data_loader.sample_source(n, self.device)
29
+ dt = 1.0 / self.num_steps
30
+ for step in range(self.num_steps):
31
+ t = torch.full((n,), step * dt, device=self.device)
32
+ v = self.model(X, t)
33
+ X = X + dt * v
34
+ return X
35
+
36
+ @torch.no_grad()
37
+ def sample_trajectory(self, n: int) -> List[torch.Tensor]:
38
+ X = self.data_loader.sample_source(n, self.device)
39
+ trajectory = [X.clone()]
40
+ dt = 1.0 / self.num_steps
41
+ for step in range(self.num_steps):
42
+ t = torch.full((n,), step * dt, device=self.device)
43
+ v = self.model(X, t)
44
+ X = X + dt * v
45
+ trajectory.append(X.clone())
46
+ return trajectory
47
+
48
+
49
+ class NSGFPlusPlusSampler:
50
+ """Sampler for the NSGF++ two-phase model.
51
+ Phase 1 (NSGF): ≤5 Euler steps with Sinkhorn velocity field
52
+ Phase 2 (NSF): Straight flow velocity field
53
+ Total NFE = nsgf_steps + nsf_steps
54
+ """
55
+ def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module,
56
+ phase_predictor: Optional[nn.Module], data_loader: DatasetLoader,
57
+ nsgf_steps: int = 5, nsf_steps: int = 55, device: str = "cpu"):
58
+ self.nsgf_model = nsgf_model.to(device)
59
+ self.nsf_model = nsf_model.to(device)
60
+ self.nsgf_model.eval()
61
+ self.nsf_model.eval()
62
+ if phase_predictor is not None:
63
+ self.phase_predictor = phase_predictor.to(device)
64
+ self.phase_predictor.eval()
65
+ else:
66
+ self.phase_predictor = None
67
+ self.data_loader = data_loader
68
+ self.nsgf_steps = nsgf_steps
69
+ self.nsf_steps = nsf_steps
70
+ self.device = device
71
+
72
+ @torch.no_grad()
73
+ def sample(self, n: int) -> torch.Tensor:
74
+ X = self.data_loader.sample_source(n, self.device)
75
+ dt_nsgf = 1.0 / self.nsgf_steps
76
+ for step in range(self.nsgf_steps):
77
+ t = torch.full((n,), step * dt_nsgf, device=self.device)
78
+ v = self.nsgf_model(X, t)
79
+ X = X + dt_nsgf * v
80
+ if self.phase_predictor is not None:
81
+ t_start = self.phase_predictor(X)
82
+ else:
83
+ t_start = torch.zeros(n, device=self.device)
84
+ dt_nsf = 1.0 / self.nsf_steps
85
+ for step in range(self.nsf_steps):
86
+ t_current = t_start + step * dt_nsf * (1.0 - t_start)
87
+ t_current = t_current.clamp(0, 1)
88
+ v = self.nsf_model(X, t_current)
89
+ X = X + dt_nsf * (1.0 - t_start.view(-1, *([1] * (X.dim() - 1)))) * v
90
+ return X
91
+
92
+ @torch.no_grad()
93
+ def sample_simple(self, n: int) -> torch.Tensor:
94
+ """Simplified: NSGF then NSF from t=0 to t=1."""
95
+ X = self.data_loader.sample_source(n, self.device)
96
+ dt_nsgf = 1.0 / self.nsgf_steps
97
+ for step in range(self.nsgf_steps):
98
+ t = torch.full((n,), step * dt_nsgf, device=self.device)
99
+ v = self.nsgf_model(X, t)
100
+ X = X + dt_nsgf * v
101
+ dt_nsf = 1.0 / self.nsf_steps
102
+ for step in range(self.nsf_steps):
103
+ t = torch.full((n,), step * dt_nsf, device=self.device)
104
+ v = self.nsf_model(X, t)
105
+ X = X + dt_nsf * v
106
+ return X
107
+
108
+ @torch.no_grad()
109
+ def sample_trajectory(self, n: int) -> Tuple[List[torch.Tensor], int]:
110
+ trajectory = []
111
+ X = self.data_loader.sample_source(n, self.device)
112
+ trajectory.append(X.clone())
113
+ dt_nsgf = 1.0 / self.nsgf_steps
114
+ for step in range(self.nsgf_steps):
115
+ t = torch.full((n,), step * dt_nsgf, device=self.device)
116
+ v = self.nsgf_model(X, t)
117
+ X = X + dt_nsgf * v
118
+ trajectory.append(X.clone())
119
+ phase_boundary = len(trajectory) - 1
120
+ dt_nsf = 1.0 / self.nsf_steps
121
+ for step in range(self.nsf_steps):
122
+ t = torch.full((n,), step * dt_nsf, device=self.device)
123
+ v = self.nsf_model(X, t)
124
+ X = X + dt_nsf * v
125
+ trajectory.append(X.clone())
126
+ return trajectory, phase_boundary