Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import torch | |
import numpy as np | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
from ema_pytorch import EMA | |
from torch.optim import Adam | |
from torch.nn.utils import clip_grad_norm_ | |
from utils.io_utils import instantiate_from_config, get_model_parameters_info | |
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) | |
def cycle(dl): | |
while True: | |
for data in dl: | |
yield data | |
class Trainer(object): | |
def __init__(self, config, args, model, dataloader, logger=None): | |
super().__init__() | |
if os.getenv("WANDB_ENABLED") == "true": | |
import wandb | |
self.run = wandb.init(project="tiffusion-revenue", config=config) | |
else: | |
self.run = None | |
self.model = model | |
self.device = self.model.betas.device | |
self.train_num_steps = config["solver"]["max_epochs"] | |
self.gradient_accumulate_every = config["solver"]["gradient_accumulate_every"] | |
self.save_cycle = config["solver"]["save_cycle"] | |
self.dl = cycle(dataloader["dataloader"]) | |
self.step = 0 | |
self.milestone = 0 | |
self.args = args | |
self.logger = logger | |
self.results_folder = Path( | |
config["solver"]["results_folder"] + f"_{model.seq_length}" | |
) | |
os.makedirs(self.results_folder, exist_ok=True) | |
start_lr = config["solver"].get("base_lr", 1.0e-4) | |
ema_decay = config["solver"]["ema"]["decay"] | |
ema_update_every = config["solver"]["ema"]["update_interval"] | |
self.opt = Adam( | |
filter(lambda p: p.requires_grad, self.model.parameters()), | |
lr=start_lr, | |
betas=[0.9, 0.96], | |
) | |
self.ema = EMA(self.model, beta=ema_decay, update_every=ema_update_every).to( | |
self.device | |
) | |
sc_cfg = config["solver"]["scheduler"] | |
sc_cfg["params"]["optimizer"] = self.opt | |
self.sch = instantiate_from_config(sc_cfg) | |
if self.logger is not None: | |
self.logger.log_info(str(get_model_parameters_info(self.model))) | |
self.log_frequency = 100 | |
def save(self, milestone, verbose=False): | |
if self.logger is not None and verbose: | |
self.logger.log_info( | |
"Save current model to {}".format( | |
str(self.results_folder / f"checkpoint-{milestone}.pt") | |
) | |
) | |
data = { | |
"step": self.step, | |
"model": self.model.state_dict(), | |
"ema": self.ema.state_dict(), | |
"opt": self.opt.state_dict(), | |
} | |
torch.save(data, str(self.results_folder / f"checkpoint-{milestone}.pt")) | |
def load(self, milestone, verbose=False, from_folder=None): | |
if self.logger is not None and verbose: | |
self.logger.log_info( | |
"Resume from {}".format( | |
os.path.join(from_folder, f"checkpoint-{milestone}.pt") | |
) | |
) | |
device = self.device | |
data = torch.load( | |
os.path.join(from_folder,f"checkpoint-{milestone}.pt") if from_folder else str(self.results_folder / f"checkpoint-{milestone}.pt"), map_location=device, weights_only=True | |
) | |
self.model.load_state_dict(data["model"], ) | |
self.step = data["step"] | |
self.opt.load_state_dict(data["opt"]) | |
self.ema.load_state_dict(data["ema"]) | |
self.milestone = milestone | |
def train(self): | |
device = self.device | |
step = 0 | |
if self.logger is not None: | |
tic = time.time() | |
self.logger.log_info( | |
"{}: start training...".format(self.args.name), check_primary=False | |
) | |
with tqdm(initial=step, total=self.train_num_steps) as pbar: | |
while step < self.train_num_steps: | |
total_loss = 0.0 | |
for _ in range(self.gradient_accumulate_every): | |
data = next(self.dl).to(device) | |
loss = self.model(data, target=data) | |
loss = loss / self.gradient_accumulate_every | |
loss.backward() | |
total_loss += loss.item() | |
pbar.set_description( | |
f'loss: {total_loss:.6f} lr: {self.opt.param_groups[0]["lr"]:.6f}' | |
) | |
if self.run is not None: | |
wandb.log( | |
{ | |
"step": step, | |
"loss": total_loss, | |
"lr": self.opt.param_groups[0]["lr"], | |
}, | |
step=self.step, | |
) | |
clip_grad_norm_(self.model.parameters(), 1.0) | |
self.opt.step() | |
self.sch.step(total_loss) | |
self.opt.zero_grad() | |
self.step += 1 | |
step += 1 | |
self.ema.update() | |
with torch.no_grad(): | |
if self.step != 0 and self.step % self.save_cycle == 0: | |
self.milestone += 1 | |
self.save(self.milestone) | |
# self.logger.log_info('saved in {}'.format(str(self.results_folder / f'checkpoint-{self.milestone}.pt'))) | |
if self.logger is not None and self.step % self.log_frequency == 0: | |
# info = '{}: train'.format(self.args.name) | |
# info = info + ': Epoch {}/{}'.format(self.step, self.train_num_steps) | |
# info += ' ||' | |
# info += '' if loss_f == 'none' else ' Fourier Loss: {:.4f}'.format(loss_f.item()) | |
# info += '' if loss_r == 'none' else ' Reglarization: {:.4f}'.format(loss_r.item()) | |
# info += ' | Total Loss: {:.6f}'.format(total_loss) | |
# self.logger.log_info(info) | |
self.logger.add_scalar( | |
tag="train/loss", | |
scalar_value=total_loss, | |
global_step=self.step, | |
) | |
pbar.update(1) | |
print("training complete") | |
if self.logger is not None: | |
self.logger.log_info( | |
"Training done, time: {:.2f}".format(time.time() - tic) | |
) | |
def sample(self, num, size_every, shape=None): | |
if self.logger is not None: | |
tic = time.time() | |
self.logger.log_info("Begin to sample...") | |
samples = np.empty([0, shape[0], shape[1]]) | |
num_cycle = int(num // size_every) + 1 | |
for _ in range(num_cycle): | |
sample = self.ema.ema_model.generate_mts(batch_size=size_every) | |
samples = np.row_stack([samples, sample.detach().cpu().numpy()]) | |
torch.cuda.empty_cache() | |
if self.logger is not None: | |
self.logger.log_info( | |
"Sampling done, time: {:.2f}".format(time.time() - tic) | |
) | |
return samples | |
def control_sample(self, num, size_every, shape=None, model_kwargs={}, target=None, partial_mask=None): | |
samples = np.empty([0, shape[0], shape[1]]) | |
import math | |
num_cycle = math.ceil(num / size_every) | |
assert not ((target is None) ^ (partial_mask is None)), "target and partial_mask should be provided" | |
if self.logger is not None: | |
tic = time.time() | |
self.logger.log_info("Begin to infill sample...") | |
target = torch.tensor(target).to(self.device) if target is not None else torch.zeros(shape).to(self.device) | |
target = target.repeat(size_every, 1, 1) if len(target.shape) == 2 else target | |
partial_mask = torch.tensor(partial_mask).to(self.device) if partial_mask is not None else torch.zeros(shape).to(self.device) | |
partial_mask = partial_mask.repeat(size_every, 1, 1) if len(partial_mask.shape) == 2 else partial_mask | |
for _ in range(num_cycle): | |
sample = self.ema.ema_model.generate_mts_infill(target, partial_mask, model_kwargs=model_kwargs) | |
samples = np.row_stack([samples, sample.detach().cpu().numpy()]) | |
torch.cuda.empty_cache() | |
if self.logger is not None: | |
self.logger.log_info( | |
"Sampling done, time: {:.2f}".format(time.time() - tic) | |
) | |
return samples | |
def predict( | |
self, | |
observed_points: torch.Tensor, | |
coef=1e-1, | |
stepsize=1e-1, | |
sampling_steps=50, | |
**kargs, | |
): | |
model_kwargs = {} | |
model_kwargs["coef"] = coef | |
model_kwargs["learning_rate"] = stepsize | |
model_kwargs = {**model_kwargs, **kargs} | |
assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" | |
x = observed_points.unsqueeze(0) | |
t_m = x != 0 | |
x = x * 2 - 1 # normalize | |
x, t_m = x.to(self.device), t_m.to(self.device) | |
if sampling_steps == self.model.num_timesteps: | |
print("normal sampling") | |
sample = self.ema.ema_model.sample_infill( | |
shape=x.shape, | |
target=x * t_m, | |
partial_mask=t_m, | |
model_kwargs=model_kwargs, | |
) | |
# x: partially noise : (batch_size, seq_length, feature_dim) | |
else: | |
print("fast sampling") | |
sample = self.ema.ema_model.fast_sample_infill( | |
shape=x.shape, | |
target=x * t_m, | |
partial_mask=t_m, | |
model_kwargs=model_kwargs, | |
sampling_timesteps=sampling_steps, | |
) | |
# unnormalize | |
sample = (sample + 1) / 2 | |
return sample.squeeze(0).detach().cpu().numpy() | |
def predict_weighted_points( | |
self, | |
observed_points: torch.Tensor, | |
observed_mask: torch.Tensor, | |
coef=1e-1, | |
stepsize=1e-1, | |
sampling_steps=50, | |
**kargs, | |
): | |
model_kwargs = {} | |
model_kwargs["coef"] = coef | |
model_kwargs["learning_rate"] = stepsize | |
model_kwargs = {**model_kwargs, **kargs} | |
assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" | |
x = observed_points.unsqueeze(0) | |
float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor | |
binary_mask = float_mask.clone() | |
binary_mask[binary_mask > 0] = 1 | |
x = x * 2 - 1 # normalize | |
x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device) | |
if sampling_steps == self.model.num_timesteps: | |
print("normal sampling") | |
raise NotImplementedError | |
sample = self.ema.ema_model.sample_infill_float_mask( | |
shape=x.shape, | |
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
partial_mask=float_mask, | |
model_kwargs=model_kwargs, | |
) | |
# x: partially noise : (batch_size, seq_length, feature_dim) | |
else: | |
print("fast sampling") | |
sample = self.ema.ema_model.fast_sample_infill_float_mask( | |
shape=x.shape, | |
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
partial_mask=float_mask, | |
model_kwargs=model_kwargs, | |
sampling_timesteps=sampling_steps, | |
) | |
# unnormalize | |
sample = (sample + 1) / 2 | |
return sample.squeeze(0).detach().cpu().numpy() | |
def restore( | |
self, | |
raw_dataloader, | |
shape=None, | |
coef=1e-1, | |
stepsize=1e-1, | |
sampling_steps=50, | |
**kargs, | |
): | |
if self.logger is not None: | |
tic = time.time() | |
self.logger.log_info("Begin to restore...") | |
model_kwargs = {} | |
model_kwargs["coef"] = coef | |
model_kwargs["learning_rate"] = stepsize | |
model_kwargs = {**model_kwargs, **kargs} | |
test = kargs.get("test", False) | |
samples = np.empty([0, shape[0], shape[1]]) # seq_length, feature_dim | |
reals = np.empty([0, shape[0], shape[1]]) | |
masks = np.empty([0, shape[0], shape[1]]) | |
for idx, (x, t_m) in enumerate(raw_dataloader): | |
# # take first 5 example | |
# # x, t_m = x[:5], t_m[:5] | |
# # x[~t_m] = 0 | |
# # print(x, t_m) | |
# # 1M 2021/2/10 9 | |
# # 2M 2021/2/16 6+9 | |
# # 3M 2021/2/19 9+9 | |
# # 4M 2021/2/24 14+ | |
# # 5M 2021/3/3 20+9 | |
# x = torch.zeros_like(x)[:1] | |
# # x[0, 0, 0] = 0.03 | |
# # x[0, 9, 0] = 0.16 | |
# # x[0, 15, 0] = 0.25 | |
# # x[0, 18, 0] = 0.22 | |
# # x[0, 24, 0] = 0.21 | |
# # x[0, 33, 0] = 0.16 | |
# x[0, 0, 0] = 0.04 | |
# x[0, 2, 0] = 0.58 | |
# x[0, 6, 0] = 0.27 | |
# x[0, 58, 0] = 1. | |
# x[0, -1, 0] = 0.05 | |
# # x[0, 0, 0] = 0.01 | |
# # x[0, -1, 0] = 0.01 | |
# # x[0, -20, 0] = 0.01 | |
# # x[0, -100, 0] = 0.01 | |
# # x[0, -50, 0] = 0.01 | |
# # x[0, -120, 0] = 0.01 | |
# # import math | |
# # for i in range(35, 240, 2): | |
# # x[0, i, 0] = max(0.01, math.exp(-0.01*i) / 10) | |
# # import matplotlib.pyplot as plt | |
# # plt.plot(x[0, :, 0].detach().cpu().numpy()) | |
# # plt.show() | |
t_m = x == 0 # x != 0, 1 for observed, 0 for missing, bool tensor | |
# # | |
if test: | |
t_m = t_m.type_as(x) | |
binary_mask = t_m.clone() | |
binary_mask[binary_mask > 0] = 1 | |
else: | |
binary_mask = t_m | |
# x = x * 2 - 1 | |
x, t_m = x.to(self.device), t_m.to(self.device) | |
binary_mask = binary_mask.to(self.device) | |
if sampling_steps == self.model.num_timesteps: | |
print("normal sampling") | |
sample = self.ema.ema_model.sample_infill( | |
shape=x.shape, | |
target=x * t_m, | |
partial_mask=t_m, | |
model_kwargs=model_kwargs, | |
) | |
# x: partially noise : (batch_size, seq_length, feature_dim) | |
else: | |
print("fast sampling") | |
if test: | |
sample = self.ema.ema_model.fast_sample_infill_float_mask( | |
shape=x.shape, | |
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
partial_mask=t_m, | |
model_kwargs=model_kwargs, | |
sampling_timesteps=sampling_steps, | |
) | |
else: | |
sample = self.ema.ema_model.fast_sample_infill( | |
shape=x.shape, | |
target=x * t_m, | |
partial_mask=t_m, | |
model_kwargs=model_kwargs, | |
sampling_timesteps=sampling_steps, | |
) | |
samples = np.row_stack([samples, sample.detach().cpu().numpy()]) | |
reals = np.row_stack([reals, x.detach().cpu().numpy()]) | |
masks = np.row_stack([masks, t_m.detach().cpu().numpy()]) | |
break | |
if self.logger is not None: | |
self.logger.log_info( | |
"Imputation done, time: {:.2f}".format(time.time() - tic) | |
) | |
return samples, reals, masks | |
# return samples | |