| import tqdm |
| import os |
| import textwrap |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from torch.utils.data import DataLoader |
| import warnings |
| from torch.utils.data import Dataset |
| from sklearn.preprocessing import StandardScaler |
|
|
| warnings.filterwarnings("ignore") |
|
|
| from .time_rcd.dataset import ChatTSTimeRCDPretrainDataset |
| from .time_rcd.TimeRCD_pretrain_multi import TimeSeriesPretrainModel, create_random_mask, collate_fn, test_collate_fn |
| from .time_rcd.time_rcd_config import TimeRCDConfig, default_config |
| from utils.dataset import TimeRCDDataset |
|
|
| class TimeRCDPretrainTester: |
| """Tester class for visualizing pretrained model results.""" |
|
|
| def __init__(self, checkpoint_path: str, config: TimeRCDConfig): |
| self.config = config |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.win_size = config.win_size |
| self.batch_size = config.batch_size |
| |
| self.model = TimeSeriesPretrainModel(config).to(self.device) |
| self.load_checkpoint(checkpoint_path) |
| self.model.eval() |
|
|
| print(f"Model loaded on device: {self.device}") |
|
|
| def load_checkpoint(self, checkpoint_path: str): |
| """Load model from checkpoint.""" |
| if not os.path.exists(checkpoint_path): |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
| checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
| |
| if 'model_state_dict' in checkpoint: |
| state_dict = checkpoint['model_state_dict'] |
| else: |
| state_dict = checkpoint |
|
|
| |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| if key.startswith('module.'): |
| new_key = key[7:] |
| else: |
| new_key = key |
| new_state_dict[new_key] = value |
|
|
| self.model.load_state_dict(new_state_dict) |
| print(f"Successfully loaded checkpoint from {checkpoint_path}") |
|
|
| def predict(self, batch): |
| """Run inference on a batch.""" |
| with torch.no_grad(): |
| |
| time_series = batch['time_series'].to(self.device) |
| normal_time_series = batch['normal_time_series'].to(self.device) |
| masked_time_series = batch['masked_time_series'].to(self.device) |
| attribute = batch['attribute'] |
| batch_size, seq_len, num_features = time_series.shape |
|
|
| |
| time_series = (time_series - time_series.mean(dim=1, keepdim=True)) / (time_series.std(dim=1, keepdim=True) + 1e-8) |
| masked_time_series = (masked_time_series - masked_time_series.mean(dim=1, keepdim=True)) / (masked_time_series.std(dim=1, keepdim=True) + 1e-8) |
|
|
| mask = batch['mask'].to(self.device) |
| labels = batch['labels'].to(self.device) |
| attention_mask = batch['attention_mask'].to(self.device) |
|
|
| |
| local_embeddings = self.model( |
| time_series=time_series, |
| mask=attention_mask) |
|
|
| |
| reconstructed = self.model.reconstruction_head(local_embeddings) |
| reconstructed = reconstructed.view(batch_size, seq_len, num_features) |
|
|
| |
| anomaly_logits = self.model.anomaly_head(local_embeddings) |
| anomaly_logits = torch.mean(anomaly_logits, dim=-2) |
| anomaly_probs = F.softmax(anomaly_logits, dim=-1)[..., 1] |
|
|
| return { |
| 'original': time_series.cpu(), |
| 'normal': normal_time_series.cpu(), |
| 'masked': masked_time_series.cpu(), |
| 'reconstructed': reconstructed.cpu(), |
| 'mask': mask.cpu(), |
| 'anomaly_probs': anomaly_probs.cpu(), |
| 'true_labels': labels.cpu(), |
| 'attention_mask': attention_mask.cpu(), |
| 'attribute': attribute |
| } |
|
|
| def visualize_single_sample(self, results, sample_idx=0, save_path=None): |
| """Visualize results for a single time series sample.""" |
| |
| original = results['original'][sample_idx].squeeze(-1).numpy() |
| normal = results['normal'][sample_idx].squeeze(-1).numpy() |
| masked = results['masked'][sample_idx].squeeze(-1).numpy() |
| reconstructed = results['reconstructed'][sample_idx].squeeze(-1).numpy() |
| mask = results['mask'][sample_idx].numpy().astype(bool) |
| anomaly_probs = results['anomaly_probs'][sample_idx].numpy() |
| true_labels = results['true_labels'][sample_idx].numpy() |
| attention_mask = results['attention_mask'][sample_idx].numpy().astype(bool) |
| attribute = results['attribute'][sample_idx] |
|
|
| |
| valid_length = attention_mask.sum() |
| original = original[:valid_length] |
| normal = normal[:valid_length] |
| masked = masked[:valid_length] |
| reconstructed = reconstructed[:valid_length] |
| mask = mask[:valid_length] |
| anomaly_probs = anomaly_probs[:valid_length] |
| true_labels = true_labels[:valid_length] |
|
|
| |
| time_axis = np.arange(len(original)) |
|
|
| assert original.ndim == normal.ndim == reconstructed.ndim == masked.ndim, "Original, normal, reconstructed, and masked time series must have the same dimensions." |
| if original.ndim == 1: |
| |
| fig, axes = plt.subplots(3, 1, figsize=(15, 12)) |
|
|
| |
| ax1 = axes[0] |
| ax1.plot(time_axis, original, 'b-', label='Original', linewidth=2, alpha=0.8) |
| ax1.plot(time_axis, masked, 'g--', label='Masked Input', linewidth=1.5, alpha=0.7) |
| ax1.plot(time_axis[mask], reconstructed[mask], 'ro', |
| label='Reconstructed', markersize=4, alpha=0.8) |
|
|
| |
| mask_regions = [] |
| in_mask = False |
| start_idx = 0 |
|
|
| for i, is_masked in enumerate(mask): |
| if is_masked and not in_mask: |
| start_idx = i |
| in_mask = True |
| elif not is_masked and in_mask: |
| mask_regions.append((start_idx, i - 1)) |
| in_mask = False |
|
|
| if in_mask: |
| mask_regions.append((start_idx, len(mask) - 1)) |
|
|
| for start, end in mask_regions: |
| ax1.axvspan(start, end, alpha=0.2, color='red', |
| label='Masked Region' if start == mask_regions[0][0] else "") |
|
|
| ax1.set_title('Time Series Reconstruction', fontsize=14, fontweight='bold') |
| ax1.set_xlabel('Time Steps') |
| ax1.set_ylabel('Value') |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
|
|
| |
| ax2 = axes[1] |
| ax2.plot(time_axis, normal, 'g-', label='Normal Time Series', linewidth=1, alpha=0.6) |
| ax2.plot(time_axis, original, 'b-', label='Anomalous Time Series', linewidth=1, alpha=0.6) |
|
|
| |
| anomaly_regions = [] |
| in_anomaly = False |
| start_idx = 0 |
|
|
| for i, is_anomaly in enumerate(true_labels > 0.5): |
| if is_anomaly and not in_anomaly: |
| start_idx = i |
| in_anomaly = True |
| elif not is_anomaly and in_anomaly: |
| anomaly_regions.append((start_idx, i - 1)) |
| in_anomaly = False |
|
|
| if in_anomaly: |
| anomaly_regions.append((start_idx, len(true_labels) - 1)) |
|
|
| for start, end in anomaly_regions: |
| ax2.axvspan(start, end, alpha=0.3, color='red', |
| label='True Anomaly' if start == anomaly_regions[0][0] else "") |
|
|
| |
| ax2_twin = ax2.twinx() |
| ax2_twin.plot(time_axis, anomaly_probs, 'r-', label='Anomaly Probability', |
| linewidth=2, alpha=0.8) |
| ax2_twin.axhline(y=0.5, color='orange', linestyle='--', alpha=0.7, |
| label='Threshold (0.5)') |
| ax2_twin.set_ylabel('Anomaly Probability', color='red') |
| ax2_twin.set_ylim(0, 1) |
|
|
| ax2.set_title('Anomaly Detection Results', fontsize=14, fontweight='bold') |
| ax2.set_xlabel('Time Steps') |
| ax2.set_ylabel('Time Series Value', color='blue') |
|
|
| |
| lines1, labels1 = ax2.get_legend_handles_labels() |
| lines2, labels2 = ax2_twin.get_legend_handles_labels() |
| ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right') |
| ax2.grid(True, alpha=0.3) |
|
|
| |
| ax3 = axes[2] |
|
|
| |
| if mask.sum() > 0: |
| recon_error = np.abs(original[mask] - reconstructed[mask]) |
| ax3.bar(np.arange(len(recon_error)), recon_error, |
| alpha=0.7, color='orange', label='Reconstruction Error') |
| ax3.set_title('Reconstruction Error (Masked Regions Only)', |
| fontsize=14, fontweight='bold') |
| ax3.set_xlabel('Masked Time Step Index') |
| ax3.set_ylabel('Absolute Error') |
| ax3.legend() |
| ax3.grid(True, alpha=0.3) |
| else: |
| ax3.text(0.5, 0.5, 'No masked regions in this sample', |
| ha='center', va='center', transform=ax3.transAxes, fontsize=12) |
| ax3.set_title('Reconstruction Error', fontsize=14, fontweight='bold') |
|
|
| plt.tight_layout() |
|
|
| if save_path: |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
| plt.show() |
|
|
| elif original.ndim == 2: |
| _, num_features = original.shape |
|
|
| fig_height = 4 * num_features + 2 |
| fig, axes = plt.subplots(num_features, 1, figsize=(16, fig_height)) |
| plt.subplots_adjust(top=0.85, hspace=0.2, left=0.08, right=0.92, bottom=0.08) |
|
|
| anomaly_regions = [] |
| in_anomaly = False |
| start_idx = 0 |
| for i, is_anomaly in enumerate(true_labels > 0.5): |
| if is_anomaly and not in_anomaly: |
| start_idx = i |
| in_anomaly = True |
| elif not is_anomaly and in_anomaly: |
| anomaly_regions.append((start_idx, i - 1)) |
| in_anomaly = False |
| if in_anomaly: |
| anomaly_regions.append((start_idx, len(true_labels) - 1)) |
|
|
| for feature_idx in range(num_features): |
| ax = axes[feature_idx] |
| ax.plot(time_axis, original[:, feature_idx], 'b-', |
| linewidth=1, label=f'Anomalous Time Series', alpha=0.8) |
| ax.plot(time_axis, normal[:, feature_idx], 'g-', |
| linewidth=1, label='Normal Time Series', alpha=0.8) |
| y_min, y_max = ax.get_ylim() |
| shift = y_max - y_min |
| ax.set_ylim(y_min - shift, y_max) |
|
|
| for start, end in anomaly_regions: |
| if start == end: |
| ax.axvspan(start - 0.5, start + 0.5, alpha=0.3, color='grey', |
| label='True Anomaly Region' if start == anomaly_regions[0][ |
| 0] and feature_idx == 0 else "") |
| else: |
| ax.axvspan(start, end, alpha=0.3, color='grey', |
| label='True Anomaly Region' if start == anomaly_regions[0][ |
| 0] and feature_idx == 0 else "") |
|
|
| ax2 = ax.twinx() |
| ax2.plot(time_axis, anomaly_probs, 'r-', linewidth=1, |
| label='Anomaly Score', alpha=0.9) |
| ax2.set_ylim(0, 1.5) |
| ax2.set_ylabel('Anomaly Score', fontsize=12) |
| ax.set_ylabel(f'Value', fontsize=12) |
| if feature_idx == num_features - 1: |
| ax.set_xlabel('Time Steps', fontsize=12) |
| else: |
| ax.set_xticklabels([]) |
|
|
| ax.set_title(f'Feature {feature_idx} - Time Series & Anomaly Score', |
| fontsize=16, pad=10) |
| ax.grid(True, alpha=0.3) |
|
|
| if feature_idx == 0: |
| lines1, labels1 = ax.get_legend_handles_labels() |
| lines2, labels2 = ax2.get_legend_handles_labels() |
| ax.legend(lines1 + lines2, labels1 + labels2, |
| loc='upper right', bbox_to_anchor=(0.98, 0.98), fontsize=14) |
|
|
| anomalies = [] |
| isendo = attribute['is_endogenous'] |
| edges = attribute['dag'] |
| for idx, item in enumerate(attribute['attribute_list']): |
| for k, v in item['anomalies'].items(): |
| anomalies.append((f"feature_{idx}_{k[2:]}", v)) |
| anomalies_str = ', '.join([f"{k}: {v}" for k, v in anomalies]) |
| wrap_width = 100 |
| wrapped_anomalies = textwrap.fill(f"Anomalies: {anomalies_str}", width=wrap_width) |
| wrapped_edges = textwrap.fill(f"Edges: {str(edges)}", width=wrap_width) |
| title = f"Multivariate Time Series Visualization\n{isendo}_{wrapped_anomalies}\n{wrapped_edges}" |
| fig.suptitle(title, fontsize=22, y=0.95, ha='center', va='top') |
|
|
| if save_path: |
| plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white') |
| plt.show() |
|
|
|
|
| else: |
| raise ValueError("Unsupported original data shape: {}".format(original.shape)) |
|
|
| def test_model(self, data_path: str, filename: str, num_samples: int = 5, save_dir: str = None, |
| max_test_data: int = 100): |
| """Test the model on a dataset and visualize results.""" |
| |
| full_test_dataset = ChatTSTimeRCDPretrainDataset(data_path, filename, split="test", train_ratio=0) |
| print(f'Length of dataset: {len(full_test_dataset)}') |
|
|
| |
| if len(full_test_dataset) > max_test_data: |
| indices = torch.randperm(len(full_test_dataset))[:max_test_data].tolist() |
| test_dataset = torch.utils.data.Subset(full_test_dataset, indices) |
| print("random") |
| else: |
| test_dataset = full_test_dataset |
|
|
| |
| vis_loader = DataLoader( |
| test_dataset, |
| batch_size=1, |
| shuffle=False, |
| collate_fn=collate_fn, |
| num_workers=0 |
| ) |
|
|
| |
| num_visualize = min(num_samples, len(test_dataset)) |
|
|
| vis_iter = iter(vis_loader) |
|
|
| for i in range(num_visualize): |
| try: |
| vis_batch = next(vis_iter) |
|
|
| |
| vis_results = self.predict(vis_batch) |
|
|
| save_path = None |
| if save_dir: |
| os.makedirs(save_dir, exist_ok=True) |
| save_path = os.path.join(save_dir, f"sample_{i + 1}_results.png") |
|
|
| self.visualize_single_sample(vis_results, sample_idx=0, save_path=save_path) |
|
|
| except StopIteration: |
| break |
|
|
| def zero_shot(self, data): |
| """Run zero-shot inference on the provided data.""" |
| if len(data) <= self.win_size: |
| self.win_size = len(data) |
|
|
| test_loader = DataLoader( |
| dataset=TimeRCDDataset(data, window_size=self.win_size, stride=self.win_size, normalize=True), |
| batch_size=self.batch_size, |
| collate_fn=test_collate_fn, |
| num_workers=0, |
| shuffle=False,) |
|
|
| loop = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=True) |
| scores = [] |
| logits = [] |
| with torch.no_grad(): |
| for i, batch in loop: |
| |
| time_series = batch['time_series'].to(self.device) |
| |
| |
| batch_size, seq_len, num_features = time_series.shape |
| |
| attention_mask = batch['attention_mask'].to(self.device) |
| |
| |
| |
| local_embeddings = self.model( |
| time_series=time_series, |
| mask=attention_mask) |
|
|
| |
| anomaly_logits = self.model.anomaly_head(local_embeddings) |
| anomaly_logits = torch.mean(anomaly_logits, dim=-2) |
| anomaly_probs = F.softmax(anomaly_logits, dim=-1)[..., 1] |
| scores.append(anomaly_probs.cpu().numpy()) |
| logit = anomaly_logits[..., 1] - anomaly_logits[..., 0] |
| logits.append(logit.cpu().numpy()) |
| return scores, logits |
|
|
| def evaluate(self, time_series, mask): |
| with torch.no_grad(): |
| time_series = time_series.to(self.device) |
| mask = mask.to(self.device) |
| local_embeddings = self.model(time_series = time_series, mask = mask) |
|
|
| reconstructed = self.model.reconstruction_head(local_embeddings) |
| reconstructed = reconstructed.squeeze(-1) |
|
|
| mask_expand = mask.unsqueeze(-1).expand(-1, -1, reconstructed.shape[-1]) |
|
|
| anomaly_probs = ((reconstructed - time_series) ** 2)[mask_expand] |
| return anomaly_probs, reconstructed |
|
|
|
|
| def zero_shot_reconstruct(self, data, visualize=True, data_index=None): |
| """Run zero-shot inference on the provided data.""" |
| if len(data) <= self.win_size: |
| self.win_size = len(data) |
|
|
| test_loader = DataLoader( |
| dataset=Dataset_UCR(data, window_size=self.win_size), |
| batch_size=self.batch_size, |
| |
| num_workers=0, |
| shuffle=False, ) |
|
|
| loop = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=True) |
| scores = [] |
| with torch.no_grad(): |
| for i, (x, mask) in loop: |
| |
| print("Here is the batch type: ", type(x)) |
| print("Shape: ", np.array(x).shape) |
| time_series = torch.tensor(x, dtype=torch.float32).to(self.device) |
| mask_tensor = torch.tensor(mask, dtype=torch.bool).to(self.device) |
| |
| |
| |
| score, reconstructed = self.evaluate(time_series, mask_tensor) |
|
|
| scores.append(score) |
|
|
| |
| if visualize: |
| self.visualize_reconstruction(original=time_series[0].cpu().numpy(), |
| reconstructed=reconstructed.cpu().numpy(), |
| mask=mask_tensor[0].cpu().numpy(), |
| scores=score.cpu().numpy(), |
| save_path=f"/home/lihaoyang/Huawei/TSB-AD/Synthetic/random_mask_anomaly_head_Time_RCD_Reconstruction_5000/plot/", |
| index=data_index) |
|
|
| return scores |
|
|
| def visualize_reconstruction(self, original, reconstructed, mask, scores, index, save_path=None): |
| """Visualize reconstruction results for a single sample.""" |
| import matplotlib.pyplot as plt |
|
|
| seq_len = len(original) |
| time_axis = np.arange(seq_len) |
|
|
| |
| original = original.squeeze() |
| reconstructed = reconstructed.squeeze(0).squeeze(-1) |
| scores = scores.squeeze() |
|
|
| fig, axes = plt.subplots(2, 1, figsize=(15, 10)) |
|
|
| |
| ax1 = axes[0] |
| ax1.plot(time_axis, original, 'b-', label='Original', linewidth=2, alpha=0.8) |
| ax1.plot(time_axis, reconstructed, 'r--', label='Reconstructed', linewidth=2, alpha=0.8) |
|
|
| |
| mask_regions = [] |
| in_mask = False |
| start_idx = 0 |
| for i, is_masked in enumerate(mask): |
| if is_masked and not in_mask: |
| start_idx = i |
| in_mask = True |
| elif not is_masked and in_mask: |
| mask_regions.append((start_idx, i - 1)) |
| in_mask = False |
| if in_mask: |
| mask_regions.append((start_idx, len(mask) - 1)) |
|
|
| for start, end in mask_regions: |
| ax1.axvspan(start, end, alpha=0.2, color='red', |
| label='Masked Region' if start == mask_regions[0][0] else "") |
|
|
| ax1.set_title('Time Series Reconstruction', fontsize=14, fontweight='bold') |
| ax1.set_xlabel('Time Steps') |
| ax1.set_ylabel('Value') |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
|
|
| |
| ax2 = axes[1] |
| recon_error = np.abs(original - reconstructed) |
| ax2.plot(time_axis, recon_error, 'g-', label='Reconstruction Error', linewidth=2, alpha=0.8) |
|
|
| |
| if len(scores) == mask.sum(): |
| |
| full_scores = np.zeros(seq_len) |
| full_scores[mask] = scores |
| ax2_twin = ax2.twinx() |
| ax2_twin.plot(time_axis, full_scores, 'orange', label='Anomaly Scores', linewidth=1.5, alpha=0.7) |
| ax2_twin.set_ylabel('Anomaly Score', color='orange') |
| ax2_twin.legend(loc='upper right') |
|
|
| ax2.set_title('Reconstruction Error', fontsize=14, fontweight='bold') |
| ax2.set_xlabel('Time Steps') |
| ax2.set_ylabel('Absolute Error') |
| ax2.legend() |
| ax2.grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
|
|
| if save_path: |
| if not os.path.exists(save_path): |
| os.makedirs(save_path, exist_ok=True) |
| save_path = os.path.join(save_path, f"reconstruction_sample_{index}_results.png") |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| print("Visualization saved to: ", save_path) |
|
|
| |
|
|
|
|
| class Dataset_UCR(Dataset): |
| def __init__(self, data, window_size: int = 1000): |
| super().__init__() |
| self.data = data.reshape(-1, 1) if len(data.shape) == 1 else data |
| self.window_size = window_size |
| self._load_data() |
| self._process_windows() |
| |
| def _load_data(self): |
| |
| |
| |
| train_data = self.data |
| scaler = StandardScaler() |
| train_data = scaler.fit_transform(train_data) |
| self.raw_test = scaler.transform(self.data) |
|
|
| def _process_windows(self): |
| if len(self.raw_test) <= self.window_size: |
| self.test = np.expand_dims(self.raw_test, axis=0) |
| |
| self.mask = np.expand_dims(np.ones(len(self.raw_test), dtype=bool), axis=0) |
| else: |
| self.raw_masks = np.ones(len(self.raw_test), dtype=bool) |
| padding = self.window_size - (len(self.raw_test) % self.window_size) |
| if padding < self.window_size: |
| self.raw_test = np.pad(self.raw_test, ((0, padding), (0, 0)), mode='constant') |
| |
| self.raw_masks = np.pad(self.raw_masks, (0, padding), mode='constant') |
| self.test = self.raw_test.reshape(-1, self.window_size, self.raw_test.shape[1]) |
| |
| self.mask = self.raw_masks.reshape(-1, self.window_size) |
| assert self.test.shape[0] == self.test_labels.shape[0] == self.mask.shape[0], "Inconsistent window sizes" |
|
|
| def __len__(self): |
| return len(self.test) |
|
|
| def __getitem__(self, index): |
| return np.float32(self.test[index]), self.mask[index] |