beatalignment / scripts /compound-stats.py
william590y's picture
Upload folder using huggingface_hub
151b875 verified
from argparse import ArgumentParser
from concurrent.futures import ProcessPoolExecutor
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from anticipation.config import *
from anticipation.convert import compound_to_events
from anticipation.tokenize import maybe_tokenize
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern']
plt.rcParams['font.size'] = 16
def dataset_stats(filename):
with open(filename, 'r') as f:
compound_tokens = [int(token) for token in f.read().split()]
_, _, status = maybe_tokenize(compound_tokens)
time_length = 0 if len(compound_tokens) == 0 else compound_tokens[-5] + compound_tokens[-4]
return (3*(len(compound_tokens) // 5), time_length, status)
def loghist(filename, data, title, xlabel):
sns.set_style('whitegrid')
plt.clf()
plt.figure(figsize=(10,4))
#plt.title(title)
plt.xscale('log')
plt.xlabel(xlabel)
plt.ylabel('Density')
plt.grid(True, which='both', linestyle='-', linewidth=0.5)
density = sns.kdeplot(data, bw_adjust=1.0)
plt.tight_layout()
fig = density.get_figure()
fig.savefig(filename, dpi=300)
def main(args):
filenames = glob(args.dir + '/**/*.compound.txt', recursive=True)
print(f'Calculating statistics for the dataset rooted at {args.dir}')
with ProcessPoolExecutor(max_workers=PREPROC_WORKERS) as executor:
results = list(tqdm(
executor.map(dataset_stats, filenames),
desc='Computing statistics',
total=len(filenames)))
print('Sequences over one hour: ', len([r for r in results if
r[1] > TIME_RESOLUTION*MAX_TRACK_TIME_IN_SECONDS]))
null_sequences = len([r for r in results if r[0] == 0])
print('Sequences with zero tokens: ', null_sequences)
status = [r[2] for r in results]
print('Filtering statistics: ')
print(' ==> too short:', len([s for s in status if s == 1]))
print(' ==> too long:', len([s for s in status if s == 2]))
print(' ==> too many instruments:', len([s for s in status if s == 3]))
# prefiltering (can't plot these on the log scale)
results = [r for r in results if r[0] != 0]
token_lengths, time_lengths, status = zip(*results)
time_lengths = [t/float(TIME_RESOLUTION) for t in time_lengths]
token_count = sum(token_lengths)
loghist('output/unfiltered_length_tokens.png',
token_lengths,
'Unfiltered Distribution of Sequence Lengths',
'Length in Tokens (log10 scale)')
loghist('output/unfiltered_length_seconds.png',
time_lengths,
'Unfiltered Distribution of Sequences Length',
'Time in Seconds (log10 scale)')
filtered_results = [r for r in results if r[2] == 0]
filtered_ratio = len(filtered_results)/float(len(results) + null_sequences)
token_lengths, time_lengths, status = zip(*filtered_results)
time_lengths = [t/float(TIME_RESOLUTION) for t in time_lengths]
filtered_token_count = sum(token_lengths)
filtered_token_ratio = filtered_token_count/float(token_count)
loghist('output/filtered_length_tokens.png',
token_lengths,
'Distribution of Sequence Lengths (in tokens)',
'Length in Tokens (log10 scale)')
loghist('output/filtered_length_seconds.png',
time_lengths,
'Distribution of Sequence Lengths (in seconds)',
'Time in Seconds (log10 scale)')
print('Successfully calculated statistics: detailed results available at output/')
print(' => Number of sequences: ', len(results) + null_sequences)
print(' => Number of tokens (unfiltered): ', token_count)
print(' => Number of sequences (filtered): {} ({}%)'.format(
len(filtered_results), 100*round(filtered_ratio, 2)))
print(' => Number of tokens (filtered): {} ({}%)'.format(
filtered_token_count, 100*round(filtered_token_ratio, 2)))
print(f' => Total time (filtered): {round(sum(time_lengths)/3600., 2)}h')
print(f' - mean time: {round(np.mean(time_lengths))}s')
print(f' - std time: {round(np.std(time_lengths))}s')
print(f' - mean tokens: {round(np.mean(token_lengths))}')
print(f' - std tokens: {round(np.std(token_lengths))}')
if __name__ == '__main__':
parser = ArgumentParser(description='calculate statistics of the intermediate compound representation')
parser.add_argument('dir', help='directory containing .mid files for training')
main(parser.parse_args())