import os import re import io import logging import argparse import numpy as np import pandas as pd from tqdm.auto import tqdm from datasets import Dataset, DatasetDict, Features, Image, Value from audiodiffusion.mel import Mel logging.basicConfig(level=logging.WARN) logger = logging.getLogger('audio_to_images') def main(args): mel = Mel(x_res=args.resolution[0], y_res=args.resolution[1], hop_length=args.hop_length, sample_rate=args.sample_rate) os.makedirs(args.output_dir, exist_ok=True) audio_files = [ os.path.join(root, file) for root, _, files in os.walk(args.input_dir) for file in files if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE) ] examples = [] try: for audio_file in tqdm(audio_files): try: mel.load_audio(audio_file) except KeyboardInterrupt: raise except: continue for slice in range(mel.get_number_of_slices()): image = mel.audio_slice_to_image(slice) assert (image.width == args.resolution[0] and image.height == args.resolution[1]), "Wrong resolution" # skip completely silent slices if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255): logger.warn('File %s slice %d is completely silent', audio_file, slice) continue with io.BytesIO() as output: image.save(output, format="PNG") bytes = output.getvalue() examples.extend([{ "image": { "bytes": bytes }, "audio_file": audio_file, "slice": slice, }]) except Exception as e: print(e) finally: if len(examples) == 0: logger.warn('No valid audio files were found.') return ds = Dataset.from_pandas( pd.DataFrame(examples), features=Features({ "image": Image(), "audio_file": Value(dtype="string"), "slice": Value(dtype="int16"), }), ) dsd = DatasetDict({"train": ds}) dsd.save_to_disk(os.path.join(args.output_dir)) if args.push_to_hub: dsd.push_to_hub(args.push_to_hub) if __name__ == "__main__": parser = argparse.ArgumentParser( description= "Create dataset of Mel spectrograms from directory of audio files.") parser.add_argument("--input_dir", type=str) parser.add_argument("--output_dir", type=str, default="data") parser.add_argument("--resolution", type=str, default="256", help="Either square resolution or width,height.") parser.add_argument("--hop_length", type=int, default=512) parser.add_argument("--push_to_hub", type=str, default=None) parser.add_argument("--sample_rate", type=int, default=22050) args = parser.parse_args() if args.input_dir is None: raise ValueError( "You must specify an input directory for the audio files.") # Handle the resolutions. try: args.resolution = (int(args.resolution), int(args.resolution)) except ValueError: try: args.resolution = tuple(int(x) for x in args.resolution.split(",")) if len(args.resolution) != 2: raise ValueError except ValueError: raise ValueError( "Resolution must be a tuple of two integers or a single integer." ) assert isinstance(args.resolution, tuple) main(args)