Spaces:
Sleeping
Sleeping
import pytorch_lightning as pl | |
import torch | |
import torch.nn.functional as F | |
from salad.data.dataset import SALADDataset | |
from salad.utils.train_util import PolyDecayScheduler | |
class BaseModel(pl.LightningModule): | |
def __init__( | |
self, | |
network, | |
variance_schedule, | |
**kwargs, | |
): | |
super().__init__() | |
self.save_hyperparameters(logger=False) | |
self.net = network | |
self.var_sched = variance_schedule | |
def forward(self, x): | |
return self.get_loss(x) | |
def step(self, x, stage: str): | |
loss = self(x) | |
self.log( | |
f"{stage}/loss", | |
loss, | |
on_step=stage == "train", | |
prog_bar=True, | |
) | |
return loss | |
def training_step(self, batch, batch_idx): | |
x = batch | |
return self.step(x, "train") | |
def add_noise(self, x, t): | |
""" | |
Input: | |
x: [B,D] or [B,G,D] | |
t: list of size B | |
Output: | |
x_noisy: [B,D] | |
beta: [B] | |
e_rand: [B,D] | |
""" | |
alpha_bar = self.var_sched.alpha_bars[t] | |
beta = self.var_sched.betas[t] | |
c0 = torch.sqrt(alpha_bar).view(-1, 1) # [B,1] | |
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1) | |
e_rand = torch.randn_like(x) | |
if e_rand.dim() == 3: | |
c0 = c0.unsqueeze(1) | |
c1 = c1.unsqueeze(1) | |
x_noisy = c0 * x + c1 * e_rand | |
return x_noisy, beta, e_rand | |
def get_loss( | |
self, | |
x0, | |
t=None, | |
noisy_in=False, | |
beta_in=None, | |
e_rand_in=None, | |
): | |
if x0.dim() == 2: | |
B, D = x0.shape | |
else: | |
B, G, D = x0.shape | |
if not noisy_in: | |
if t is None: | |
t = self.var_sched.uniform_sample_t(B) | |
x_noisy, beta, e_rand = self.add_noise(x0, t) | |
else: | |
x_noisy = x0 | |
beta = beta_in | |
e_rand = e_rand_in | |
e_theta = self.net(x_noisy, beta=beta) | |
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean") | |
return loss | |
def sample( | |
self, | |
batch_size=0, | |
return_traj=False, | |
): | |
raise NotImplementedError | |
def validation_epoch_end(self, outputs): | |
if self.hparams.no_run_validation: | |
return | |
if not self.trainer.sanity_checking: | |
if (self.current_epoch) % self.hparams.validation_step == 0: | |
self.validation() | |
def _build_dataset(self, stage): | |
if hasattr(self, f"data_{stage}"): | |
return getattr(self, f"data_{stage}") | |
if stage == "train": | |
ds = SALADDataset(**self.hparams.dataset_kwargs) | |
else: | |
dataset_kwargs = self.hparams.dataset_kwargs.copy() | |
dataset_kwargs["repeat"] = 1 | |
ds = SALADDataset(**dataset_kwargs) | |
setattr(self, f"data_{stage}", ds) | |
return ds | |
def _build_dataloader(self, stage): | |
try: | |
ds = getattr(self, f"data_{stage}") | |
except: | |
ds = self._build_dataset(stage) | |
return torch.utils.data.DataLoader( | |
ds, | |
batch_size=self.hparams.batch_size, | |
shuffle=stage == "train", | |
drop_last=stage == "train", | |
num_workers=4, | |
) | |
def train_dataloader(self): | |
return self._build_dataloader("train") | |
def val_dataloader(self): | |
return self._build_dataloader("val") | |
def test_dataloader(self): | |
return self._build_dataloader("test") | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) | |
scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999) | |
return [optimizer], [scheduler] | |
#TODO move get_wandb_logger to logutil.py | |
def get_wandb_logger(self): | |
for logger in self.logger: | |
if isinstance(logger, pl.loggers.wandb.WandbLogger): | |
return logger | |
return None | |