Refactor audio generation scripts to streamline processing and enhance functionality
Browse filesThis commit introduces significant improvements to the audio generation workflow in both `audiogen_medium.py` and `stable_audio.py`. Key changes include:
- Removal of redundant seed extraction logic and integration of a new `process_audio_generations` function to handle audio generation in a more organized manner.
- Consolidation of argument preparation for audio generation into a dedicated `prepare_args` function, improving code clarity and maintainability.
- Enhanced user feedback during the audio generation process, ensuring clearer communication of the actions being performed.
These modifications optimize the audio generation process, improve code organization, and enhance the overall user experience.
- audio/audiogen_medium.py +11 -66
- audio/stable_audio.py +13 -66
- audio/tango_audio.py +42 -0
- caption/jtp2.py +0 -2
- caption/wdv3.py +0 -1
- utils/audio_utils.py +91 -0
audio/audiogen_medium.py
CHANGED
@@ -4,18 +4,10 @@
|
|
4 |
import sys
|
5 |
import os
|
6 |
import torch
|
7 |
-
import torchaudio
|
8 |
import random
|
9 |
-
import multiprocessing as mp
|
10 |
from audiocraft.models import AudioGen
|
11 |
from audiocraft.data.audio import audio_write
|
12 |
-
|
13 |
-
def get_seed_from_filename(filename):
|
14 |
-
"""Extract seed from filename like '12345.wav'"""
|
15 |
-
try:
|
16 |
-
return int(filename.split('.')[0])
|
17 |
-
except:
|
18 |
-
return None
|
19 |
|
20 |
def generate_audio(args):
|
21 |
description, seed, prompt_dir = args
|
@@ -37,61 +29,14 @@ def generate_audio(args):
|
|
37 |
# Will save with loudness normalization at -14 db LUFS
|
38 |
audio_write(file_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
descriptions = sys.argv[1:]
|
45 |
-
if not descriptions:
|
46 |
-
print('At least one prompt should be provided')
|
47 |
-
sys.exit(1)
|
48 |
-
|
49 |
-
# Base output directory
|
50 |
-
base_output_dir = 'generated_audio'
|
51 |
-
os.makedirs(base_output_dir, exist_ok=True)
|
52 |
-
|
53 |
-
# Generate 25 variations for each prompt
|
54 |
-
num_variations = 25
|
55 |
-
num_processes = 3 # Number of parallel models to run
|
56 |
-
seed_range = (0, 1000000) # Use seeds between 0 and 1,000,000
|
57 |
-
|
58 |
-
for description in descriptions:
|
59 |
-
# Create a safe folder name from the description
|
60 |
-
folder_name = description.replace(' ', '_').replace('/', '_').replace('\\', '_')
|
61 |
-
folder_name = ''.join(c for c in folder_name if c.isalnum() or c in '_-')
|
62 |
-
prompt_dir = os.path.join(base_output_dir, folder_name)
|
63 |
-
os.makedirs(prompt_dir, exist_ok=True)
|
64 |
-
|
65 |
-
print(f"\nGenerating variations for prompt: {description}")
|
66 |
-
print(f"Saving in directory: {prompt_dir}")
|
67 |
-
|
68 |
-
# Get existing seeds
|
69 |
-
existing_seeds = set()
|
70 |
-
for filename in os.listdir(prompt_dir):
|
71 |
-
if filename.endswith('.wav'):
|
72 |
-
seed = get_seed_from_filename(filename)
|
73 |
-
if seed is not None:
|
74 |
-
existing_seeds.add(seed)
|
75 |
-
|
76 |
-
if len(existing_seeds) >= num_variations:
|
77 |
-
print(f"All {num_variations} variations already exist in {prompt_dir}, skipping...")
|
78 |
-
continue
|
79 |
-
|
80 |
-
# Generate new random seeds that haven't been used yet
|
81 |
-
needed_variations = num_variations - len(existing_seeds)
|
82 |
-
new_seeds = set()
|
83 |
-
while len(new_seeds) < needed_variations:
|
84 |
-
seed = random.randint(*seed_range)
|
85 |
-
if seed not in existing_seeds and seed not in new_seeds:
|
86 |
-
new_seeds.add(seed)
|
87 |
-
|
88 |
-
print(f"Generating {needed_variations} new variations using {num_processes} parallel processes...")
|
89 |
-
print(f"Using seeds: {sorted(new_seeds)}")
|
90 |
-
|
91 |
-
# Prepare arguments for parallel processing
|
92 |
-
args_list = [(description, seed, prompt_dir) for seed in new_seeds]
|
93 |
-
|
94 |
-
# Use multiprocessing to distribute the work
|
95 |
-
with mp.Pool(processes=num_processes) as pool:
|
96 |
-
pool.map(generate_audio, args_list)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import sys
|
5 |
import os
|
6 |
import torch
|
|
|
7 |
import random
|
|
|
8 |
from audiocraft.models import AudioGen
|
9 |
from audiocraft.data.audio import audio_write
|
10 |
+
from utils.audio_utils import process_audio_generations
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def generate_audio(args):
|
13 |
description, seed, prompt_dir = args
|
|
|
29 |
# Will save with loudness normalization at -14 db LUFS
|
30 |
audio_write(file_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
|
31 |
|
32 |
+
def prepare_args(description, seeds, prompt_dir):
|
33 |
+
"""Prepare arguments for the generate_audio function"""
|
34 |
+
return [(description, seed, prompt_dir) for seed in seeds]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
if __name__ == '__main__':
|
37 |
+
process_audio_generations(
|
38 |
+
descriptions=sys.argv[1:],
|
39 |
+
model_name='audiogen',
|
40 |
+
generate_fn=generate_audio,
|
41 |
+
prepare_args_fn=prepare_args
|
42 |
+
)
|
audio/stable_audio.py
CHANGED
@@ -6,15 +6,11 @@ import os
|
|
6 |
import torch
|
7 |
import soundfile as sf
|
8 |
import random
|
9 |
-
import multiprocessing as mp
|
10 |
from diffusers import StableAudioPipeline
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
try:
|
15 |
-
return int(filename.split('.')[0])
|
16 |
-
except:
|
17 |
-
return None
|
18 |
|
19 |
def generate_audio(args):
|
20 |
description, negative_prompt, seed, prompt_dir = args
|
@@ -46,63 +42,14 @@ def generate_audio(args):
|
|
46 |
print(f"Saving audio to: {file_path}")
|
47 |
sf.write(file_path, output, pipe.vae.sampling_rate)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
descriptions = sys.argv[1:]
|
54 |
-
if not descriptions:
|
55 |
-
print('At least one prompt should be provided')
|
56 |
-
sys.exit(1)
|
57 |
-
|
58 |
-
# Default negative prompt
|
59 |
-
negative_prompt = "Low quality, noise, distortion, low fidelity"
|
60 |
-
|
61 |
-
# Base output directory
|
62 |
-
base_output_dir = 'generated_audio/sa'
|
63 |
-
os.makedirs(base_output_dir, exist_ok=True)
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
folder_name = description.replace(' ', '_').replace('/', '_').replace('\\', '_')
|
73 |
-
folder_name = ''.join(c for c in folder_name if c.isalnum() or c in '_-')
|
74 |
-
prompt_dir = os.path.join(base_output_dir, folder_name)
|
75 |
-
os.makedirs(prompt_dir, exist_ok=True)
|
76 |
-
|
77 |
-
print(f"\nGenerating variations for prompt: {description}")
|
78 |
-
print(f"Saving in directory: {prompt_dir}")
|
79 |
-
|
80 |
-
# Get existing seeds
|
81 |
-
existing_seeds = set()
|
82 |
-
for filename in os.listdir(prompt_dir):
|
83 |
-
if filename.endswith('.wav'):
|
84 |
-
seed = get_seed_from_filename(filename)
|
85 |
-
if seed is not None:
|
86 |
-
existing_seeds.add(seed)
|
87 |
-
|
88 |
-
if len(existing_seeds) >= num_variations:
|
89 |
-
print(f"All {num_variations} variations already exist in {prompt_dir}, skipping...")
|
90 |
-
continue
|
91 |
-
|
92 |
-
# Generate new random seeds that haven't been used yet
|
93 |
-
needed_variations = num_variations - len(existing_seeds)
|
94 |
-
new_seeds = set()
|
95 |
-
while len(new_seeds) < needed_variations:
|
96 |
-
seed = random.randint(*seed_range)
|
97 |
-
if seed not in existing_seeds and seed not in new_seeds:
|
98 |
-
new_seeds.add(seed)
|
99 |
-
|
100 |
-
print(f"Generating {needed_variations} new variations using {num_processes} parallel processes...")
|
101 |
-
print(f"Using seeds: {sorted(new_seeds)}")
|
102 |
-
|
103 |
-
# Prepare arguments for parallel processing
|
104 |
-
args_list = [(description, negative_prompt, seed, prompt_dir) for seed in new_seeds]
|
105 |
-
|
106 |
-
# Use multiprocessing to distribute the work
|
107 |
-
with mp.Pool(processes=num_processes) as pool:
|
108 |
-
pool.map(generate_audio, args_list)
|
|
|
6 |
import torch
|
7 |
import soundfile as sf
|
8 |
import random
|
|
|
9 |
from diffusers import StableAudioPipeline
|
10 |
+
from utils.audio_utils import process_audio_generations
|
11 |
|
12 |
+
# Default negative prompt
|
13 |
+
NEGATIVE_PROMPT = "Low quality, noise, distortion, low fidelity"
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def generate_audio(args):
|
16 |
description, negative_prompt, seed, prompt_dir = args
|
|
|
42 |
print(f"Saving audio to: {file_path}")
|
43 |
sf.write(file_path, output, pipe.vae.sampling_rate)
|
44 |
|
45 |
+
def prepare_args(description, seeds, prompt_dir):
|
46 |
+
"""Prepare arguments for the generate_audio function"""
|
47 |
+
return [(description, NEGATIVE_PROMPT, seed, prompt_dir) for seed in seeds]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
if __name__ == '__main__':
|
50 |
+
process_audio_generations(
|
51 |
+
descriptions=sys.argv[1:],
|
52 |
+
model_name='sa',
|
53 |
+
generate_fn=generate_audio,
|
54 |
+
prepare_args_fn=prepare_args
|
55 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audio/tango_audio.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import soundfile as sf
|
8 |
+
from tango import Tango
|
9 |
+
from utils.audio_utils import process_audio_generations
|
10 |
+
|
11 |
+
def generate_audio(args):
|
12 |
+
description, seed, prompt_dir = args
|
13 |
+
wav_path = os.path.join(prompt_dir, f"{seed}.wav")
|
14 |
+
if os.path.exists(wav_path):
|
15 |
+
print(f"Skipping seed {seed} - file already exists")
|
16 |
+
return
|
17 |
+
|
18 |
+
# Initialize model for this process
|
19 |
+
tango = Tango("declare-lab/tango")
|
20 |
+
|
21 |
+
# Set random seed for reproducibility
|
22 |
+
random.seed(seed)
|
23 |
+
|
24 |
+
# Generate audio
|
25 |
+
audio = tango.generate(description)
|
26 |
+
|
27 |
+
# Save the audio
|
28 |
+
file_path = os.path.join(prompt_dir, f"{seed}.wav")
|
29 |
+
print(f"Saving audio to: {file_path}")
|
30 |
+
sf.write(file_path, audio, samplerate=16000)
|
31 |
+
|
32 |
+
def prepare_args(description, seeds, prompt_dir):
|
33 |
+
"""Prepare arguments for the generate_audio function"""
|
34 |
+
return [(description, seed, prompt_dir) for seed in seeds]
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
process_audio_generations(
|
38 |
+
descriptions=sys.argv[1:],
|
39 |
+
model_name='tango',
|
40 |
+
generate_fn=generate_audio,
|
41 |
+
prepare_args_fn=prepare_args
|
42 |
+
)
|
caption/jtp2.py
CHANGED
@@ -447,5 +447,3 @@ def create_tags(threshold):
|
|
447 |
|
448 |
if __name__ == "__main__":
|
449 |
process_directory(args.directory, args.threshold, args.cpu, args.no_grad)
|
450 |
-
|
451 |
-
|
|
|
447 |
|
448 |
if __name__ == "__main__":
|
449 |
process_directory(args.directory, args.threshold, args.cpu, args.no_grad)
|
|
|
|
caption/wdv3.py
CHANGED
@@ -395,4 +395,3 @@ if __name__ == "__main__":
|
|
395 |
print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
|
396 |
raise ValueError(f"Unknown model name '{opts.model}'")
|
397 |
main(opts)
|
398 |
-
|
|
|
395 |
print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
|
396 |
raise ValueError(f"Unknown model name '{opts.model}'")
|
397 |
main(opts)
|
|
utils/audio_utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import random
|
7 |
+
import multiprocessing as mp
|
8 |
+
|
9 |
+
def get_seed_from_filename(filename):
|
10 |
+
"""Extract seed from filename like '12345.wav'"""
|
11 |
+
try:
|
12 |
+
return int(filename.split('.')[0])
|
13 |
+
except:
|
14 |
+
return None
|
15 |
+
|
16 |
+
def setup_generation_dir(base_output_dir, description):
|
17 |
+
"""Setup and return the directory for a given prompt"""
|
18 |
+
os.makedirs(base_output_dir, exist_ok=True)
|
19 |
+
|
20 |
+
# Create a safe folder name from the description
|
21 |
+
folder_name = description.replace(' ', '_').replace('/', '_').replace('\\', '_')
|
22 |
+
folder_name = ''.join(c for c in folder_name if c.isalnum() or c in '_-')
|
23 |
+
prompt_dir = os.path.join(base_output_dir, folder_name)
|
24 |
+
os.makedirs(prompt_dir, exist_ok=True)
|
25 |
+
return prompt_dir
|
26 |
+
|
27 |
+
def get_existing_seeds(prompt_dir):
|
28 |
+
"""Get set of seeds from existing wav files in directory"""
|
29 |
+
existing_seeds = set()
|
30 |
+
for filename in os.listdir(prompt_dir):
|
31 |
+
if filename.endswith('.wav'):
|
32 |
+
seed = get_seed_from_filename(filename)
|
33 |
+
if seed is not None:
|
34 |
+
existing_seeds.add(seed)
|
35 |
+
return existing_seeds
|
36 |
+
|
37 |
+
def generate_new_seeds(needed_variations, existing_seeds, seed_range=(0, 1000000)):
|
38 |
+
"""Generate new unique random seeds"""
|
39 |
+
new_seeds = set()
|
40 |
+
while len(new_seeds) < needed_variations:
|
41 |
+
seed = random.randint(*seed_range)
|
42 |
+
if seed not in existing_seeds and seed not in new_seeds:
|
43 |
+
new_seeds.add(seed)
|
44 |
+
return new_seeds
|
45 |
+
|
46 |
+
def process_audio_generations(descriptions, model_name, generate_fn, prepare_args_fn, num_variations=25, num_processes=3):
|
47 |
+
"""
|
48 |
+
Shared logic for processing audio generations across different models.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
descriptions: List of text prompts to generate audio for
|
52 |
+
model_name: Name of the model (used for output directory)
|
53 |
+
generate_fn: Function that generates a single audio sample
|
54 |
+
prepare_args_fn: Function that prepares arguments for generate_fn
|
55 |
+
num_variations: Number of variations to generate per prompt
|
56 |
+
num_processes: Number of parallel processes to use
|
57 |
+
"""
|
58 |
+
# Set start method for multiprocessing
|
59 |
+
mp.set_start_method('spawn', force=True)
|
60 |
+
|
61 |
+
if not descriptions:
|
62 |
+
print('At least one prompt should be provided')
|
63 |
+
sys.exit(1)
|
64 |
+
|
65 |
+
# Base output directory
|
66 |
+
base_output_dir = f'generated_audio/{model_name}'
|
67 |
+
|
68 |
+
for description in descriptions:
|
69 |
+
prompt_dir = setup_generation_dir(base_output_dir, description)
|
70 |
+
print(f"\nGenerating variations for prompt: {description}")
|
71 |
+
print(f"Saving in directory: {prompt_dir}")
|
72 |
+
|
73 |
+
# Get existing seeds and check if we need to generate more
|
74 |
+
existing_seeds = get_existing_seeds(prompt_dir)
|
75 |
+
if len(existing_seeds) >= num_variations:
|
76 |
+
print(f"All {num_variations} variations already exist in {prompt_dir}, skipping...")
|
77 |
+
continue
|
78 |
+
|
79 |
+
# Generate new random seeds that haven't been used yet
|
80 |
+
needed_variations = num_variations - len(existing_seeds)
|
81 |
+
new_seeds = generate_new_seeds(needed_variations, existing_seeds)
|
82 |
+
|
83 |
+
print(f"Generating {needed_variations} new variations using {num_processes} parallel processes...")
|
84 |
+
print(f"Using seeds: {sorted(new_seeds)}")
|
85 |
+
|
86 |
+
# Prepare arguments for parallel processing
|
87 |
+
args_list = prepare_args_fn(description, new_seeds, prompt_dir)
|
88 |
+
|
89 |
+
# Use multiprocessing to distribute the work
|
90 |
+
with mp.Pool(processes=num_processes) as pool:
|
91 |
+
pool.map(generate_fn, args_list)
|