| | import os,csv
|
| |
|
| | from argparse import ArgumentParser
|
| | from glob import glob
|
| |
|
| | import numpy as np
|
| |
|
| | from tqdm import tqdm
|
| |
|
| | from anticipation import ops
|
| | from anticipation.visuals import visualize
|
| | from anticipation.tokenize import extract_instruments
|
| | from anticipation.convert import midi_to_events, events_to_midi
|
| | from anticipation.config import TIME_RESOLUTION, EVENT_SIZE
|
| | from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET
|
| |
|
| | def select_sample(filenames, prompt_length, clip_length, verbose=False):
|
| | while True:
|
| |
|
| | idx = np.random.randint(len(filenames))
|
| | if verbose:
|
| | print('Loading index: ', idx)
|
| |
|
| | try:
|
| | events = midi_to_events(filenames[idx])
|
| | except Exception:
|
| | continue
|
| |
|
| | max_time = ops.max_time(events) - clip_length
|
| |
|
| |
|
| | if max_time < 0:
|
| | if verbose:
|
| | print(f' rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
|
| | continue
|
| |
|
| | start_time = max_time*np.random.rand(1)[0]
|
| | clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
|
| | clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))
|
| |
|
| | instruments = ops.get_instruments(clip).keys()
|
| | if len(instruments) > 15:
|
| | if verbose:
|
| | print(f' rejected: track instrument count out of bounds: {len(instruments)}')
|
| | continue
|
| |
|
| | prompt = ops.clip(clip, 0, prompt_length, clip_duration=False)
|
| |
|
| |
|
| | if len(prompt) < EVENT_SIZE*10:
|
| | if verbose:
|
| | print(f' rejected: track has {len(prompt)//EVENT_SIZE} < 10 events in the prompt')
|
| | continue
|
| |
|
| | break
|
| |
|
| | return os.path.basename(filenames[idx]), clip, prompt
|
| |
|
| |
|
| | def main(args):
|
| | np.random.seed(args.seed)
|
| |
|
| | print(f'Selecting clips for accompaniment from: {args.dir}')
|
| | filenames = glob(args.dir + '/**/*.mid', recursive=True) \
|
| | + glob(args.dir + '/**/*.midi', recursive=True)
|
| | filenames = sorted(filenames)
|
| |
|
| | print(f'Saving clips to: {args.output}')
|
| | try:
|
| | os.makedirs(args.output)
|
| | except FileExistsError:
|
| | pass
|
| |
|
| | try:
|
| | os.makedirs(f'{args.output}/groundtruth')
|
| | except FileExistsError:
|
| | pass
|
| |
|
| | with open(f'{args.output}/index.csv', 'w', newline='') as f:
|
| | writer = csv.writer(f)
|
| | writer.writerow(['idx', 'original', 'prompt', 'parts'])
|
| |
|
| | for i in tqdm(range(args.count)):
|
| | filename, clip, prompt = select_sample(filenames, args.prompt_length, args.clip_length)
|
| | parts = ops.get_instruments(clip).keys()
|
| | writer.writerow([i, filename, f'{i}-conditional.mid', len(parts)])
|
| |
|
| | mid = events_to_midi(clip)
|
| | mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
|
| | if args.visualize:
|
| | visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')
|
| |
|
| | mid = events_to_midi(prompt)
|
| | mid.save(f'{args.output}/{i}-conditional.mid')
|
| | if args.visualize:
|
| | visualize(prompt, f'{args.output}/{i}-conditional.png')
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | parser = ArgumentParser(description='select prompts for infilling completion human eval')
|
| | parser.add_argument('dir', help='directory containing MIDI files to sample')
|
| | parser.add_argument('-o', '--output', type=str, default='output',
|
| | help='output directory')
|
| | parser.add_argument('-s', '--seed', type=int, default=0,
|
| | help='random seed for sampling')
|
| | parser.add_argument('-c', '--count', type=int, default=10,
|
| | help='number of clips to sample')
|
| | parser.add_argument('-p', '--prompt_length', type=int, default=5,
|
| | help='length of the prompt (in seconds)')
|
| | parser.add_argument('-l', '--clip_length', type=int, default=20,
|
| | help='length of the full clip (in seconds)')
|
| | parser.add_argument('-v', '--visualize', action='store_true',
|
| | help='plot visualizations')
|
| | main(parser.parse_args())
|
| |
|