import numpy as np import matplotlib.pyplot as plt import json from sklearn.metrics import confusion_matrix from matplotlib import patches from matplotlib.collections import PatchCollection from . import audio_utils as au def create_box_image(spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False): # filter detections stop_time = start_time + duration detections = [] for bb in detections_ip: if (bb['start_time'] >= start_time) and (bb['start_time'] < stop_time-0.02): #(bb['end_time'] < end_time): detections.append(bb) # create figure freq_scale = 1000 # turn Hz to kHz min_freq = params['min_freq']//freq_scale max_freq = params['max_freq']//freq_scale y_extent = [0, duration, min_freq, max_freq] if hide_axis: ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) else: ax = plt.gca() plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=max_val) boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time) ax.add_collection(PatchCollection(boxes, match_original=True)) plt.grid(False) if plot_class_names: for ii, bb in enumerate(boxes): txt = ' '.join([sp[:3] for sp in detections_ip[ii]['class'].split(' ')]) font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} y_pos = bb.get_xy()[1] + bb.get_height() if y_pos > (max_freq - 10): y_pos = max_freq - 10 plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title_text='', anns=None): # create figure and plot boxes freq_scale = 1000 # turn Hz to kHz min_freq = min_freq//freq_scale max_freq = max_freq//freq_scale y_extent = [0, duration, min_freq, max_freq] plt.close('all') fig = plt.figure(0, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100) plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=spec.max()*1.1) plt.ylabel('Freq - kHz') plt.xlabel('Time - secs') if title_text != '': plt.title(title_text) plt.tight_layout() if anns is not None: # drawing bounding boxes and class names boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time) plt.gca().add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): txt = ' '.join([sp[:3] for sp in anns[ii]['class'].split(' ')]) font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} y_pos = bb.get_xy()[1] + bb.get_height() if y_pos > (max_freq - 10): y_pos = max_freq - 10 plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) print('Saving figure to:', op_path) plt.savefig(op_path) def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False): plt.figure(fig_id) un_class, labels = np.unique(class_names, return_inverse=True) un_labels = np.unique(labels) if un_labels.shape[0] > len(colors): colors = [plt.cm.jet(float(ii)/un_labels.shape[0]) for ii in un_labels] for ii, u in enumerate(un_labels): inds = np.where(labels==u)[0] plt.scatter(feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size) if plot_legend: plt.legend() plt.xticks([]) plt.yticks([]) plt.title('downsampled features') def plot_bounding_box_patch(pred, freq_scale, ecolor='w'): patch_collect = [] for bb in range(len(pred['start_times'])): xx = pred['start_times'][bb] ww = pred['end_times'][bb] - pred['start_times'][bb] yy = pred['low_freqs'][bb] / freq_scale hh = (pred['high_freqs'][bb] - pred['low_freqs'][bb]) / freq_scale if 'det_probs' in pred.keys(): alpha_val = pred['det_probs'][bb] else: alpha_val = 1.0 patch_collect.append(patches.Rectangle((xx, yy), ww, hh, linewidth=1, edgecolor=ecolor, facecolor='none', alpha=alpha_val)) return patch_collect def plot_bounding_box_patch_ann(anns, freq_scale, start_time): patch_collect = [] for aa in range(len(anns)): xx = anns[aa]['start_time'] - start_time ww = anns[aa]['end_time'] - anns[aa]['start_time'] yy = anns[aa]['low_freq'] / freq_scale hh = (anns[aa]['high_freq'] - anns[aa]['low_freq']) / freq_scale if 'det_prob' in anns[aa]: alpha = anns[aa]['det_prob'] else: alpha = 1.0 patch_collect.append(patches.Rectangle((xx,yy), ww, hh, linewidth=1, edgecolor='w', facecolor='none', alpha=alpha)) return patch_collect def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title, op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True): if fixed_aspect: # ouptut image will be this width irrespective of the duration of the audio file width = 12 else: width = 12*duration fig = plt.figure(1, figsize=(width, 8)) ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h ax1 = plt.axes([0.05, 0.33, 0.9, 0.30]) ax2 = plt.axes([0.05, 0.01, 0.9, 0.30]) freq_scale = 1000 # turn Hz in kHz #duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) y_extent = [0, duration, params['min_freq']//freq_scale, params['max_freq']//freq_scale] # plot gt boxes ax0.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) ax0.xaxis.set_ticklabels([]) font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} ax0.text(0, params['min_freq']//freq_scale, 'Ground Truth', fontdict=font_info) plt.grid(False) if plot_boxes: boxes = plot_bounding_box_patch(gt, freq_scale) ax0.add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): class_id = int(gt['class_ids'][ii]) if class_id < 0: txt = params['generic_class'][0] else: txt = params['class_names_short'][class_id] font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} y_pos = bb.get_xy()[1] + bb.get_height() ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) # plot predicted boxes ax1.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) ax1.xaxis.set_ticklabels([]) font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} ax1.text(0, params['min_freq']//freq_scale, 'Prediction', fontdict=font_info) plt.grid(False) if plot_boxes: boxes = plot_bounding_box_patch(pred, freq_scale) ax1.add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): if pred['class_probs'].shape[0] > len(params['class_names_short']): class_id = pred['class_probs'][:-1, ii].argmax() else: class_id = pred['class_probs'][:, ii].argmax() txt = params['class_names_short'][class_id] font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} y_pos = bb.get_xy()[1] + bb.get_height() ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) # plot 2D heatmap if pred_2d_hm is not None: min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min() max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max() ax2.imshow(pred_2d_hm, aspect='auto', cmap='plasma', extent=y_extent, clim=[min_val, max_val]) #ax2.xaxis.set_ticklabels([]) font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} ax2.text(0, params['min_freq']//freq_scale, 'Heatmap', fontdict=font_info) plt.grid(False) plt.suptitle(plot_title) if op_file_name is not None: fig.savefig(op_file_name) plt.close(1) def plot_pr_curve(op_dir, plt_title, file_name, results, file_type='png', title_text=''): precision = results['precision'] recall = results['recall'] avg_prec = results['avg_prec'] plt.figure(0, figsize=(10,8)) plt.plot(recall, precision) plt.ylabel('Precision', fontsize=20) plt.xlabel('Recall', fontsize=20) if title_text != '': plt.title(title_text, fontdict={'fontsize': 28}) else: plt.title(plt_title + ' {:.3f}\n'.format(avg_prec)) plt.xlim(0,1.02) plt.ylim(0,1.02) plt.grid(True) plt.tight_layout() plt.savefig(op_dir + file_name + '.' + file_type) plt.close(0) def plot_pr_curve_class(op_dir, plt_title, file_name, results, file_type='png', title_text=''): plt.figure(0, figsize=(10,8)) plt.ylabel('Precision', fontsize=20) plt.xlabel('Recall', fontsize=20) plt.xlim(0,1.02) plt.ylim(0,1.02) plt.grid(True) linestyles = ['-', ':', '--'] markers = ['o', 'v', '>', '^', '<', 's', 'P', 'X', '*'] colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] # plot the PR curves for ii, rr in enumerate(results['class_pr']): class_name = ' '.join([sp[:3] for sp in rr['name'].split(' ')]) cur_color = colors[int(ii%10)] plt.plot(rr['recall'], rr['precision'], label=class_name, color=cur_color, linestyle=linestyles[int(ii//10)], lw=2.5) #print(class_name) # plot the location of the confidence threshold values for jj, tt in enumerate(rr['thresholds']): ind = rr['thresholds_inds'][jj] if ind > -1: plt.plot(rr['recall'][ind], rr['precision'][ind], markers[jj], color=cur_color, ms=10) #print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3)) if title_text != '': plt.title(title_text, fontdict={'fontsize': 28}) else: plt.title(plt_title + ' {:.3f}\n'.format(results['avg_prec_class'])) plt.legend(loc='lower left', prop={'size': 14}) plt.tight_layout() plt.savefig(op_dir + file_name + '.' + file_type) plt.close(0) def plot_confusion_matrix(op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type='png', title_text=''): # shorten the class names for plotting class_names = [] for cc in class_names_long: class_name_sm = ''.join([cc_sm[:3] + ' ' for cc_sm in cc.split(' ')])[:-1] class_names.append(class_name_sm) num_classes = len(class_names) cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32) cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] cm[np.where(cm_norm ==- 0)[0], :] = np.nan if verbose: print('Per class accuracy:') str_len = np.max([len(cc) for cc in class_names_long]) + 5 accs = np.diag(cm) for ii, cc in enumerate(class_names_long): if np.isnan(accs[ii]): print(str(ii).ljust(5) + cc.ljust(str_len)) else: print(str(ii).ljust(5) + cc.ljust(str_len) + '{:.2f}'.format(accs[ii]*100)) plt.figure(0, figsize=(10,8)) plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') plt.colorbar() plt.xticks(np.arange(cm.shape[1]), class_names, rotation='vertical') plt.yticks(np.arange(cm.shape[0]), class_names) plt.xlabel('Predicted', fontsize=20) plt.ylabel('Ground Truth', fontsize=20) if title_text != '': plt.title(title_text, fontdict={'fontsize': 28}) else: plt.title(op_file + ' {:.3f}\n'.format(file_acc)) plt.tight_layout() plt.savefig(op_dir + op_file + '.' + file_type) plt.close('all') class LossPlotter(object): def __init__(self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False): self.reset() self.op_file_name = op_file_name self.duration = duration # length of x axis self.labels = labels self.ylim = ylim self.class_names = class_names self.axis_labels = axis_labels self.logy = logy def reset(self): self.epochs = [] self.vals = [] def update_and_save(self, epoch, val, gt=None, pred=None): self.epochs.append(epoch) self.vals.append(val) self.save_plot() self.save_json() if gt is not None: self.save_confusion_matrix(gt, pred) def save_plot(self): linestyles = ['-', ':', '--'] plt.figure(0, figsize=(8,5)) for ii in range(len(self.vals[0])): l_vals = [vv[ii] for vv in self.vals] plt.plot(self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii//10)]) plt.xlim(0, np.maximum(self.duration, len(self.vals))) if self.ylim is not None: plt.ylim(self.ylim[0], self.ylim[1]) if self.axis_labels is not None: plt.xlabel(self.axis_labels[0]) plt.ylabel(self.axis_labels[1]) if self.logy: plt.gca().set_yscale('log') plt.grid(True) plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.0) plt.tight_layout() plt.savefig(self.op_file_name) plt.close(0) def save_json(self): data = {} data['epochs'] = self.epochs for ii in range(len(self.vals[0])): data[self.labels[ii]] = [round(vv[ii],4) for vv in self.vals] with open(self.op_file_name[:-4] + '.json', 'w') as da: json.dump(data, da, indent=2) def save_confusion_matrix(self, gt, pred): plt.figure(0) cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(np.float32) cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') plt.colorbar() plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation='vertical') plt.yticks(np.arange(cm.shape[0]), self.class_names) plt.xlabel('Predicted') plt.ylabel('Ground Truth') plt.tight_layout() plt.savefig(self.op_file_name[:-4] + '_cm.png') plt.close(0)