audio-diffusion / scripts /audio_to_images.py
teticio's picture
add conditional training
f29faf1
raw
history blame
3.9 kB
import argparse
import io
import logging
import os
import re
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, Features, Image, Value
from diffusers.pipelines.audio_diffusion import Mel
from tqdm.auto import tqdm
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger("audio_to_images")
def main(args):
mel = Mel(
x_res=args.resolution[0],
y_res=args.resolution[1],
hop_length=args.hop_length,
sample_rate=args.sample_rate,
n_fft=args.n_fft,
)
os.makedirs(args.output_dir, exist_ok=True)
audio_files = [
os.path.join(root, file)
for root, _, files in os.walk(args.input_dir)
for file in files
if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
]
examples = []
try:
for audio_file in tqdm(audio_files):
try:
mel.load_audio(audio_file)
except KeyboardInterrupt:
raise
except:
continue
for slice in range(mel.get_number_of_slices()):
image = mel.audio_slice_to_image(slice)
assert image.width == args.resolution[0] and image.height == args.resolution[1], "Wrong resolution"
# skip completely silent slices
if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
logger.warn("File %s slice %d is completely silent", audio_file, slice)
continue
with io.BytesIO() as output:
image.save(output, format="PNG")
bytes = output.getvalue()
examples.extend(
[
{
"image": {"bytes": bytes},
"audio_file": audio_file,
"slice": slice,
}
]
)
except Exception as e:
print(e)
finally:
if len(examples) == 0:
logger.warn("No valid audio files were found.")
return
ds = Dataset.from_pandas(
pd.DataFrame(examples),
features=Features(
{
"image": Image(),
"audio_file": Value(dtype="string"),
"slice": Value(dtype="int16"),
}
),
)
dsd = DatasetDict({"train": ds})
dsd.save_to_disk(os.path.join(args.output_dir))
if args.push_to_hub:
dsd.push_to_hub(args.push_to_hub)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create dataset of Mel spectrograms from directory of audio files.")
parser.add_argument("--input_dir", type=str)
parser.add_argument("--output_dir", type=str, default="data")
parser.add_argument(
"--resolution",
type=str,
default="256",
help="Either square resolution or width,height.",
)
parser.add_argument("--hop_length", type=int, default=512)
parser.add_argument("--push_to_hub", type=str, default=None)
parser.add_argument("--sample_rate", type=int, default=22050)
parser.add_argument("--n_fft", type=int, default=2048)
args = parser.parse_args()
if args.input_dir is None:
raise ValueError("You must specify an input directory for the audio files.")
# Handle the resolutions.
try:
args.resolution = (int(args.resolution), int(args.resolution))
except ValueError:
try:
args.resolution = tuple(int(x) for x in args.resolution.split(","))
if len(args.resolution) != 2:
raise ValueError
except ValueError:
raise ValueError("Resolution must be a tuple of two integers or a single integer.")
assert isinstance(args.resolution, tuple)
main(args)