Spaces:
Runtime error
Runtime error
File size: 2,370 Bytes
9c0c5c8 1dea888 9c0c5c8 1dea888 9c0c5c8 1dea888 9c0c5c8 c17b696 9c0c5c8 c17b696 9c0c5c8 c17b696 9c0c5c8 1dea888 9c0c5c8 1dea888 9c0c5c8 c17b696 1dea888 c17b696 9c0c5c8 1dea888 c17b696 1dea888 e97d748 9c0c5c8 1dea888 c17b696 9c0c5c8 1dea888 e97d748 9c0c5c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import os
import re
import io
import argparse
import pandas as pd
from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict, Features, Image, Value
from audiodiffusion.mel import Mel
def main(args):
mel = Mel(x_res=args.resolution,
y_res=args.resolution,
hop_length=args.hop_length)
os.makedirs(args.output_dir, exist_ok=True)
audio_files = [
os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
for file in files if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
]
examples = []
try:
for audio_file in tqdm(audio_files):
try:
mel.load_audio(audio_file)
except KeyboardInterrupt:
raise
except:
continue
for slice in range(mel.get_number_of_slices()):
image = mel.audio_slice_to_image(slice)
assert (image.width == args.resolution
and image.height == args.resolution)
with io.BytesIO() as output:
image.save(output, format="PNG")
bytes = output.getvalue()
examples.extend([{
"image": {
"bytes": bytes
},
"audio_file": audio_file,
"slice": slice,
}])
finally:
ds = Dataset.from_pandas(
pd.DataFrame(examples),
features=Features({
"image": Image(),
"audio_file": Value(dtype="string"),
"slice": Value(dtype="int16"),
}),
)
dsd = DatasetDict({"train": ds})
dsd.save_to_disk(os.path.join(args.output_dir))
if args.push_to_hub:
dsd.push_to_hub(args.push_to_hub)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=
"Create dataset of Mel spectrograms from directory of audio files.")
parser.add_argument("--input_dir", type=str)
parser.add_argument("--output_dir", type=str, default="data")
parser.add_argument("--resolution", type=int, default=256)
parser.add_argument("--hop_length", type=int, default=512)
parser.add_argument("--push_to_hub", type=str, default=None)
args = parser.parse_args()
main(args)
|