Spaces:
Runtime error
Runtime error
File size: 3,901 Bytes
9f76d9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import argparse
import io
import logging
import os
import re
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, Features, Image, Value
from diffusers.pipelines.audio_diffusion import Mel
from tqdm.auto import tqdm
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger("audio_to_images")
def main(args):
mel = Mel(
x_res=args.resolution[0],
y_res=args.resolution[1],
hop_length=args.hop_length,
sample_rate=args.sample_rate,
n_fft=args.n_fft,
)
os.makedirs(args.output_dir, exist_ok=True)
audio_files = [
os.path.join(root, file)
for root, _, files in os.walk(args.input_dir)
for file in files
if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
]
examples = []
try:
for audio_file in tqdm(audio_files):
try:
mel.load_audio(audio_file)
except KeyboardInterrupt:
raise
except:
continue
for slice in range(mel.get_number_of_slices()):
image = mel.audio_slice_to_image(slice)
assert image.width == args.resolution[0] and image.height == args.resolution[1], "Wrong resolution"
# skip completely silent slices
if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
logger.warn("File %s slice %d is completely silent", audio_file, slice)
continue
with io.BytesIO() as output:
image.save(output, format="PNG")
bytes = output.getvalue()
examples.extend(
[
{
"image": {"bytes": bytes},
"audio_file": audio_file,
"slice": slice,
}
]
)
except Exception as e:
print(e)
finally:
if len(examples) == 0:
logger.warn("No valid audio files were found.")
return
ds = Dataset.from_pandas(
pd.DataFrame(examples),
features=Features(
{
"image": Image(),
"audio_file": Value(dtype="string"),
"slice": Value(dtype="int16"),
}
),
)
dsd = DatasetDict({"train": ds})
dsd.save_to_disk(os.path.join(args.output_dir))
if args.push_to_hub:
dsd.push_to_hub(args.push_to_hub)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create dataset of Mel spectrograms from directory of audio files.")
parser.add_argument("--input_dir", type=str)
parser.add_argument("--output_dir", type=str, default="data")
parser.add_argument(
"--resolution",
type=str,
default="256",
help="Either square resolution or width,height.",
)
parser.add_argument("--hop_length", type=int, default=512)
parser.add_argument("--push_to_hub", type=str, default=None)
parser.add_argument("--sample_rate", type=int, default=22050)
parser.add_argument("--n_fft", type=int, default=2048)
args = parser.parse_args()
if args.input_dir is None:
raise ValueError("You must specify an input directory for the audio files.")
# Handle the resolutions.
try:
args.resolution = (int(args.resolution), int(args.resolution))
except ValueError:
try:
args.resolution = tuple(int(x) for x in args.resolution.split(","))
if len(args.resolution) != 2:
raise ValueError
except ValueError:
raise ValueError("Resolution must be a tuple of two integers or a single integer.")
assert isinstance(args.resolution, tuple)
main(args)
|