TSEditor / engine /solver.py
PeterYu's picture
update
2875fe6
import os
import sys
import time
import torch
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from ema_pytorch import EMA
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
from utils.io_utils import instantiate_from_config, get_model_parameters_info
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
def cycle(dl):
while True:
for data in dl:
yield data
class Trainer(object):
def __init__(self, config, args, model, dataloader, logger=None):
super().__init__()
if os.getenv("WANDB_ENABLED") == "true":
import wandb
self.run = wandb.init(project="tiffusion-revenue", config=config)
else:
self.run = None
self.model = model
self.device = self.model.betas.device
self.train_num_steps = config["solver"]["max_epochs"]
self.gradient_accumulate_every = config["solver"]["gradient_accumulate_every"]
self.save_cycle = config["solver"]["save_cycle"]
self.dl = cycle(dataloader["dataloader"])
self.step = 0
self.milestone = 0
self.args = args
self.logger = logger
self.results_folder = Path(
config["solver"]["results_folder"] + f"_{model.seq_length}"
)
os.makedirs(self.results_folder, exist_ok=True)
start_lr = config["solver"].get("base_lr", 1.0e-4)
ema_decay = config["solver"]["ema"]["decay"]
ema_update_every = config["solver"]["ema"]["update_interval"]
self.opt = Adam(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=start_lr,
betas=[0.9, 0.96],
)
self.ema = EMA(self.model, beta=ema_decay, update_every=ema_update_every).to(
self.device
)
sc_cfg = config["solver"]["scheduler"]
sc_cfg["params"]["optimizer"] = self.opt
self.sch = instantiate_from_config(sc_cfg)
if self.logger is not None:
self.logger.log_info(str(get_model_parameters_info(self.model)))
self.log_frequency = 100
def save(self, milestone, verbose=False):
if self.logger is not None and verbose:
self.logger.log_info(
"Save current model to {}".format(
str(self.results_folder / f"checkpoint-{milestone}.pt")
)
)
data = {
"step": self.step,
"model": self.model.state_dict(),
"ema": self.ema.state_dict(),
"opt": self.opt.state_dict(),
}
torch.save(data, str(self.results_folder / f"checkpoint-{milestone}.pt"))
def load(self, milestone, verbose=False, from_folder=None):
if self.logger is not None and verbose:
self.logger.log_info(
"Resume from {}".format(
os.path.join(from_folder, f"checkpoint-{milestone}.pt")
)
)
device = self.device
data = torch.load(
os.path.join(from_folder,f"checkpoint-{milestone}.pt") if from_folder else str(self.results_folder / f"checkpoint-{milestone}.pt"), map_location=device, weights_only=True
)
self.model.load_state_dict(data["model"], )
self.step = data["step"]
self.opt.load_state_dict(data["opt"])
self.ema.load_state_dict(data["ema"])
self.milestone = milestone
def train(self):
device = self.device
step = 0
if self.logger is not None:
tic = time.time()
self.logger.log_info(
"{}: start training...".format(self.args.name), check_primary=False
)
with tqdm(initial=step, total=self.train_num_steps) as pbar:
while step < self.train_num_steps:
total_loss = 0.0
for _ in range(self.gradient_accumulate_every):
data = next(self.dl).to(device)
loss = self.model(data, target=data)
loss = loss / self.gradient_accumulate_every
loss.backward()
total_loss += loss.item()
pbar.set_description(
f'loss: {total_loss:.6f} lr: {self.opt.param_groups[0]["lr"]:.6f}'
)
if self.run is not None:
wandb.log(
{
"step": step,
"loss": total_loss,
"lr": self.opt.param_groups[0]["lr"],
},
step=self.step,
)
clip_grad_norm_(self.model.parameters(), 1.0)
self.opt.step()
self.sch.step(total_loss)
self.opt.zero_grad()
self.step += 1
step += 1
self.ema.update()
with torch.no_grad():
if self.step != 0 and self.step % self.save_cycle == 0:
self.milestone += 1
self.save(self.milestone)
# self.logger.log_info('saved in {}'.format(str(self.results_folder / f'checkpoint-{self.milestone}.pt')))
if self.logger is not None and self.step % self.log_frequency == 0:
# info = '{}: train'.format(self.args.name)
# info = info + ': Epoch {}/{}'.format(self.step, self.train_num_steps)
# info += ' ||'
# info += '' if loss_f == 'none' else ' Fourier Loss: {:.4f}'.format(loss_f.item())
# info += '' if loss_r == 'none' else ' Reglarization: {:.4f}'.format(loss_r.item())
# info += ' | Total Loss: {:.6f}'.format(total_loss)
# self.logger.log_info(info)
self.logger.add_scalar(
tag="train/loss",
scalar_value=total_loss,
global_step=self.step,
)
pbar.update(1)
print("training complete")
if self.logger is not None:
self.logger.log_info(
"Training done, time: {:.2f}".format(time.time() - tic)
)
def sample(self, num, size_every, shape=None):
if self.logger is not None:
tic = time.time()
self.logger.log_info("Begin to sample...")
samples = np.empty([0, shape[0], shape[1]])
num_cycle = int(num // size_every) + 1
for _ in range(num_cycle):
sample = self.ema.ema_model.generate_mts(batch_size=size_every)
samples = np.row_stack([samples, sample.detach().cpu().numpy()])
torch.cuda.empty_cache()
if self.logger is not None:
self.logger.log_info(
"Sampling done, time: {:.2f}".format(time.time() - tic)
)
return samples
def control_sample(self, num, size_every, shape=None, model_kwargs={}, target=None, partial_mask=None):
samples = np.empty([0, shape[0], shape[1]])
import math
num_cycle = math.ceil(num / size_every)
assert not ((target is None) ^ (partial_mask is None)), "target and partial_mask should be provided"
if self.logger is not None:
tic = time.time()
self.logger.log_info("Begin to infill sample...")
target = torch.tensor(target).to(self.device) if target is not None else torch.zeros(shape).to(self.device)
target = target.repeat(size_every, 1, 1) if len(target.shape) == 2 else target
partial_mask = torch.tensor(partial_mask).to(self.device) if partial_mask is not None else torch.zeros(shape).to(self.device)
partial_mask = partial_mask.repeat(size_every, 1, 1) if len(partial_mask.shape) == 2 else partial_mask
for _ in range(num_cycle):
sample = self.ema.ema_model.generate_mts_infill(target, partial_mask, model_kwargs=model_kwargs)
samples = np.row_stack([samples, sample.detach().cpu().numpy()])
torch.cuda.empty_cache()
if self.logger is not None:
self.logger.log_info(
"Sampling done, time: {:.2f}".format(time.time() - tic)
)
return samples
def predict(
self,
observed_points: torch.Tensor,
coef=1e-1,
stepsize=1e-1,
sampling_steps=50,
**kargs,
):
model_kwargs = {}
model_kwargs["coef"] = coef
model_kwargs["learning_rate"] = stepsize
model_kwargs = {**model_kwargs, **kargs}
assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1"
x = observed_points.unsqueeze(0)
t_m = x != 0
x = x * 2 - 1 # normalize
x, t_m = x.to(self.device), t_m.to(self.device)
if sampling_steps == self.model.num_timesteps:
print("normal sampling")
sample = self.ema.ema_model.sample_infill(
shape=x.shape,
target=x * t_m,
partial_mask=t_m,
model_kwargs=model_kwargs,
)
# x: partially noise : (batch_size, seq_length, feature_dim)
else:
print("fast sampling")
sample = self.ema.ema_model.fast_sample_infill(
shape=x.shape,
target=x * t_m,
partial_mask=t_m,
model_kwargs=model_kwargs,
sampling_timesteps=sampling_steps,
)
# unnormalize
sample = (sample + 1) / 2
return sample.squeeze(0).detach().cpu().numpy()
def predict_weighted_points(
self,
observed_points: torch.Tensor,
observed_mask: torch.Tensor,
coef=1e-1,
stepsize=1e-1,
sampling_steps=50,
**kargs,
):
model_kwargs = {}
model_kwargs["coef"] = coef
model_kwargs["learning_rate"] = stepsize
model_kwargs = {**model_kwargs, **kargs}
assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1"
x = observed_points.unsqueeze(0)
float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor
binary_mask = float_mask.clone()
binary_mask[binary_mask > 0] = 1
x = x * 2 - 1 # normalize
x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device)
if sampling_steps == self.model.num_timesteps:
print("normal sampling")
raise NotImplementedError
sample = self.ema.ema_model.sample_infill_float_mask(
shape=x.shape,
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing
partial_mask=float_mask,
model_kwargs=model_kwargs,
)
# x: partially noise : (batch_size, seq_length, feature_dim)
else:
print("fast sampling")
sample = self.ema.ema_model.fast_sample_infill_float_mask(
shape=x.shape,
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing
partial_mask=float_mask,
model_kwargs=model_kwargs,
sampling_timesteps=sampling_steps,
)
# unnormalize
sample = (sample + 1) / 2
return sample.squeeze(0).detach().cpu().numpy()
def restore(
self,
raw_dataloader,
shape=None,
coef=1e-1,
stepsize=1e-1,
sampling_steps=50,
**kargs,
):
if self.logger is not None:
tic = time.time()
self.logger.log_info("Begin to restore...")
model_kwargs = {}
model_kwargs["coef"] = coef
model_kwargs["learning_rate"] = stepsize
model_kwargs = {**model_kwargs, **kargs}
test = kargs.get("test", False)
samples = np.empty([0, shape[0], shape[1]]) # seq_length, feature_dim
reals = np.empty([0, shape[0], shape[1]])
masks = np.empty([0, shape[0], shape[1]])
for idx, (x, t_m) in enumerate(raw_dataloader):
# # take first 5 example
# # x, t_m = x[:5], t_m[:5]
# # x[~t_m] = 0
# # print(x, t_m)
# # 1M 2021/2/10 9
# # 2M 2021/2/16 6+9
# # 3M 2021/2/19 9+9
# # 4M 2021/2/24 14+
# # 5M 2021/3/3 20+9
# x = torch.zeros_like(x)[:1]
# # x[0, 0, 0] = 0.03
# # x[0, 9, 0] = 0.16
# # x[0, 15, 0] = 0.25
# # x[0, 18, 0] = 0.22
# # x[0, 24, 0] = 0.21
# # x[0, 33, 0] = 0.16
# x[0, 0, 0] = 0.04
# x[0, 2, 0] = 0.58
# x[0, 6, 0] = 0.27
# x[0, 58, 0] = 1.
# x[0, -1, 0] = 0.05
# # x[0, 0, 0] = 0.01
# # x[0, -1, 0] = 0.01
# # x[0, -20, 0] = 0.01
# # x[0, -100, 0] = 0.01
# # x[0, -50, 0] = 0.01
# # x[0, -120, 0] = 0.01
# # import math
# # for i in range(35, 240, 2):
# # x[0, i, 0] = max(0.01, math.exp(-0.01*i) / 10)
# # import matplotlib.pyplot as plt
# # plt.plot(x[0, :, 0].detach().cpu().numpy())
# # plt.show()
t_m = x == 0 # x != 0, 1 for observed, 0 for missing, bool tensor
# #
if test:
t_m = t_m.type_as(x)
binary_mask = t_m.clone()
binary_mask[binary_mask > 0] = 1
else:
binary_mask = t_m
# x = x * 2 - 1
x, t_m = x.to(self.device), t_m.to(self.device)
binary_mask = binary_mask.to(self.device)
if sampling_steps == self.model.num_timesteps:
print("normal sampling")
sample = self.ema.ema_model.sample_infill(
shape=x.shape,
target=x * t_m,
partial_mask=t_m,
model_kwargs=model_kwargs,
)
# x: partially noise : (batch_size, seq_length, feature_dim)
else:
print("fast sampling")
if test:
sample = self.ema.ema_model.fast_sample_infill_float_mask(
shape=x.shape,
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing
partial_mask=t_m,
model_kwargs=model_kwargs,
sampling_timesteps=sampling_steps,
)
else:
sample = self.ema.ema_model.fast_sample_infill(
shape=x.shape,
target=x * t_m,
partial_mask=t_m,
model_kwargs=model_kwargs,
sampling_timesteps=sampling_steps,
)
samples = np.row_stack([samples, sample.detach().cpu().numpy()])
reals = np.row_stack([reals, x.detach().cpu().numpy()])
masks = np.row_stack([masks, t_m.detach().cpu().numpy()])
break
if self.logger is not None:
self.logger.log_info(
"Imputation done, time: {:.2f}".format(time.time() - tic)
)
return samples, reals, masks
# return samples