Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import TensorDataset, DataLoader | |
import numpy as np | |
from .models.encoder import TSEncoder | |
from .models.losses import hierarchical_contrastive_loss | |
from .utils import ( | |
take_per_row, | |
split_with_nan, | |
centerize_vary_length_series, | |
torch_pad_nan, | |
) | |
class TS2Vec: | |
"""The TS2Vec model""" | |
def __init__( | |
self, | |
input_dims, | |
output_dims=320, | |
hidden_dims=64, | |
depth=10, | |
device="cuda", | |
lr=0.001, | |
batch_size=16, | |
max_train_length=None, | |
temporal_unit=0, | |
after_iter_callback=None, | |
after_epoch_callback=None, | |
): | |
"""Initialize a TS2Vec model. | |
Args: | |
input_dims (int): The input dimension. For a univariate time series, this should be set to 1. | |
output_dims (int): The representation dimension. | |
hidden_dims (int): The hidden dimension of the encoder. | |
depth (int): The number of hidden residual blocks in the encoder. | |
device (int): The gpu used for training and inference. | |
lr (int): The learning rate. | |
batch_size (int): The batch size. | |
max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than <max_train_length>, it would be cropped into some sequences, each of which has a length less than <max_train_length>. | |
temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory. | |
after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration. | |
after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch. | |
""" | |
super().__init__() | |
self.device = device | |
self.lr = lr | |
self.batch_size = batch_size | |
self.max_train_length = max_train_length | |
self.temporal_unit = temporal_unit | |
self._net = TSEncoder( | |
input_dims=input_dims, | |
output_dims=output_dims, | |
hidden_dims=hidden_dims, | |
depth=depth, | |
).to(self.device) | |
self.net = torch.optim.swa_utils.AveragedModel(self._net) | |
self.net.update_parameters(self._net) | |
self.after_iter_callback = after_iter_callback | |
self.after_epoch_callback = after_epoch_callback | |
self.n_epochs = 0 | |
self.n_iters = 0 | |
def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): | |
"""Training the TS2Vec model. | |
Args: | |
train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. | |
n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops. | |
n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise. | |
verbose (bool): Whether to print the training loss after each epoch. | |
Returns: | |
loss_log: a list containing the training losses on each epoch. | |
""" | |
assert train_data.ndim == 3 | |
if n_iters is None and n_epochs is None: | |
n_iters = ( | |
200 if train_data.size <= 100000 else 600 | |
) # default param for n_iters | |
if self.max_train_length is not None: | |
sections = train_data.shape[1] // self.max_train_length | |
if sections >= 2: | |
train_data = np.concatenate( | |
split_with_nan(train_data, sections, axis=1), axis=0 | |
) | |
temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0) | |
if temporal_missing[0] or temporal_missing[-1]: | |
train_data = centerize_vary_length_series(train_data) | |
train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] | |
train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float)) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=min(self.batch_size, len(train_dataset)), | |
shuffle=True, | |
drop_last=True, | |
) | |
optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr) | |
loss_log = [] | |
while True: | |
if n_epochs is not None and self.n_epochs >= n_epochs: | |
break | |
cum_loss = 0 | |
n_epoch_iters = 0 | |
interrupted = False | |
for batch in train_loader: | |
if n_iters is not None and self.n_iters >= n_iters: | |
interrupted = True | |
break | |
x = batch[0] | |
if ( | |
self.max_train_length is not None | |
and x.size(1) > self.max_train_length | |
): | |
window_offset = np.random.randint( | |
x.size(1) - self.max_train_length + 1 | |
) | |
x = x[:, window_offset : window_offset + self.max_train_length] | |
x = x.to(self.device) | |
ts_l = x.size(1) | |
crop_l = np.random.randint( | |
low=2 ** (self.temporal_unit + 1), high=ts_l + 1 | |
) | |
crop_left = np.random.randint(ts_l - crop_l + 1) | |
crop_right = crop_left + crop_l | |
crop_eleft = np.random.randint(crop_left + 1) | |
crop_eright = np.random.randint(low=crop_right, high=ts_l + 1) | |
crop_offset = np.random.randint( | |
low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0) | |
) | |
optimizer.zero_grad() | |
out1 = self._net( | |
take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft) | |
) | |
out1 = out1[:, -crop_l:] | |
out2 = self._net( | |
take_per_row(x, crop_offset + crop_left, crop_eright - crop_left) | |
) | |
out2 = out2[:, :crop_l] | |
loss = hierarchical_contrastive_loss( | |
out1, out2, temporal_unit=self.temporal_unit | |
) | |
loss.backward() | |
optimizer.step() | |
self.net.update_parameters(self._net) | |
cum_loss += loss.item() | |
n_epoch_iters += 1 | |
self.n_iters += 1 | |
if self.after_iter_callback is not None: | |
self.after_iter_callback(self, loss.item()) | |
if interrupted: | |
break | |
cum_loss /= n_epoch_iters | |
loss_log.append(cum_loss) | |
if verbose: | |
print(f"Epoch #{self.n_epochs}: loss={cum_loss}") | |
self.n_epochs += 1 | |
if self.after_epoch_callback is not None: | |
self.after_epoch_callback(self, cum_loss) | |
return loss_log | |
def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None): | |
out = self.net(x.to(self.device, non_blocking=True), mask) | |
if encoding_window == "full_series": | |
if slicing is not None: | |
out = out[:, slicing] | |
out = F.max_pool1d( | |
out.transpose(1, 2), | |
kernel_size=out.size(1), | |
).transpose(1, 2) | |
elif isinstance(encoding_window, int): | |
out = F.max_pool1d( | |
out.transpose(1, 2), | |
kernel_size=encoding_window, | |
stride=1, | |
padding=encoding_window // 2, | |
).transpose(1, 2) | |
if encoding_window % 2 == 0: | |
out = out[:, :-1] | |
if slicing is not None: | |
out = out[:, slicing] | |
elif encoding_window == "multiscale": | |
p = 0 | |
reprs = [] | |
while (1 << p) + 1 < out.size(1): | |
t_out = F.max_pool1d( | |
out.transpose(1, 2), | |
kernel_size=(1 << (p + 1)) + 1, | |
stride=1, | |
padding=1 << p, | |
).transpose(1, 2) | |
if slicing is not None: | |
t_out = t_out[:, slicing] | |
reprs.append(t_out) | |
p += 1 | |
out = torch.cat(reprs, dim=-1) | |
else: | |
if slicing is not None: | |
out = out[:, slicing] | |
return out.cpu() | |
def encode( | |
self, | |
data, | |
mask=None, | |
encoding_window=None, | |
casual=False, | |
sliding_length=None, | |
sliding_padding=0, | |
batch_size=None, | |
): | |
"""Compute representations using the model. | |
Args: | |
data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. | |
mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'. | |
encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size. | |
casual (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp. | |
sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series. | |
sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows. | |
batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training. | |
Returns: | |
repr: The representations for data. | |
""" | |
assert self.net is not None, "please train or load a net first" | |
assert data.ndim == 3 | |
if batch_size is None: | |
batch_size = self.batch_size | |
n_samples, ts_l, _ = data.shape | |
org_training = self.net.training | |
self.net.eval() | |
dataset = TensorDataset(torch.from_numpy(data).to(torch.float)) | |
loader = DataLoader(dataset, batch_size=batch_size) | |
with torch.no_grad(): | |
output = [] | |
for batch in loader: | |
x = batch[0] | |
if sliding_length is not None: | |
reprs = [] | |
if n_samples < batch_size: | |
calc_buffer = [] | |
calc_buffer_l = 0 | |
for i in range(0, ts_l, sliding_length): | |
l = i - sliding_padding | |
r = i + sliding_length + (sliding_padding if not casual else 0) | |
x_sliding = torch_pad_nan( | |
x[:, max(l, 0) : min(r, ts_l)], | |
left=-l if l < 0 else 0, | |
right=r - ts_l if r > ts_l else 0, | |
dim=1, | |
) | |
if n_samples < batch_size: | |
if calc_buffer_l + n_samples > batch_size: | |
out = self._eval_with_pooling( | |
torch.cat(calc_buffer, dim=0), | |
mask, | |
slicing=slice( | |
sliding_padding, | |
sliding_padding + sliding_length, | |
), | |
encoding_window=encoding_window, | |
) | |
reprs += torch.split(out, n_samples) | |
calc_buffer = [] | |
calc_buffer_l = 0 | |
calc_buffer.append(x_sliding) | |
calc_buffer_l += n_samples | |
else: | |
out = self._eval_with_pooling( | |
x_sliding, | |
mask, | |
slicing=slice( | |
sliding_padding, sliding_padding + sliding_length | |
), | |
encoding_window=encoding_window, | |
) | |
reprs.append(out) | |
if n_samples < batch_size: | |
if calc_buffer_l > 0: | |
out = self._eval_with_pooling( | |
torch.cat(calc_buffer, dim=0), | |
mask, | |
slicing=slice( | |
sliding_padding, sliding_padding + sliding_length | |
), | |
encoding_window=encoding_window, | |
) | |
reprs += torch.split(out, n_samples) | |
calc_buffer = [] | |
calc_buffer_l = 0 | |
out = torch.cat(reprs, dim=1) | |
if encoding_window == "full_series": | |
out = F.max_pool1d( | |
out.transpose(1, 2).contiguous(), | |
kernel_size=out.size(1), | |
).squeeze(1) | |
else: | |
out = self._eval_with_pooling( | |
x, mask, encoding_window=encoding_window | |
) | |
if encoding_window == "full_series": | |
out = out.squeeze(1) | |
output.append(out) | |
output = torch.cat(output, dim=0) | |
self.net.train(org_training) | |
return output.numpy() | |
def save(self, fn): | |
"""Save the model to a file. | |
Args: | |
fn (str): filename. | |
""" | |
torch.save(self.net.state_dict(), fn) | |
def load(self, fn): | |
"""Load the model from a file. | |
Args: | |
fn (str): filename. | |
""" | |
state_dict = torch.load(fn, map_location=self.device) | |
self.net.load_state_dict(state_dict) | |