Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import numpy as np | |
import argparse | |
import h5py | |
import math | |
import time | |
import logging | |
import pickle | |
import matplotlib.pyplot as plt | |
def load_sdrs(workspace, task_name, filename, config, gpus): | |
stat_path = os.path.join( | |
workspace, | |
"statistics", | |
task_name, | |
filename, | |
"config={},gpus={}".format(config, gpus), | |
"statistics.pkl", | |
) | |
stat_dict = pickle.load(open(stat_path, 'rb')) | |
median_sdrs = [e['sdr'] for e in stat_dict['test']] | |
return median_sdrs | |
def plot_statistics(args): | |
# arguments & parameters | |
workspace = args.workspace | |
select = args.select | |
task_name = "vctk-musdb18" | |
filename = "train" | |
# paths | |
fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select)) | |
os.makedirs(os.path.dirname(fig_path), exist_ok=True) | |
linewidth = 1 | |
lines = [] | |
fig, ax = plt.subplots(1, 1, figsize=(8, 6)) | |
ylim = 30 | |
expand = 1 | |
if select == '1a': | |
sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1) | |
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) | |
lines.append(line) | |
else: | |
raise Exception('Error!') | |
eval_every_iterations = 10000 | |
total_ticks = 50 | |
ticks_freq = 10 | |
ax.set_ylim(0, ylim) | |
ax.set_xlim(0, total_ticks) | |
ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq)) | |
ax.xaxis.set_ticklabels( | |
np.arange( | |
0, | |
total_ticks * eval_every_iterations + 1, | |
ticks_freq * eval_every_iterations, | |
) | |
) | |
ax.yaxis.set_ticks(np.arange(ylim + 1)) | |
ax.yaxis.set_ticklabels(np.arange(ylim + 1)) | |
ax.grid(color='b', linestyle='solid', linewidth=0.3) | |
plt.legend(handles=lines, loc=4) | |
plt.savefig(fig_path) | |
print('Save figure to {}'.format(fig_path)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--workspace', type=str, required=True) | |
parser.add_argument('--select', type=str, required=True) | |
args = parser.parse_args() | |
plot_statistics(args) | |