| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| |
| |
| class VectorField(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(3, 64), |
| nn.Tanh(), |
| nn.Linear(64, 64), |
| nn.Tanh(), |
| nn.Linear(64, 2) |
| ) |
|
|
| def forward(self, x, t): |
| |
| if t.dim() == 0: t = t.expand(x.shape[0], 1) |
| elif t.dim() == 1: t = t.view(-1, 1) |
| |
| xt = torch.cat([x, t], dim=1) |
| return self.net(xt) |
|
|
| |
| model = VectorField() |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) |
|
|
| |
| def sample_data(batch_size): |
| indices = torch.randint(0, 2, (batch_size,)) |
| centers = torch.tensor([[-2., -2.], [2., 2.]]) |
| noise = torch.randn(batch_size, 2) * 0.5 |
| return centers[indices] + noise |
|
|
| |
| def sample_source(batch_size): |
| return torch.randn(batch_size, 2) |
|
|
| |
| print("Training Flow Matching Model...") |
| for step in range(2000): |
| batch_size = 256 |
| |
| |
| x0 = sample_source(batch_size) |
| x1 = sample_data(batch_size) |
| |
| |
| t = torch.rand(batch_size, 1) |
| |
| |
| |
| x_t = (1 - t) * x0 + t * x1 |
| |
| |
| |
| target_velocity = x1 - x0 |
| |
| |
| pred_velocity = model(x_t, t) |
| |
| |
| loss = torch.mean((pred_velocity - target_velocity) ** 2) |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| if step % 500 == 0: |
| print(f"Step {step}: Loss = {loss.item():.4f}") |
|
|
| |
| |
| print("\nSampling (solving ODE)...") |
| with torch.no_grad(): |
| x = sample_source(1000) |
| dt = 0.01 |
| |
| for t_step in np.arange(0, 1, dt): |
| t_tensor = torch.full((x.shape[0], 1), t_step) |
| velocity = model(x, t_tensor) |
| x = x + velocity * dt |
|
|
| |
| final_samples = x.numpy() |
| plt.figure(figsize=(6, 6)) |
| plt.scatter(final_samples[:, 0], final_samples[:, 1], s=10, alpha=0.6, label="Generated") |
| plt.title("Flow Matching Output (Approx. Data Dist.)") |
| plt.grid(True) |
| plt.tight_layout() |
| plt.savefig("flow_matching_output.png") |
| plt.close() |
|
|