import os import sys import numpy as np import argparse import h5py import math import time import logging import pickle import matplotlib.pyplot as plt def load_sdrs(workspace, task_name, filename, config, gpus): stat_path = os.path.join( workspace, "statistics", task_name, filename, "config={},gpus={}".format(config, gpus), "statistics.pkl", ) stat_dict = pickle.load(open(stat_path, 'rb')) median_sdrs = [e['sdr'] for e in stat_dict['test']] return median_sdrs def plot_statistics(args): # arguments & parameters workspace = args.workspace select = args.select task_name = "vctk-musdb18" filename = "train" # paths fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select)) os.makedirs(os.path.dirname(fig_path), exist_ok=True) linewidth = 1 lines = [] fig, ax = plt.subplots(1, 1, figsize=(8, 6)) ylim = 30 expand = 1 if select == '1a': sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1) (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) lines.append(line) else: raise Exception('Error!') eval_every_iterations = 10000 total_ticks = 50 ticks_freq = 10 ax.set_ylim(0, ylim) ax.set_xlim(0, total_ticks) ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq)) ax.xaxis.set_ticklabels( np.arange( 0, total_ticks * eval_every_iterations + 1, ticks_freq * eval_every_iterations, ) ) ax.yaxis.set_ticks(np.arange(ylim + 1)) ax.yaxis.set_ticklabels(np.arange(ylim + 1)) ax.grid(color='b', linestyle='solid', linewidth=0.3) plt.legend(handles=lines, loc=4) plt.savefig(fig_path) print('Save figure to {}'.format(fig_path)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--workspace', type=str, required=True) parser.add_argument('--select', type=str, required=True) args = parser.parse_args() plot_statistics(args)