File size: 5,254 Bytes
d03866e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import os
import pickle
from typing import Dict, List, Union, Optional, Tuple
from pathlib import Path


class ChatTSTimeRCDPretrainDataset(Dataset):
    def __init__(self,
                 dataset_dir: str,
                 filename: str,
                 split: str = 'train',
                 train_ratio: float = 0.95,
                 seed: int = 42):
        file_path = os.path.join(dataset_dir, filename)
        with open(file_path, 'rb') as f:
            dataset = pickle.load(f)
        random.seed(seed)
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        num_train = int(len(dataset) * train_ratio)
        if split == 'train':
            selected_indices = indices[:num_train]
        elif split == 'test':
            selected_indices = indices[num_train:]
        else:
            raise ValueError("split must be 'train' or 'test'")
        self.data = [dataset[i] for i in selected_indices]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        time_series = torch.tensor(sample['time_series'], dtype=torch.float32)
        normal_time_series = torch.tensor(sample['normal_time_series'], dtype=torch.float32)
        labels = torch.tensor(sample['labels'], dtype=torch.long)
        attribute = sample['attribute']
        return time_series, normal_time_series, labels, attribute


class ChatTSTimeRCDQADataset(Dataset):
    """Dataset class for time series anomaly detection with QA pairs.

    This dataset loads time series data and corresponding question-answer pairs
    for anomaly detection tasks. It supports train/val split and efficient loading
    of series data from the time_rcd_datasets format.

    Attributes:
        split (str): Dataset split, either 'train' or 'val'
        series_dir (Path): Directory containing series JSON files
        metadata (Dict): Dataset metadata loaded from metadata.json
        series_files (List[str]): List of series file paths
        window_size_range (Tuple[int, int]): Range of window sizes used in the dataset
    """

    def __init__(
            self,
            dataset_dir: str,
            split: str = 'train',
            train_ratio: float = 0.95,
            seed: int = 42,
            cache_size: int = 1000
    ) -> None:
        """Initialize the dataset.

        Args:
            dataset_dir: Path to the dataset directory containing metadata.json and series/
            split: Dataset split, either 'train' or 'val'
            train_ratio: Ratio of training samples (default: 0.8)
            seed: Random seed for reproducibility (default: 42)
            cache_size: Number of series files to keep in memory (default: 1000)
        """
        self.split = split
        self.series_dir = Path(dataset_dir) / 'series'

        # Get all series files and shuffle them
        self.series_files = sorted(self.series_dir.glob('series_*.json'))
        random.seed(seed)
        random.shuffle(self.series_files)

        # Split into train/val
        split_idx = int(len(self.series_files) * train_ratio)
        self.series_files = self.series_files[:split_idx] if split == 'train' else self.series_files[split_idx:]

        # Initialize LRU cache for series data
        self._cache = {}
        self._cache_size = cache_size
        self._cache_order = []

    def _load_series(self, file_path: Path) -> Dict:
        """Load a series file with caching.

        Args:
            file_path: Path to the series JSON file

        Returns:
            Dictionary containing the series data
        """
        if file_path in self._cache:
            # Update cache order
            self._cache_order.remove(file_path)
            self._cache_order.append(file_path)
            return self._cache[file_path]

        # Load new file
        with open(file_path, 'r') as f:
            data = json.load(f)

        # Update cache
        if len(self._cache) >= self._cache_size:
            # Remove oldest item
            oldest = self._cache_order.pop(0)
            del self._cache[oldest]

        self._cache[file_path] = data
        self._cache_order.append(file_path)
        return data

    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(self.series_files)

    def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
        """Get a sample from the dataset.

        Args:
            idx: Index of the sample to retrieve

        Returns:
            Dictionary containing:
                - time_series: Time series data as torch.Tensor
                - windows: List of window data containing QA pairs
                - sample_id: Unique identifier for the sample
        """
        file_path = self.series_files[idx]
        data = self._load_series(file_path)

        # Convert time series to tensor
        time_series = np.array(data['original_data']['time_series'])
        time_series_tensor = torch.FloatTensor(time_series)

        return {
            'time_series': time_series_tensor,
            'analysis_data': data['windows']
        }