Spaces:
Runtime error
Runtime error
import argparse | |
import io | |
import logging | |
import os | |
import re | |
import numpy as np | |
import pandas as pd | |
from datasets import Dataset, DatasetDict, Features, Image, Value | |
from diffusers.pipelines.audio_diffusion import Mel | |
from tqdm.auto import tqdm | |
logging.basicConfig(level=logging.WARN) | |
logger = logging.getLogger("audio_to_images") | |
def main(args): | |
mel = Mel( | |
x_res=args.resolution[0], | |
y_res=args.resolution[1], | |
hop_length=args.hop_length, | |
sample_rate=args.sample_rate, | |
n_fft=args.n_fft, | |
) | |
os.makedirs(args.output_dir, exist_ok=True) | |
audio_files = [ | |
os.path.join(root, file) | |
for root, _, files in os.walk(args.input_dir) | |
for file in files | |
if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE) | |
] | |
examples = [] | |
try: | |
for audio_file in tqdm(audio_files): | |
try: | |
mel.load_audio(audio_file) | |
except KeyboardInterrupt: | |
raise | |
except: | |
continue | |
for slice in range(mel.get_number_of_slices()): | |
image = mel.audio_slice_to_image(slice) | |
assert image.width == args.resolution[0] and image.height == args.resolution[1], "Wrong resolution" | |
# skip completely silent slices | |
if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255): | |
logger.warn("File %s slice %d is completely silent", audio_file, slice) | |
continue | |
with io.BytesIO() as output: | |
image.save(output, format="PNG") | |
bytes = output.getvalue() | |
examples.extend( | |
[ | |
{ | |
"image": {"bytes": bytes}, | |
"audio_file": audio_file, | |
"slice": slice, | |
} | |
] | |
) | |
except Exception as e: | |
print(e) | |
finally: | |
if len(examples) == 0: | |
logger.warn("No valid audio files were found.") | |
return | |
ds = Dataset.from_pandas( | |
pd.DataFrame(examples), | |
features=Features( | |
{ | |
"image": Image(), | |
"audio_file": Value(dtype="string"), | |
"slice": Value(dtype="int16"), | |
} | |
), | |
) | |
dsd = DatasetDict({"train": ds}) | |
dsd.save_to_disk(os.path.join(args.output_dir)) | |
if args.push_to_hub: | |
dsd.push_to_hub(args.push_to_hub) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Create dataset of Mel spectrograms from directory of audio files.") | |
parser.add_argument("--input_dir", type=str) | |
parser.add_argument("--output_dir", type=str, default="data") | |
parser.add_argument( | |
"--resolution", | |
type=str, | |
default="256", | |
help="Either square resolution or width,height.", | |
) | |
parser.add_argument("--hop_length", type=int, default=512) | |
parser.add_argument("--push_to_hub", type=str, default=None) | |
parser.add_argument("--sample_rate", type=int, default=22050) | |
parser.add_argument("--n_fft", type=int, default=2048) | |
args = parser.parse_args() | |
if args.input_dir is None: | |
raise ValueError("You must specify an input directory for the audio files.") | |
# Handle the resolutions. | |
try: | |
args.resolution = (int(args.resolution), int(args.resolution)) | |
except ValueError: | |
try: | |
args.resolution = tuple(int(x) for x in args.resolution.split(",")) | |
if len(args.resolution) != 2: | |
raise ValueError | |
except ValueError: | |
raise ValueError("Resolution must be a tuple of two integers or a single integer.") | |
assert isinstance(args.resolution, tuple) | |
main(args) | |