File size: 2,518 Bytes
6dff871
 
 
 
 
9c0c5c8
 
1dea888
9c0c5c8
 
1dea888
9c0c5c8
1dea888
9c0c5c8
 
 
 
 
1dea888
9c0c5c8
 
 
 
 
 
 
1dea888
9c0c5c8
1dea888
9c0c5c8
 
 
 
 
 
 
 
1dea888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c0c5c8
1dea888
 
 
 
 
 
 
 
 
 
 
 
e97d748
 
9c0c5c8
 
 
1dea888
 
 
9c0c5c8
 
 
1dea888
e97d748
9c0c5c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# TODO
# run on sagemaker
# run with deepspeed


import os
import re
import io
import argparse

import pandas as pd
from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict, Features, Image, Value

from mel import Mel


def main(args):
    mel = Mel(x_res=args.resolution, y_res=args.resolution, hop_length=args.hop_length)
    os.makedirs(args.output_dir, exist_ok=True)
    audio_files = [
        os.path.join(root, file)
        for root, _, files in os.walk(args.input_dir)
        for file in files
        if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
    ]
    examples = []
    try:
        for audio_file in tqdm(audio_files):
            try:
                mel.load_audio(audio_file)
            except KeyboardInterrupt:
                raise
            except:
                continue
            for slice in range(mel.get_number_of_slices()):
                image = mel.audio_slice_to_image(slice)
                assert (
                    image.width == args.resolution and image.height == args.resolution
                )
                with io.BytesIO() as output:
                    image.save(output, format="PNG")
                    bytes = output.getvalue()
                examples.extend(
                    [
                        {
                            "image": {"bytes": bytes},
                            "audio_file": audio_file,
                            "slice": slice,
                        }
                    ]
                )
    finally:
        ds = Dataset.from_pandas(
            pd.DataFrame(examples),
            features=Features(
                {
                    "image": Image(),
                    "audio_file": Value(dtype="string"),
                    "slice": Value(dtype="int16"),
                }
            ),
        )
        dsd = DatasetDict({"train": ds})
        dsd.save_to_disk(os.path.join(args.output_dir))
        if args.push_to_hub:
            dsd.push_to_hub(args.push_to_hub)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Create dataset of Mel spectrograms from directory of audio files."
    )
    parser.add_argument("--input_dir", type=str)
    parser.add_argument("--output_dir", type=str, default="data")
    parser.add_argument("--resolution", type=int, default=256)
    parser.add_argument("--hop_length", type=int, default=512)
    parser.add_argument("--push_to_hub", type=str, default=None)
    args = parser.parse_args()
    main(args)