bark_with_batch_inference / generate_audio_semantic_dataset.py
sleeper371's picture
add code
37a9836
import argparse
import logging
import os
from typing import Optional
from core.bark.generate_audio_semantic_dataset import (
generate_wav_semantic_dataset,
BarkGenerationConfig,
)
from core.utils import upload_file_to_hf, zip_folder
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def parse_dataset_args(args_list=None):
"""Parse arguments specific to dataset creation."""
parser = argparse.ArgumentParser(description="Audio Semantic Dataset Creation")
parser.add_argument(
"--text-file",
type=str,
default="data/test_data.txt",
help="Path to text file for dataset generation",
)
parser.add_argument(
"--batch-size",
type=int,
default=2,
help="Batch size for processing (default: 1)",
)
parser.add_argument(
"--output-dir",
type=str,
default="./dataset",
help="Output directory for generated files (default: ./dataset)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=256,
help="Maximum tokens per example (default: 256)",
)
parser.add_argument(
"--use-small-model",
action="store_true",
help="Use small model for generation",
)
parser.add_argument(
"--save-raw-audio",
action="store_true",
help="Store generated audio as .wav instead of .npz",
)
parser.add_argument(
"--publish-hf",
action="store_true",
help="Publish dataset to HuggingFace Hub",
)
parser.add_argument(
"--repo-id",
type=str,
help="HuggingFace repo ID to publish to",
)
parser.add_argument(
"--path-in-repo",
type=str,
help="Path in HF repo",
default=None,
)
parser.add_argument(
"--silent", action="store_true", help="Suppress progress output"
)
return parser.parse_args(args_list)
def create_audio_semantic_dataset(
text_file: str,
output_dir: str = "./dataset",
batch_size: int = 1,
max_tokens: int = 256,
use_small_model: bool = False,
save_raw_audio: bool = False,
publish_hf: bool = False,
repo_id: Optional[str] = None,
path_in_repo: Optional[str] = None,
silent: bool = False,
) -> None:
"""Create audio semantic dataset from text file.
Can be called directly with parameters or via command line using parse_dataset_args().
Args:
text_file: Path to input text file
output_dir: Directory to save generated dataset
batch_size: Batch size for processing
max_tokens: Maximum tokens per example
use_small_model: Whether to use small model
save_raw_audio: Save as raw audio (.wav) instead of .npz
publish_hf: Whether to publish to HuggingFace Hub
repo_id: HF repo ID to publish to
path_in_repo: Path in HF repo
silent: Suppress progress output
"""
os.makedirs(output_dir, exist_ok=True)
if not os.path.isfile(text_file):
raise FileNotFoundError(f"Text file not found: {text_file}")
logger.info(f"Starting dataset generation from {text_file}")
generation_config = BarkGenerationConfig(
temperature=None,
generate_coarse_temperature=None,
generate_fine_temperature=None,
use_small_model=use_small_model,
)
generate_wav_semantic_dataset(
text_file_path=text_file,
generation_config=generation_config,
batch_size=batch_size,
save_path=output_dir,
save_data_as_raw_audio=save_raw_audio,
silent=silent,
)
logger.info("Dataset generation completed")
if publish_hf and repo_id:
logger.info("Publishing dataset to huggingface hub")
zip_path = "./dataset.zip"
success = zip_folder(output_dir, zip_path)
if not success:
raise RuntimeError(f"Unable to zip folder {output_dir}")
upload_file_to_hf(zip_path, repo_id, "dataset", path_in_repo=path_in_repo)
if __name__ == "__main__":
args = parse_dataset_args()
create_audio_semantic_dataset(
text_file=args.text_file,
output_dir=args.output_dir,
batch_size=args.batch_size,
max_tokens=args.max_tokens,
use_small_model=args.use_small_model,
save_raw_audio=args.save_raw_audio,
publish_hf=args.publish_hf,
repo_id=args.repo_id,
path_in_repo=args.path_in_repo,
silent=args.silent,
)