TSEditor / models /ts2vec /ts2vec.py
PeterYu's picture
update
2875fe6
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from .models.encoder import TSEncoder
from .models.losses import hierarchical_contrastive_loss
from .utils import (
take_per_row,
split_with_nan,
centerize_vary_length_series,
torch_pad_nan,
)
class TS2Vec:
"""The TS2Vec model"""
def __init__(
self,
input_dims,
output_dims=320,
hidden_dims=64,
depth=10,
device="cuda",
lr=0.001,
batch_size=16,
max_train_length=None,
temporal_unit=0,
after_iter_callback=None,
after_epoch_callback=None,
):
"""Initialize a TS2Vec model.
Args:
input_dims (int): The input dimension. For a univariate time series, this should be set to 1.
output_dims (int): The representation dimension.
hidden_dims (int): The hidden dimension of the encoder.
depth (int): The number of hidden residual blocks in the encoder.
device (int): The gpu used for training and inference.
lr (int): The learning rate.
batch_size (int): The batch size.
max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than <max_train_length>, it would be cropped into some sequences, each of which has a length less than <max_train_length>.
temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory.
after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration.
after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch.
"""
super().__init__()
self.device = device
self.lr = lr
self.batch_size = batch_size
self.max_train_length = max_train_length
self.temporal_unit = temporal_unit
self._net = TSEncoder(
input_dims=input_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
depth=depth,
).to(self.device)
self.net = torch.optim.swa_utils.AveragedModel(self._net)
self.net.update_parameters(self._net)
self.after_iter_callback = after_iter_callback
self.after_epoch_callback = after_epoch_callback
self.n_epochs = 0
self.n_iters = 0
def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False):
"""Training the TS2Vec model.
Args:
train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops.
n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise.
verbose (bool): Whether to print the training loss after each epoch.
Returns:
loss_log: a list containing the training losses on each epoch.
"""
assert train_data.ndim == 3
if n_iters is None and n_epochs is None:
n_iters = (
200 if train_data.size <= 100000 else 600
) # default param for n_iters
if self.max_train_length is not None:
sections = train_data.shape[1] // self.max_train_length
if sections >= 2:
train_data = np.concatenate(
split_with_nan(train_data, sections, axis=1), axis=0
)
temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0)
if temporal_missing[0] or temporal_missing[-1]:
train_data = centerize_vary_length_series(train_data)
train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)]
train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float))
train_loader = DataLoader(
train_dataset,
batch_size=min(self.batch_size, len(train_dataset)),
shuffle=True,
drop_last=True,
)
optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr)
loss_log = []
while True:
if n_epochs is not None and self.n_epochs >= n_epochs:
break
cum_loss = 0
n_epoch_iters = 0
interrupted = False
for batch in train_loader:
if n_iters is not None and self.n_iters >= n_iters:
interrupted = True
break
x = batch[0]
if (
self.max_train_length is not None
and x.size(1) > self.max_train_length
):
window_offset = np.random.randint(
x.size(1) - self.max_train_length + 1
)
x = x[:, window_offset : window_offset + self.max_train_length]
x = x.to(self.device)
ts_l = x.size(1)
crop_l = np.random.randint(
low=2 ** (self.temporal_unit + 1), high=ts_l + 1
)
crop_left = np.random.randint(ts_l - crop_l + 1)
crop_right = crop_left + crop_l
crop_eleft = np.random.randint(crop_left + 1)
crop_eright = np.random.randint(low=crop_right, high=ts_l + 1)
crop_offset = np.random.randint(
low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0)
)
optimizer.zero_grad()
out1 = self._net(
take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)
)
out1 = out1[:, -crop_l:]
out2 = self._net(
take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)
)
out2 = out2[:, :crop_l]
loss = hierarchical_contrastive_loss(
out1, out2, temporal_unit=self.temporal_unit
)
loss.backward()
optimizer.step()
self.net.update_parameters(self._net)
cum_loss += loss.item()
n_epoch_iters += 1
self.n_iters += 1
if self.after_iter_callback is not None:
self.after_iter_callback(self, loss.item())
if interrupted:
break
cum_loss /= n_epoch_iters
loss_log.append(cum_loss)
if verbose:
print(f"Epoch #{self.n_epochs}: loss={cum_loss}")
self.n_epochs += 1
if self.after_epoch_callback is not None:
self.after_epoch_callback(self, cum_loss)
return loss_log
def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None):
out = self.net(x.to(self.device, non_blocking=True), mask)
if encoding_window == "full_series":
if slicing is not None:
out = out[:, slicing]
out = F.max_pool1d(
out.transpose(1, 2),
kernel_size=out.size(1),
).transpose(1, 2)
elif isinstance(encoding_window, int):
out = F.max_pool1d(
out.transpose(1, 2),
kernel_size=encoding_window,
stride=1,
padding=encoding_window // 2,
).transpose(1, 2)
if encoding_window % 2 == 0:
out = out[:, :-1]
if slicing is not None:
out = out[:, slicing]
elif encoding_window == "multiscale":
p = 0
reprs = []
while (1 << p) + 1 < out.size(1):
t_out = F.max_pool1d(
out.transpose(1, 2),
kernel_size=(1 << (p + 1)) + 1,
stride=1,
padding=1 << p,
).transpose(1, 2)
if slicing is not None:
t_out = t_out[:, slicing]
reprs.append(t_out)
p += 1
out = torch.cat(reprs, dim=-1)
else:
if slicing is not None:
out = out[:, slicing]
return out.cpu()
def encode(
self,
data,
mask=None,
encoding_window=None,
casual=False,
sliding_length=None,
sliding_padding=0,
batch_size=None,
):
"""Compute representations using the model.
Args:
data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'.
encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size.
casual (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp.
sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series.
sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows.
batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training.
Returns:
repr: The representations for data.
"""
assert self.net is not None, "please train or load a net first"
assert data.ndim == 3
if batch_size is None:
batch_size = self.batch_size
n_samples, ts_l, _ = data.shape
org_training = self.net.training
self.net.eval()
dataset = TensorDataset(torch.from_numpy(data).to(torch.float))
loader = DataLoader(dataset, batch_size=batch_size)
with torch.no_grad():
output = []
for batch in loader:
x = batch[0]
if sliding_length is not None:
reprs = []
if n_samples < batch_size:
calc_buffer = []
calc_buffer_l = 0
for i in range(0, ts_l, sliding_length):
l = i - sliding_padding
r = i + sliding_length + (sliding_padding if not casual else 0)
x_sliding = torch_pad_nan(
x[:, max(l, 0) : min(r, ts_l)],
left=-l if l < 0 else 0,
right=r - ts_l if r > ts_l else 0,
dim=1,
)
if n_samples < batch_size:
if calc_buffer_l + n_samples > batch_size:
out = self._eval_with_pooling(
torch.cat(calc_buffer, dim=0),
mask,
slicing=slice(
sliding_padding,
sliding_padding + sliding_length,
),
encoding_window=encoding_window,
)
reprs += torch.split(out, n_samples)
calc_buffer = []
calc_buffer_l = 0
calc_buffer.append(x_sliding)
calc_buffer_l += n_samples
else:
out = self._eval_with_pooling(
x_sliding,
mask,
slicing=slice(
sliding_padding, sliding_padding + sliding_length
),
encoding_window=encoding_window,
)
reprs.append(out)
if n_samples < batch_size:
if calc_buffer_l > 0:
out = self._eval_with_pooling(
torch.cat(calc_buffer, dim=0),
mask,
slicing=slice(
sliding_padding, sliding_padding + sliding_length
),
encoding_window=encoding_window,
)
reprs += torch.split(out, n_samples)
calc_buffer = []
calc_buffer_l = 0
out = torch.cat(reprs, dim=1)
if encoding_window == "full_series":
out = F.max_pool1d(
out.transpose(1, 2).contiguous(),
kernel_size=out.size(1),
).squeeze(1)
else:
out = self._eval_with_pooling(
x, mask, encoding_window=encoding_window
)
if encoding_window == "full_series":
out = out.squeeze(1)
output.append(out)
output = torch.cat(output, dim=0)
self.net.train(org_training)
return output.numpy()
def save(self, fn):
"""Save the model to a file.
Args:
fn (str): filename.
"""
torch.save(self.net.state_dict(), fn)
def load(self, fn):
"""Load the model from a file.
Args:
fn (str): filename.
"""
state_dict = torch.load(fn, map_location=self.device)
self.net.load_state_dict(state_dict)