File size: 6,999 Bytes
9ace58a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from sklearn.svm import LinearSVC
from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR')


colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4',
          '#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff',
          '#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1',
          '#000075', '#a9a9a9']


class InteractivePlotter:
    def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training):
        """
        Plots 2D low dimensional features on left and corresponding spectgrams on
        the right.
        """
        self.feats_ds = feats_ds
        self.feats = feats
        self.clf = None

        self.spec_slices = spec_slices
        self.call_info = call_info
        #_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
        self.labels = np.zeros(len(call_info), dtype=np.int)
        self.annotated = np.zeros(self.labels.shape[0], dtype=np.int)  # can populate this with 1's where we have labels
        self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))]
        self.freq_lims = freq_lims

        self.allow_training = allow_training
        self.pt_size = 5.0
        self.spec_pad = 0.2  # this much padding has been applied to the spec slices
        self.fig_width = 12
        self.fig_height = 8

        self.current_id = 0
        max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices])
        self.max_width = self.spec_slices[max_ind].shape[1]
        self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width))


    def plot(self, fig_id):
        self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height),
                               gridspec_kw={'width_ratios': [2, 1]})
        plt.tight_layout()

        # plot 2D TNSE features
        self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
                                              c=self.labels_cols, s=self.pt_size, picker=5)
        self.ax[0].set_title('TSNE of Call Features')
        self.ax[0].set_xticks([])
        self.ax[0].set_yticks([])

        # plot clip from spectrogram
        spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1])
        self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto')
        self.spec_im = self.ax[1].get_images()[0]
        self.ax[1].set_title('Spectrogram')
        self.ax[1].grid(color='w', linewidth=0.5)
        self.ax[1].set_xticks([])
        self.ax[1].set_ylabel('kHz')

        bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False)
        self.ax[1].add_patch(bbox_orig)

        self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points',
                               bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->'))
        self.annot.set_visible(False)

        self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover)
        self.fig.canvas.mpl_connect('key_press_event', self.key_press)


    def mouse_hover(self, event):
        vis = self.annot.get_visible()
        if event.inaxes == self.ax[0]:
            cont, ind = self.low_dim_plt.contains(event)
            if cont:
                self.current_id = ind['ind'][0]

                # copy spec into full window - probably a better way of doing this
                new_spec = self.blank_spec.copy()
                w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2
                new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id]
                self.spec_im.set_data(new_spec)
                self.spec_im.set_clim(vmin=0, vmax=new_spec.max())

                # draw bounding box around call
                self.ax[1].patches[0].remove()
                spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad)
                xx = w_diff + self.spec_pad*spec_width_orig
                ww = spec_width_orig
                yy = self.call_info[self.current_id]['low_freq']/1000
                hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000
                bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False)
                self.ax[1].add_patch(bbox)

                # update annotation arrow
                pos = self.low_dim_plt.get_offsets()[self.current_id]
                self.annot.xy = pos
                self.annot.set_visible(True)

                # write call info
                info_str = self.call_info[self.current_id]['file_name'] + ', time=' \
                            + str(round(self.call_info[self.current_id]['start_time'],3)) \
                            + ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3))
                self.ax[0].set_xlabel(info_str)

                # redraw
                self.fig.canvas.draw_idle()


    def key_press(self, event):
        if event.key.isdigit():
            self.labels_cols[self.current_id] = colors[int(event.key)]
            self.labels[self.current_id] = int(event.key)
            self.annotated[self.current_id] = 1
        elif event.key == 'enter' and self.allow_training:
            self.train_classifier()
        elif event.key == 'x' and self.allow_training:
            self.get_classifier_params()

        self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
                           c=self.labels_cols, s=self.pt_size)
        self.fig.canvas.draw_idle()


    def train_classifier(self):
        # TODO maybe it's better to classify in 2D space - but then can't be linear ...
        inds = np.where(self.annotated == 1)[0]
        labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True)

        if labs_un.shape[0] > 1:  # needs at least 2 classes
            self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001,
                                 intercept_scaling=1.0, max_iter=2000)

            self.clf.fit(self.feats[inds, :], self.labels[inds])

            # update labels
            inds_unlab = np.where(self.annotated == 0)[0]
            self.labels[inds_unlab] = self.clf.predict(self.feats[inds_unlab])
            for ii in inds_unlab:
                self.labels_cols[ii] = colors[self.labels[ii]]
        else:
            print('Not enough data - please label more classes.')


    def get_classifier_params(self):
        res = {}
        if self.clf is None:
            print('Model not trained!')
        else:
            res['weights'] = self.clf.coef_.astype(np.float32)
            res['biases'] = self.clf.intercept_.astype(np.float32)
        return res