Spaces:
Sleeping
Sleeping
File size: 4,052 Bytes
801501a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from salad.data.dataset import SALADDataset
from salad.utils.train_util import PolyDecayScheduler
class BaseModel(pl.LightningModule):
def __init__(
self,
network,
variance_schedule,
**kwargs,
):
super().__init__()
self.save_hyperparameters(logger=False)
self.net = network
self.var_sched = variance_schedule
def forward(self, x):
return self.get_loss(x)
def step(self, x, stage: str):
loss = self(x)
self.log(
f"{stage}/loss",
loss,
on_step=stage == "train",
prog_bar=True,
)
return loss
def training_step(self, batch, batch_idx):
x = batch
return self.step(x, "train")
def add_noise(self, x, t):
"""
Input:
x: [B,D] or [B,G,D]
t: list of size B
Output:
x_noisy: [B,D]
beta: [B]
e_rand: [B,D]
"""
alpha_bar = self.var_sched.alpha_bars[t]
beta = self.var_sched.betas[t]
c0 = torch.sqrt(alpha_bar).view(-1, 1) # [B,1]
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1)
e_rand = torch.randn_like(x)
if e_rand.dim() == 3:
c0 = c0.unsqueeze(1)
c1 = c1.unsqueeze(1)
x_noisy = c0 * x + c1 * e_rand
return x_noisy, beta, e_rand
def get_loss(
self,
x0,
t=None,
noisy_in=False,
beta_in=None,
e_rand_in=None,
):
if x0.dim() == 2:
B, D = x0.shape
else:
B, G, D = x0.shape
if not noisy_in:
if t is None:
t = self.var_sched.uniform_sample_t(B)
x_noisy, beta, e_rand = self.add_noise(x0, t)
else:
x_noisy = x0
beta = beta_in
e_rand = e_rand_in
e_theta = self.net(x_noisy, beta=beta)
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
return loss
@torch.no_grad()
def sample(
self,
batch_size=0,
return_traj=False,
):
raise NotImplementedError
def validation_epoch_end(self, outputs):
if self.hparams.no_run_validation:
return
if not self.trainer.sanity_checking:
if (self.current_epoch) % self.hparams.validation_step == 0:
self.validation()
def _build_dataset(self, stage):
if hasattr(self, f"data_{stage}"):
return getattr(self, f"data_{stage}")
if stage == "train":
ds = SALADDataset(**self.hparams.dataset_kwargs)
else:
dataset_kwargs = self.hparams.dataset_kwargs.copy()
dataset_kwargs["repeat"] = 1
ds = SALADDataset(**dataset_kwargs)
setattr(self, f"data_{stage}", ds)
return ds
def _build_dataloader(self, stage):
try:
ds = getattr(self, f"data_{stage}")
except:
ds = self._build_dataset(stage)
return torch.utils.data.DataLoader(
ds,
batch_size=self.hparams.batch_size,
shuffle=stage == "train",
drop_last=stage == "train",
num_workers=4,
)
def train_dataloader(self):
return self._build_dataloader("train")
def val_dataloader(self):
return self._build_dataloader("val")
def test_dataloader(self):
return self._build_dataloader("test")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999)
return [optimizer], [scheduler]
#TODO move get_wandb_logger to logutil.py
def get_wandb_logger(self):
for logger in self.logger:
if isinstance(logger, pl.loggers.wandb.WandbLogger):
return logger
return None
|