Spaces:
Running
Running
import numpy as np | |
import matplotlib.pyplot as plt | |
import json | |
from sklearn.metrics import confusion_matrix | |
from matplotlib import patches | |
from matplotlib.collections import PatchCollection | |
from . import audio_utils as au | |
def create_box_image(spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False): | |
# filter detections | |
stop_time = start_time + duration | |
detections = [] | |
for bb in detections_ip: | |
if (bb['start_time'] >= start_time) and (bb['start_time'] < stop_time-0.02): #(bb['end_time'] < end_time): | |
detections.append(bb) | |
# create figure | |
freq_scale = 1000 # turn Hz to kHz | |
min_freq = params['min_freq']//freq_scale | |
max_freq = params['max_freq']//freq_scale | |
y_extent = [0, duration, min_freq, max_freq] | |
if hide_axis: | |
ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
ax.set_axis_off() | |
fig.add_axes(ax) | |
else: | |
ax = plt.gca() | |
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=max_val) | |
boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time) | |
ax.add_collection(PatchCollection(boxes, match_original=True)) | |
plt.grid(False) | |
if plot_class_names: | |
for ii, bb in enumerate(boxes): | |
txt = ' '.join([sp[:3] for sp in detections_ip[ii]['class'].split(' ')]) | |
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} | |
y_pos = bb.get_xy()[1] + bb.get_height() | |
if y_pos > (max_freq - 10): | |
y_pos = max_freq - 10 | |
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) | |
def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title_text='', anns=None): | |
# create figure and plot boxes | |
freq_scale = 1000 # turn Hz to kHz | |
min_freq = min_freq//freq_scale | |
max_freq = max_freq//freq_scale | |
y_extent = [0, duration, min_freq, max_freq] | |
plt.close('all') | |
fig = plt.figure(0, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100) | |
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=spec.max()*1.1) | |
plt.ylabel('Freq - kHz') | |
plt.xlabel('Time - secs') | |
if title_text != '': | |
plt.title(title_text) | |
plt.tight_layout() | |
if anns is not None: | |
# drawing bounding boxes and class names | |
boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time) | |
plt.gca().add_collection(PatchCollection(boxes, match_original=True)) | |
for ii, bb in enumerate(boxes): | |
txt = ' '.join([sp[:3] for sp in anns[ii]['class'].split(' ')]) | |
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} | |
y_pos = bb.get_xy()[1] + bb.get_height() | |
if y_pos > (max_freq - 10): | |
y_pos = max_freq - 10 | |
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) | |
print('Saving figure to:', op_path) | |
plt.savefig(op_path) | |
def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False): | |
plt.figure(fig_id) | |
un_class, labels = np.unique(class_names, return_inverse=True) | |
un_labels = np.unique(labels) | |
if un_labels.shape[0] > len(colors): | |
colors = [plt.cm.jet(float(ii)/un_labels.shape[0]) for ii in un_labels] | |
for ii, u in enumerate(un_labels): | |
inds = np.where(labels==u)[0] | |
plt.scatter(feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size) | |
if plot_legend: | |
plt.legend() | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.title('downsampled features') | |
def plot_bounding_box_patch(pred, freq_scale, ecolor='w'): | |
patch_collect = [] | |
for bb in range(len(pred['start_times'])): | |
xx = pred['start_times'][bb] | |
ww = pred['end_times'][bb] - pred['start_times'][bb] | |
yy = pred['low_freqs'][bb] / freq_scale | |
hh = (pred['high_freqs'][bb] - pred['low_freqs'][bb]) / freq_scale | |
if 'det_probs' in pred.keys(): | |
alpha_val = pred['det_probs'][bb] | |
else: | |
alpha_val = 1.0 | |
patch_collect.append(patches.Rectangle((xx, yy), ww, hh, linewidth=1, | |
edgecolor=ecolor, facecolor='none', alpha=alpha_val)) | |
return patch_collect | |
def plot_bounding_box_patch_ann(anns, freq_scale, start_time): | |
patch_collect = [] | |
for aa in range(len(anns)): | |
xx = anns[aa]['start_time'] - start_time | |
ww = anns[aa]['end_time'] - anns[aa]['start_time'] | |
yy = anns[aa]['low_freq'] / freq_scale | |
hh = (anns[aa]['high_freq'] - anns[aa]['low_freq']) / freq_scale | |
if 'det_prob' in anns[aa]: | |
alpha = anns[aa]['det_prob'] | |
else: | |
alpha = 1.0 | |
patch_collect.append(patches.Rectangle((xx,yy), ww, hh, linewidth=1, | |
edgecolor='w', facecolor='none', alpha=alpha)) | |
return patch_collect | |
def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title, | |
op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True): | |
if fixed_aspect: | |
# ouptut image will be this width irrespective of the duration of the audio file | |
width = 12 | |
else: | |
width = 12*duration | |
fig = plt.figure(1, figsize=(width, 8)) | |
ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h | |
ax1 = plt.axes([0.05, 0.33, 0.9, 0.30]) | |
ax2 = plt.axes([0.05, 0.01, 0.9, 0.30]) | |
freq_scale = 1000 # turn Hz in kHz | |
#duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) | |
y_extent = [0, duration, params['min_freq']//freq_scale, params['max_freq']//freq_scale] | |
# plot gt boxes | |
ax0.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) | |
ax0.xaxis.set_ticklabels([]) | |
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} | |
ax0.text(0, params['min_freq']//freq_scale, 'Ground Truth', fontdict=font_info) | |
plt.grid(False) | |
if plot_boxes: | |
boxes = plot_bounding_box_patch(gt, freq_scale) | |
ax0.add_collection(PatchCollection(boxes, match_original=True)) | |
for ii, bb in enumerate(boxes): | |
class_id = int(gt['class_ids'][ii]) | |
if class_id < 0: | |
txt = params['generic_class'][0] | |
else: | |
txt = params['class_names_short'][class_id] | |
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} | |
y_pos = bb.get_xy()[1] + bb.get_height() | |
ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) | |
# plot predicted boxes | |
ax1.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) | |
ax1.xaxis.set_ticklabels([]) | |
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} | |
ax1.text(0, params['min_freq']//freq_scale, 'Prediction', fontdict=font_info) | |
plt.grid(False) | |
if plot_boxes: | |
boxes = plot_bounding_box_patch(pred, freq_scale) | |
ax1.add_collection(PatchCollection(boxes, match_original=True)) | |
for ii, bb in enumerate(boxes): | |
if pred['class_probs'].shape[0] > len(params['class_names_short']): | |
class_id = pred['class_probs'][:-1, ii].argmax() | |
else: | |
class_id = pred['class_probs'][:, ii].argmax() | |
txt = params['class_names_short'][class_id] | |
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} | |
y_pos = bb.get_xy()[1] + bb.get_height() | |
ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) | |
# plot 2D heatmap | |
if pred_2d_hm is not None: | |
min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min() | |
max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max() | |
ax2.imshow(pred_2d_hm, aspect='auto', cmap='plasma', extent=y_extent, clim=[min_val, max_val]) | |
#ax2.xaxis.set_ticklabels([]) | |
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} | |
ax2.text(0, params['min_freq']//freq_scale, 'Heatmap', fontdict=font_info) | |
plt.grid(False) | |
plt.suptitle(plot_title) | |
if op_file_name is not None: | |
fig.savefig(op_file_name) | |
plt.close(1) | |
def plot_pr_curve(op_dir, plt_title, file_name, results, file_type='png', title_text=''): | |
precision = results['precision'] | |
recall = results['recall'] | |
avg_prec = results['avg_prec'] | |
plt.figure(0, figsize=(10,8)) | |
plt.plot(recall, precision) | |
plt.ylabel('Precision', fontsize=20) | |
plt.xlabel('Recall', fontsize=20) | |
if title_text != '': | |
plt.title(title_text, fontdict={'fontsize': 28}) | |
else: | |
plt.title(plt_title + ' {:.3f}\n'.format(avg_prec)) | |
plt.xlim(0,1.02) | |
plt.ylim(0,1.02) | |
plt.grid(True) | |
plt.tight_layout() | |
plt.savefig(op_dir + file_name + '.' + file_type) | |
plt.close(0) | |
def plot_pr_curve_class(op_dir, plt_title, file_name, results, file_type='png', title_text=''): | |
plt.figure(0, figsize=(10,8)) | |
plt.ylabel('Precision', fontsize=20) | |
plt.xlabel('Recall', fontsize=20) | |
plt.xlim(0,1.02) | |
plt.ylim(0,1.02) | |
plt.grid(True) | |
linestyles = ['-', ':', '--'] | |
markers = ['o', 'v', '>', '^', '<', 's', 'P', 'X', '*'] | |
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] | |
# plot the PR curves | |
for ii, rr in enumerate(results['class_pr']): | |
class_name = ' '.join([sp[:3] for sp in rr['name'].split(' ')]) | |
cur_color = colors[int(ii%10)] | |
plt.plot(rr['recall'], rr['precision'], label=class_name, color=cur_color, | |
linestyle=linestyles[int(ii//10)], lw=2.5) | |
#print(class_name) | |
# plot the location of the confidence threshold values | |
for jj, tt in enumerate(rr['thresholds']): | |
ind = rr['thresholds_inds'][jj] | |
if ind > -1: | |
plt.plot(rr['recall'][ind], rr['precision'][ind], markers[jj], | |
color=cur_color, ms=10) | |
#print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3)) | |
if title_text != '': | |
plt.title(title_text, fontdict={'fontsize': 28}) | |
else: | |
plt.title(plt_title + ' {:.3f}\n'.format(results['avg_prec_class'])) | |
plt.legend(loc='lower left', prop={'size': 14}) | |
plt.tight_layout() | |
plt.savefig(op_dir + file_name + '.' + file_type) | |
plt.close(0) | |
def plot_confusion_matrix(op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type='png', title_text=''): | |
# shorten the class names for plotting | |
class_names = [] | |
for cc in class_names_long: | |
class_name_sm = ''.join([cc_sm[:3] + ' ' for cc_sm in cc.split(' ')])[:-1] | |
class_names.append(class_name_sm) | |
num_classes = len(class_names) | |
cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32) | |
cm_norm = cm.sum(1) | |
valid_inds = np.where(cm_norm > 0)[0] | |
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] | |
cm[np.where(cm_norm ==- 0)[0], :] = np.nan | |
if verbose: | |
print('Per class accuracy:') | |
str_len = np.max([len(cc) for cc in class_names_long]) + 5 | |
accs = np.diag(cm) | |
for ii, cc in enumerate(class_names_long): | |
if np.isnan(accs[ii]): | |
print(str(ii).ljust(5) + cc.ljust(str_len)) | |
else: | |
print(str(ii).ljust(5) + cc.ljust(str_len) + '{:.2f}'.format(accs[ii]*100)) | |
plt.figure(0, figsize=(10,8)) | |
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') | |
plt.colorbar() | |
plt.xticks(np.arange(cm.shape[1]), class_names, rotation='vertical') | |
plt.yticks(np.arange(cm.shape[0]), class_names) | |
plt.xlabel('Predicted', fontsize=20) | |
plt.ylabel('Ground Truth', fontsize=20) | |
if title_text != '': | |
plt.title(title_text, fontdict={'fontsize': 28}) | |
else: | |
plt.title(op_file + ' {:.3f}\n'.format(file_acc)) | |
plt.tight_layout() | |
plt.savefig(op_dir + op_file + '.' + file_type) | |
plt.close('all') | |
class LossPlotter(object): | |
def __init__(self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False): | |
self.reset() | |
self.op_file_name = op_file_name | |
self.duration = duration # length of x axis | |
self.labels = labels | |
self.ylim = ylim | |
self.class_names = class_names | |
self.axis_labels = axis_labels | |
self.logy = logy | |
def reset(self): | |
self.epochs = [] | |
self.vals = [] | |
def update_and_save(self, epoch, val, gt=None, pred=None): | |
self.epochs.append(epoch) | |
self.vals.append(val) | |
self.save_plot() | |
self.save_json() | |
if gt is not None: | |
self.save_confusion_matrix(gt, pred) | |
def save_plot(self): | |
linestyles = ['-', ':', '--'] | |
plt.figure(0, figsize=(8,5)) | |
for ii in range(len(self.vals[0])): | |
l_vals = [vv[ii] for vv in self.vals] | |
plt.plot(self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii//10)]) | |
plt.xlim(0, np.maximum(self.duration, len(self.vals))) | |
if self.ylim is not None: | |
plt.ylim(self.ylim[0], self.ylim[1]) | |
if self.axis_labels is not None: | |
plt.xlabel(self.axis_labels[0]) | |
plt.ylabel(self.axis_labels[1]) | |
if self.logy: | |
plt.gca().set_yscale('log') | |
plt.grid(True) | |
plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.0) | |
plt.tight_layout() | |
plt.savefig(self.op_file_name) | |
plt.close(0) | |
def save_json(self): | |
data = {} | |
data['epochs'] = self.epochs | |
for ii in range(len(self.vals[0])): | |
data[self.labels[ii]] = [round(vv[ii],4) for vv in self.vals] | |
with open(self.op_file_name[:-4] + '.json', 'w') as da: | |
json.dump(data, da, indent=2) | |
def save_confusion_matrix(self, gt, pred): | |
plt.figure(0) | |
cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(np.float32) | |
cm_norm = cm.sum(1) | |
valid_inds = np.where(cm_norm > 0)[0] | |
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] | |
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') | |
plt.colorbar() | |
plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation='vertical') | |
plt.yticks(np.arange(cm.shape[0]), self.class_names) | |
plt.xlabel('Predicted') | |
plt.ylabel('Ground Truth') | |
plt.tight_layout() | |
plt.savefig(self.op_file_name[:-4] + '_cm.png') | |
plt.close(0) | |