File size: 6,023 Bytes
3d75a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import random
from operator import itemgetter

from data_enrich import DataEnrich


class DataLoader:

    label_mapping = {
        'car': 0,
        'walk': 1,
        'bus': 2,
        'train': 3,
        'subway': 4,
        'bike': 5,
        'run': 6,
        'boat': 7,
        'airplane': 8,
        'motorcycle': 9,
        'taxi': 10
    }

    fields_to_feed = ["dist", "speed", "accel", "timedelta", "jerk", "bearing", "bearing_rate"]
    labels_to_remove = ["boat", "motorcycle", "airplane", "run", "bike"]


    def __init__(self, test_ratio=0.2, val_ratio=0.1, batchsize=4, read_from_pickle=True):
        de = DataEnrich()
        self._raw = de.get_enriched_data(read_from_pickle)
        self._test_ratio = test_ratio
        self._val_ratio = val_ratio
        self._batchsize = batchsize

    def _remove_traj_containing_labels(self):
        cleaned = []
        for elem in self._raw:
            if len(elem) == 0:
                continue
            if all(x not in list(elem["label"]) for x in self.labels_to_remove):
                cleaned.append(elem)
        self._raw = cleaned

    def _merge_labels(self, target_label, label_to_remove):
        for elem in self._raw:
            if label_to_remove in list(elem["label"]):
                elem["label"] = elem["label"].replace(to_replace=label_to_remove, value=target_label)

    def _labels_to_int_repr(self):
        for elem in self._raw:
            elem["label"] = elem["label"].apply(lambda x: self.label_mapping[x])

    def _get_split_indices(self, traj):
        train_size = int((1 - self._test_ratio) * len(traj))
        val_size = len(traj) - int((1 - self._val_ratio) * len(traj))

        indices = [x for x in range(len(traj))]

        indices_for_training = random.sample(indices, train_size)
        indices_for_validation = random.sample(indices_for_training, val_size)
        indices_for_training = set(indices_for_training) - set(indices_for_validation)
        indices_for_testing = set(indices) - indices_for_training
        indices_for_testing = list(indices_for_testing)

        return list(indices_for_training), list(indices_for_testing), list(indices_for_validation)

    def _set_splitted_data(self, traj, labels):

        i_train, i_test, i_val = self._get_split_indices(traj)

        random.shuffle(i_train)

        self.test_data = list(itemgetter(*i_test)(traj))
        self.val_data = list(itemgetter(*i_val)(traj))
        self.train_data = list(itemgetter(*i_train)(traj))
        self.test_labels = list(itemgetter(*i_test)(labels))
        self.val_labels = list(itemgetter(*i_val)(labels))
        self.train_labels = list(itemgetter(*i_train)(labels))

    def _split_too_long_traj(self, traj, labels, max_points):
        if len(traj) > max_points*2:
            splitted_traj, splitted_labels = [],[]
            num_subsets = len(traj) // max_points
            print("Splitting trajectory with length ", len(traj), "in ", num_subsets, "trajectories")
            for i in range(num_subsets):
                end_pointer = len(traj)-1 if ((i+1)*max_points)+max_points > len(traj) else (i*max_points)+max_points
                traj_subset = traj[i*max_points:end_pointer]
                labels_subset = labels[i*max_points:end_pointer]
                assert len(traj_subset) == len(labels_subset)
                splitted_traj.append(traj_subset)
                splitted_labels.append(labels_subset)
            return splitted_traj, splitted_labels
        return [traj], [labels]

    def prepare_data(self):
        trajs = []
        labels = []

        self._remove_traj_containing_labels()
        self._merge_labels("car", "taxi")
        self._labels_to_int_repr()

        for elem in self._raw:
            assert len(elem) > 0
            data_ = elem[self.fields_to_feed].values.tolist()
            label_ = elem["label"].values.tolist()
            data_, label_ = self._split_too_long_traj(data_, label_, 350)
            trajs.extend(data_)
            labels.extend(label_)

        self._set_splitted_data(trajs, labels)

    def batches(self):
        for i in range(0, len(self.train_data), self._batchsize):

            if len(self.train_data[i:i + self._batchsize]) < self._batchsize:
                break  # drop last incomplete batch

            labels_sorted = sorted(self.train_labels[i:i + self._batchsize:], key=len, reverse=True)
            train_sorted = sorted(self.train_data[i:i + self._batchsize:], key=len, reverse=True)
            for p in range(len(labels_sorted)):
                    assert len(labels_sorted[p]) == len(train_sorted[p])
            yield train_sorted, labels_sorted

    def val_batches(self):
        for i in range(0, len(self.val_data), self._batchsize):

            if len(self.val_data[i:i + self._batchsize]) < self._batchsize:
                break  # drop last incomplete batch

            labels_sorted = sorted(self.val_labels[i:i + self._batchsize:], key=len, reverse=True)
            val_sorted = sorted(self.val_data[i:i + self._batchsize:], key=len, reverse=True)
            for p in range(len(labels_sorted)):
                    assert len(labels_sorted[p]) == len(val_sorted[p])
            yield val_sorted, labels_sorted

    def test_batches(self):
        for i in range(0, len(self.test_data), self._batchsize):

            if len(self.test_data[i:i + self._batchsize]) < self._batchsize:
                break  # drop last incomplete batch

            labels_sorted = sorted(self.test_labels[i:i + self._batchsize:], key=len, reverse=True)
            test_sorted = sorted(self.test_data[i:i + self._batchsize:], key=len, reverse=True)
            for p in range(len(labels_sorted)):
                    assert len(labels_sorted[p]) == len(test_sorted[p])
            yield test_sorted, labels_sorted

    def get_train_size(self):
        return len(self.train_data)

    def get_val_size(self):
        return len(self.val_data)

    def get_test_size(self):
        return len(self.test_data)