Plonk / models /module.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
from typing import Any
import pytorch_lightning as L
import torch
import torch.nn as nn
from hydra.utils import instantiate
import copy
import pandas as pd
import numpy as np
from tqdm import tqdm
from utils.manifolds import Sphere
from torch.func import jacrev, vjp, vmap
from torchdiffeq import odeint
from geoopt import ProductManifold, Euclidean
from models.samplers.riemannian_flow_sampler import ode_riemannian_flow_sampler
class DiffGeolocalizer(L.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.network = instantiate(cfg.network)
# self.network = torch.compile(self.network, fullgraph=True)
self.input_dim = cfg.network.input_dim
self.train_noise_scheduler = instantiate(cfg.train_noise_scheduler)
self.inference_noise_scheduler = instantiate(cfg.inference_noise_scheduler)
self.data_preprocessing = instantiate(cfg.data_preprocessing)
self.cond_preprocessing = instantiate(cfg.cond_preprocessing)
self.preconditioning = instantiate(cfg.preconditioning)
self.ema_network = copy.deepcopy(self.network).requires_grad_(False)
self.ema_network.eval()
self.postprocessing = instantiate(cfg.postprocessing)
self.val_sampler = instantiate(cfg.val_sampler)
self.test_sampler = instantiate(cfg.test_sampler)
self.loss = instantiate(cfg.loss)(
self.train_noise_scheduler,
)
self.val_metrics = instantiate(cfg.val_metrics)
self.test_metrics = instantiate(cfg.test_metrics)
self.manifold = instantiate(cfg.manifold) if hasattr(cfg, "manifold") else None
self.interpolant = cfg.interpolant
def training_step(self, batch, batch_idx):
with torch.no_grad():
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
loss = self.loss(self.preconditioning, self.network, batch).mean()
self.log(
"train/loss",
loss,
sync_dist=True,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
return loss
def on_before_optimizer_step(self, optimizer):
if self.global_step == 0:
no_grad = []
for name, param in self.network.named_parameters():
if param.grad is None:
no_grad.append(name)
if len(no_grad) > 0:
print("Parameters without grad:")
print(no_grad)
def on_validation_start(self):
self.validation_generator = torch.Generator(device=self.device).manual_seed(
3407
)
self.validation_generator_ema = torch.Generator(device=self.device).manual_seed(
3407
)
def validation_step(self, batch, batch_idx):
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
loss = self.loss(
self.preconditioning,
self.network,
batch,
generator=self.validation_generator,
).mean()
self.log(
"val/loss",
loss,
sync_dist=True,
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
if hasattr(self, "ema_model"):
loss_ema = self.loss(
self.preconditioning,
self.ema_network,
batch,
generator=self.validation_generator_ema,
).mean()
self.log(
"val/loss_ema",
loss_ema,
sync_dist=True,
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
# nll = -self.compute_exact_loglikelihood(batch).mean()
# self.log(
# "val/nll",
# nll,
# sync_dist=True,
# on_step=False,
# on_epoch=True,
# batch_size=batch_size,
# )
# def on_validation_epoch_end(self):
# metrics = self.val_metrics.compute()
# for metric_name, metric_value in metrics.items():
# self.log(
# f"val/{metric_name}",
# metric_value,
# sync_dist=True,
# on_step=False,
# on_epoch=True,
# )
def on_test_start(self):
self.test_generator = torch.Generator(device=self.device).manual_seed(3407)
def test_step_simple(self, batch, batch_idx):
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
if isinstance(self.manifold, Sphere):
x_N = self.manifold.random_base(
batch_size,
self.input_dim,
device=self.device,
)
x_N = x_N.reshape(batch_size, self.input_dim)
else:
x_N = torch.randn(
batch_size,
self.input_dim,
device=self.device,
generator=self.test_generator,
)
cond = batch[self.cfg.cond_preprocessing.output_key]
samples = self.sample(
x_N=x_N,
cond=cond,
stage="val",
generator=self.test_generator,
cfg=self.cfg.cfg_rate,
)
self.test_metrics.update({"gps": samples}, batch)
if self.cfg.compute_nll:
nll = -self.compute_exact_loglikelihood(batch, cfg=0).mean()
self.log(
"test/NLL",
nll,
sync_dist=True,
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
def test_best_nll(self, batch, batch_idx):
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
num_sample_per_cond = 32
if isinstance(self.manifold, Sphere):
x_N = self.manifold.random_base(
batch_size * num_sample_per_cond,
self.input_dim,
device=self.device,
)
x_N = x_N.reshape(batch_size * num_sample_per_cond, self.input_dim)
else:
x_N = torch.randn(
batch_size * num_sample_per_cond,
self.input_dim,
device=self.device,
generator=self.test_generator,
)
cond = (
batch[self.cfg.cond_preprocessing.output_key]
.unsqueeze(1)
.repeat(1, num_sample_per_cond, 1)
.view(-1, batch[self.cfg.cond_preprocessing.output_key].shape[-1])
)
samples = self.sample_distribution(
x_N,
cond,
sampling_batch_size=32768,
stage="val",
generator=self.test_generator,
cfg=0,
)
samples = samples.view(batch_size * num_sample_per_cond, -1)
batch_swarm = {"gps": samples, "emb": cond}
nll_batch = -self.compute_exact_loglikelihood(batch_swarm, cfg=0)
nll_batch = nll_batch.view(batch_size, num_sample_per_cond, -1)
nll_best = nll_batch[
torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1)
]
self.log(
"test/best_nll",
nll_best.mean(),
sync_dist=True,
on_step=False,
on_epoch=True,
)
samples = samples.view(batch_size, num_sample_per_cond, -1)[
torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1)
]
self.test_metrics.update({"gps": samples}, batch)
def test_step(self, batch, batch_idx):
if self.cfg.compute_swarms:
self.test_best_nll(batch, batch_idx)
else:
self.test_step_simple(batch, batch_idx)
def on_test_epoch_end(self):
metrics = self.test_metrics.compute()
for metric_name, metric_value in metrics.items():
self.log(
f"test/{metric_name}",
metric_value,
sync_dist=True,
on_step=False,
on_epoch=True,
)
def configure_optimizers(self):
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay:
parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm])
parameters_names_wd = [
name for name in parameters_names_wd if "bias" not in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in self.network.named_parameters()
if n in parameters_names_wd
],
"weight_decay": self.cfg.optimizer.optim.weight_decay,
"layer_adaptation": True,
},
{
"params": [
p
for n, p in self.network.named_parameters()
if n not in parameters_names_wd
],
"weight_decay": 0.0,
"layer_adaptation": False,
},
]
optimizer = instantiate(
self.cfg.optimizer.optim, optimizer_grouped_parameters
)
else:
optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters())
if "lr_scheduler" in self.cfg:
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
else:
return optimizer
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(self.global_step)
def sample(
self,
batch_size=None,
cond=None,
x_N=None,
num_steps=None,
stage="test",
cfg=0,
generator=None,
return_trajectories=False,
postprocessing=True,
):
if x_N is None:
assert batch_size is not None
if isinstance(self.manifold, Sphere):
x_N = self.manifold.random_base(
batch_size, self.input_dim, device=self.device
)
x_N = x_N.reshape(batch_size, self.input_dim)
else:
x_N = torch.randn(batch_size, self.input_dim, device=self.device)
batch = {"y": x_N}
if stage == "val":
sampler = self.val_sampler
elif stage == "test":
sampler = self.test_sampler
else:
raise ValueError(f"Unknown stage {stage}")
batch[self.cfg.cond_preprocessing.input_key] = cond
batch = self.cond_preprocessing(batch, device=self.device)
if num_steps is None:
output = sampler(
self.ema_model,
batch,
conditioning_keys=self.cfg.cond_preprocessing.output_key,
scheduler=self.inference_noise_scheduler,
cfg_rate=cfg,
generator=generator,
return_trajectories=return_trajectories,
)
else:
output = sampler(
self.ema_model,
batch,
conditioning_keys=self.cfg.cond_preprocessing.output_key,
scheduler=self.inference_noise_scheduler,
num_steps=num_steps,
cfg_rate=cfg,
generator=generator,
return_trajectories=return_trajectories,
)
if return_trajectories:
return (
self.postprocessing(output[0]) if postprocessing else output[0],
[
self.postprocessing(frame) if postprocessing else frame
for frame in output[1]
],
)
else:
return self.postprocessing(output) if postprocessing else output
def sample_distribution(
self,
x_N,
cond,
sampling_batch_size=2048,
num_steps=None,
stage="test",
cfg=0,
generator=None,
return_trajectories=False,
):
if return_trajectories:
x_0 = []
trajectories = []
i = -1
for i in range(x_N.shape[0] // sampling_batch_size):
x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size]
cond_batch = cond[
i * sampling_batch_size : (i + 1) * sampling_batch_size
]
out, trajectories = self.sample(
cond=cond_batch,
x_N=x_N_batch,
num_steps=num_steps,
stage=stage,
cfg=cfg,
generator=generator,
return_trajectories=return_trajectories,
)
x_0.append(out)
trajectories.append(trajectories)
if x_N.shape[0] % sampling_batch_size != 0:
x_N_batch = x_N[(i + 1) * sampling_batch_size :]
cond_batch = cond[(i + 1) * sampling_batch_size :]
out, trajectories = self.sample(
cond=cond_batch,
x_N=x_N_batch,
num_steps=num_steps,
stage=stage,
cfg=cfg,
generator=generator,
return_trajectories=return_trajectories,
)
x_0.append(out)
trajectories.append(trajectories)
x_0 = torch.cat(x_0, dim=1)
trajectories = [torch.cat(frame, dim=1) for frame in trajectories]
return x_0, trajectories
else:
x_0 = []
i = -1
for i in range(x_N.shape[0] // sampling_batch_size):
x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size]
cond_batch = cond[
i * sampling_batch_size : (i + 1) * sampling_batch_size
]
out = self.sample(
cond=cond_batch,
x_N=x_N_batch,
num_steps=num_steps,
stage=stage,
cfg=cfg,
generator=generator,
return_trajectories=return_trajectories,
)
x_0.append(out)
if x_N.shape[0] % sampling_batch_size != 0:
x_N_batch = x_N[(i + 1) * sampling_batch_size :]
cond_batch = cond[(i + 1) * sampling_batch_size :]
out = self.sample(
cond=cond_batch,
x_N=x_N_batch,
num_steps=num_steps,
stage=stage,
cfg=cfg,
generator=generator,
return_trajectories=return_trajectories,
)
x_0.append(out)
x_0 = torch.cat(x_0, dim=0)
return x_0
def model(self, *args, **kwargs):
return self.preconditioning(self.network, *args, **kwargs)
def ema_model(self, *args, **kwargs):
return self.preconditioning(self.ema_network, *args, **kwargs)
def compute_exact_loglikelihood(
self,
batch=None,
x_1=None,
cond=None,
t1=1.0,
num_steps=1000,
rademacher=False,
data_preprocessing=True,
cfg=0,
):
nfe = [0]
if batch is None:
batch = {"x_0": x_1, "emb": cond}
if data_preprocessing:
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
timesteps = self.inference_noise_scheduler(
torch.linspace(0, t1, 2).to(batch["x_0"])
)
with torch.inference_mode(mode=False):
def odefunc(t, tensor):
nfe[0] += 1
t = t.to(tensor)
gamma = self.inference_noise_scheduler(t)
x = tensor[..., : self.input_dim]
y = batch["emb"]
def vecfield(x, y):
if cfg > 0:
batch_vecfield = {
"y": x,
"emb": y,
"gamma": gamma.reshape(-1),
}
model_output_cond = self.ema_model(batch_vecfield)
batch_vecfield_uncond = {
"y": x,
"emb": torch.zeros_like(y),
"gamma": gamma.reshape(-1),
}
model_output_uncond = self.ema_model(batch_vecfield_uncond)
model_output = model_output_cond + cfg * (
model_output_cond - model_output_uncond
)
else:
batch_vecfield = {
"y": x,
"emb": y,
"gamma": gamma.reshape(-1),
}
model_output = self.ema_model(batch_vecfield)
if self.interpolant == "flow_matching":
d_gamma = self.inference_noise_scheduler.derivative(t).reshape(
-1, 1
)
return d_gamma * model_output
elif self.interpolant == "diffusion":
alpha_t = self.inference_noise_scheduler.alpha(t).reshape(-1, 1)
return (
-1 / 2 * (alpha_t * x - torch.abs(alpha_t) * model_output)
)
else:
raise ValueError(f"Unknown interpolant {self.interpolant}")
if rademacher:
v = torch.randint_like(x, 2) * 2 - 1
else:
v = None
dx, div = output_and_div(vecfield, x, y, v=v)
div = div.reshape(-1, 1)
del t, x
return torch.cat([dx, div], dim=-1)
x_1 = batch["x_0"]
state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1)
with torch.no_grad():
if False and isinstance(self.manifold, Sphere):
print("Riemannian flow sampler")
product_man = ProductManifold(
(self.manifold, self.input_dim), (Euclidean(), 1)
)
state0 = ode_riemannian_flow_sampler(
odefunc,
state1,
manifold=product_man,
scheduler=self.inference_noise_scheduler,
num_steps=num_steps,
)
else:
print("ODE solver")
state0 = odeint(
odefunc,
state1,
t=torch.linspace(0, t1, 2).to(batch["x_0"]),
atol=1e-6,
rtol=1e-6,
method="dopri5",
options={"min_step": 1e-5},
)[-1]
x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1]
if self.manifold is not None:
x_0 = self.manifold.projx(x_0)
logp0 = self.manifold.base_logprob(x_0)
else:
logp0 = (
-1 / 2 * (x_0**2).sum(dim=-1)
- self.input_dim
* torch.log(torch.tensor(2 * np.pi, device=x_0.device))
/ 2
)
print(f"nfe: {nfe[0]}")
logp1 = logp0 + logdetjac
logp1 = logp1 / (self.input_dim * np.log(2))
return logp1
def get_parameter_names(model, forbidden_layer_types):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
Taken from HuggingFace transformers.
"""
result = []
for name, child in model.named_children():
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
if not isinstance(child, tuple(forbidden_layer_types))
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
return result
# for likelihood computation
def div_fn(u):
"""Accepts a function u:R^D -> R^D."""
J = jacrev(u, argnums=0)
return lambda x, y: torch.trace(J(x, y).squeeze(0))
def output_and_div(vecfield, x, y, v=None):
if v is None:
dx = vecfield(x, y)
div = vmap(div_fn(vecfield))(x, y)
else:
vecfield_x = lambda x: vecfield(x, y)
dx, vjpfunc = vjp(vecfield_x, x)
vJ = vjpfunc(v)[0]
div = torch.sum(vJ * v, dim=-1)
return dx, div
class VonFisherGeolocalizer(L.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.network = instantiate(cfg.network)
# self.network = torch.compile(self.network, fullgraph=True)
self.input_dim = cfg.network.input_dim
self.data_preprocessing = instantiate(cfg.data_preprocessing)
self.cond_preprocessing = instantiate(cfg.cond_preprocessing)
self.preconditioning = instantiate(cfg.preconditioning)
self.ema_network = copy.deepcopy(self.network).requires_grad_(False)
self.ema_network.eval()
self.postprocessing = instantiate(cfg.postprocessing)
self.val_sampler = instantiate(cfg.val_sampler)
self.test_sampler = instantiate(cfg.test_sampler)
self.loss = instantiate(cfg.loss)()
self.val_metrics = instantiate(cfg.val_metrics)
self.test_metrics = instantiate(cfg.test_metrics)
def training_step(self, batch, batch_idx):
with torch.no_grad():
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
loss = self.loss(self.preconditioning, self.network, batch).mean()
self.log(
"train/loss",
loss,
sync_dist=True,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
return loss
def on_before_optimizer_step(self, optimizer):
if self.global_step == 0:
no_grad = []
for name, param in self.network.named_parameters():
if param.grad is None:
no_grad.append(name)
if len(no_grad) > 0:
print("Parameters without grad:")
print(no_grad)
def on_validation_start(self):
self.validation_generator = torch.Generator(device=self.device).manual_seed(
3407
)
self.validation_generator_ema = torch.Generator(device=self.device).manual_seed(
3407
)
def validation_step(self, batch, batch_idx):
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
loss = self.loss(
self.preconditioning,
self.network,
batch,
generator=self.validation_generator,
).mean()
self.log(
"val/loss",
loss,
sync_dist=True,
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
if hasattr(self, "ema_model"):
loss_ema = self.loss(
self.preconditioning,
self.ema_network,
batch,
generator=self.validation_generator_ema,
).mean()
self.log(
"val/loss_ema",
loss_ema,
sync_dist=True,
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
def on_test_start(self):
self.test_generator = torch.Generator(device=self.device).manual_seed(3407)
def test_step(self, batch, batch_idx):
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
cond = batch[self.cfg.cond_preprocessing.output_key]
samples = self.sample(cond=cond, stage="test")
self.test_metrics.update({"gps": samples}, batch)
nll = -self.compute_exact_loglikelihood(batch).mean()
self.log(
"test/NLL",
nll,
sync_dist=True,
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
def on_test_epoch_end(self):
metrics = self.test_metrics.compute()
for metric_name, metric_value in metrics.items():
self.log(
f"test/{metric_name}",
metric_value,
sync_dist=True,
on_step=False,
on_epoch=True,
)
def configure_optimizers(self):
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay:
parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm])
parameters_names_wd = [
name for name in parameters_names_wd if "bias" not in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in self.network.named_parameters()
if n in parameters_names_wd
],
"weight_decay": self.cfg.optimizer.optim.weight_decay,
"layer_adaptation": True,
},
{
"params": [
p
for n, p in self.network.named_parameters()
if n not in parameters_names_wd
],
"weight_decay": 0.0,
"layer_adaptation": False,
},
]
optimizer = instantiate(
self.cfg.optimizer.optim, optimizer_grouped_parameters
)
else:
optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters())
if "lr_scheduler" in self.cfg:
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
else:
return optimizer
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(self.global_step)
def sample(
self,
batch_size=None,
cond=None,
postprocessing=True,
stage="val",
):
batch = {}
if stage == "val":
sampler = self.val_sampler
elif stage == "test":
sampler = self.test_sampler
else:
raise ValueError(f"Unknown stage {stage}")
batch[self.cfg.cond_preprocessing.input_key] = cond
batch = self.cond_preprocessing(batch, device=self.device)
output = sampler(
self.ema_model,
batch,
)
return self.postprocessing(output) if postprocessing else output
def model(self, *args, **kwargs):
return self.preconditioning(self.network, *args, **kwargs)
def ema_model(self, *args, **kwargs):
return self.preconditioning(self.ema_network, *args, **kwargs)
def compute_exact_loglikelihood(
self,
batch=None,
):
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
return -self.loss(self.preconditioning, self.ema_network, batch)
class RandomGeolocalizer(L.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.test_metrics = instantiate(cfg.test_metrics)
self.data_preprocessing = instantiate(cfg.data_preprocessing)
self.cond_preprocessing = instantiate(cfg.cond_preprocessing)
self.postprocessing = instantiate(cfg.postprocessing)
def test_step(self, batch, batch_idx):
batch = self.data_preprocessing(batch)
batch = self.cond_preprocessing(batch)
batch_size = batch["x_0"].shape[0]
samples = torch.randn(batch_size, 3, device=self.device)
samples = samples / samples.norm(dim=-1, keepdim=True)
samples = self.postprocessing(samples)
self.test_metrics.update({"gps": samples}, batch)
def on_test_epoch_end(self):
metrics = self.test_metrics.compute()
for metric_name, metric_value in metrics.items():
self.log(
f"test/{metric_name}",
metric_value,
sync_dist=True,
on_step=False,
on_epoch=True,
)