Spaces:
Running
Running
File size: 6,999 Bytes
9ace58a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from sklearn.svm import LinearSVC
from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR')
colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4',
'#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff',
'#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1',
'#000075', '#a9a9a9']
class InteractivePlotter:
def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training):
"""
Plots 2D low dimensional features on left and corresponding spectgrams on
the right.
"""
self.feats_ds = feats_ds
self.feats = feats
self.clf = None
self.spec_slices = spec_slices
self.call_info = call_info
#_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
self.labels = np.zeros(len(call_info), dtype=np.int)
self.annotated = np.zeros(self.labels.shape[0], dtype=np.int) # can populate this with 1's where we have labels
self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))]
self.freq_lims = freq_lims
self.allow_training = allow_training
self.pt_size = 5.0
self.spec_pad = 0.2 # this much padding has been applied to the spec slices
self.fig_width = 12
self.fig_height = 8
self.current_id = 0
max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices])
self.max_width = self.spec_slices[max_ind].shape[1]
self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width))
def plot(self, fig_id):
self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height),
gridspec_kw={'width_ratios': [2, 1]})
plt.tight_layout()
# plot 2D TNSE features
self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
c=self.labels_cols, s=self.pt_size, picker=5)
self.ax[0].set_title('TSNE of Call Features')
self.ax[0].set_xticks([])
self.ax[0].set_yticks([])
# plot clip from spectrogram
spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1])
self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto')
self.spec_im = self.ax[1].get_images()[0]
self.ax[1].set_title('Spectrogram')
self.ax[1].grid(color='w', linewidth=0.5)
self.ax[1].set_xticks([])
self.ax[1].set_ylabel('kHz')
bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False)
self.ax[1].add_patch(bbox_orig)
self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points',
bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->'))
self.annot.set_visible(False)
self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover)
self.fig.canvas.mpl_connect('key_press_event', self.key_press)
def mouse_hover(self, event):
vis = self.annot.get_visible()
if event.inaxes == self.ax[0]:
cont, ind = self.low_dim_plt.contains(event)
if cont:
self.current_id = ind['ind'][0]
# copy spec into full window - probably a better way of doing this
new_spec = self.blank_spec.copy()
w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2
new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id]
self.spec_im.set_data(new_spec)
self.spec_im.set_clim(vmin=0, vmax=new_spec.max())
# draw bounding box around call
self.ax[1].patches[0].remove()
spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad)
xx = w_diff + self.spec_pad*spec_width_orig
ww = spec_width_orig
yy = self.call_info[self.current_id]['low_freq']/1000
hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000
bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False)
self.ax[1].add_patch(bbox)
# update annotation arrow
pos = self.low_dim_plt.get_offsets()[self.current_id]
self.annot.xy = pos
self.annot.set_visible(True)
# write call info
info_str = self.call_info[self.current_id]['file_name'] + ', time=' \
+ str(round(self.call_info[self.current_id]['start_time'],3)) \
+ ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3))
self.ax[0].set_xlabel(info_str)
# redraw
self.fig.canvas.draw_idle()
def key_press(self, event):
if event.key.isdigit():
self.labels_cols[self.current_id] = colors[int(event.key)]
self.labels[self.current_id] = int(event.key)
self.annotated[self.current_id] = 1
elif event.key == 'enter' and self.allow_training:
self.train_classifier()
elif event.key == 'x' and self.allow_training:
self.get_classifier_params()
self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
c=self.labels_cols, s=self.pt_size)
self.fig.canvas.draw_idle()
def train_classifier(self):
# TODO maybe it's better to classify in 2D space - but then can't be linear ...
inds = np.where(self.annotated == 1)[0]
labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True)
if labs_un.shape[0] > 1: # needs at least 2 classes
self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001,
intercept_scaling=1.0, max_iter=2000)
self.clf.fit(self.feats[inds, :], self.labels[inds])
# update labels
inds_unlab = np.where(self.annotated == 0)[0]
self.labels[inds_unlab] = self.clf.predict(self.feats[inds_unlab])
for ii in inds_unlab:
self.labels_cols[ii] = colors[self.labels[ii]]
else:
print('Not enough data - please label more classes.')
def get_classifier_params(self):
res = {}
if self.clf is None:
print('Model not trained!')
else:
res['weights'] = self.clf.coef_.astype(np.float32)
res['biases'] = self.clf.intercept_.astype(np.float32)
return res
|