import argparse import io import logging import os import re import numpy as np import pandas as pd from datasets import Dataset, DatasetDict, Features, Image, Value from diffusers.pipelines.audio_diffusion import Mel from tqdm.auto import tqdm logging.basicConfig(level=logging.WARN) logger = logging.getLogger("audio_to_images") def main(args): mel = Mel( x_res=args.resolution[0], y_res=args.resolution[1], hop_length=args.hop_length, sample_rate=args.sample_rate, n_fft=args.n_fft, ) os.makedirs(args.output_dir, exist_ok=True) audio_files = [ os.path.join(root, file) for root, _, files in os.walk(args.input_dir) for file in files if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE) ] examples = [] try: for audio_file in tqdm(audio_files): try: mel.load_audio(audio_file) except KeyboardInterrupt: raise except: continue for slice in range(mel.get_number_of_slices()): image = mel.audio_slice_to_image(slice) assert image.width == args.resolution[0] and image.height == args.resolution[1], "Wrong resolution" # skip completely silent slices if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255): logger.warn("File %s slice %d is completely silent", audio_file, slice) continue with io.BytesIO() as output: image.save(output, format="PNG") bytes = output.getvalue() examples.extend( [ { "image": {"bytes": bytes}, "audio_file": audio_file, "slice": slice, } ] ) except Exception as e: print(e) finally: if len(examples) == 0: logger.warn("No valid audio files were found.") return ds = Dataset.from_pandas( pd.DataFrame(examples), features=Features( { "image": Image(), "audio_file": Value(dtype="string"), "slice": Value(dtype="int16"), } ), ) dsd = DatasetDict({"train": ds}) dsd.save_to_disk(os.path.join(args.output_dir)) if args.push_to_hub: dsd.push_to_hub(args.push_to_hub) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Create dataset of Mel spectrograms from directory of audio files.") parser.add_argument("--input_dir", type=str) parser.add_argument("--output_dir", type=str, default="data") parser.add_argument( "--resolution", type=str, default="256", help="Either square resolution or width,height.", ) parser.add_argument("--hop_length", type=int, default=512) parser.add_argument("--push_to_hub", type=str, default=None) parser.add_argument("--sample_rate", type=int, default=22050) parser.add_argument("--n_fft", type=int, default=2048) args = parser.parse_args() if args.input_dir is None: raise ValueError("You must specify an input directory for the audio files.") # Handle the resolutions. try: args.resolution = (int(args.resolution), int(args.resolution)) except ValueError: try: args.resolution = tuple(int(x) for x in args.resolution.split(",")) if len(args.resolution) != 2: raise ValueError except ValueError: raise ValueError("Resolution must be a tuple of two integers or a single integer.") assert isinstance(args.resolution, tuple) main(args)