Ubuntu
update tokenizer
24d0b1d
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
cpu:
s3tokenizer --root_path /path/to/audio/files \
--model speech_tokenizer_v1 \
--device "cpu" \
--batch_size 32
gpu:
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
`which s3tokenizer` --root_path /path/to/audio/files \
--model speech_tokenizer_v1 \
--device "cuda" \
--batch_size 32
"""
import argparse
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
import s3tokenizer
class AudioDataset(Dataset):
def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3'],
use_cache=True, cache_file=None, max_workers=8):
self.data = []
# Define cache file path
if cache_file is None:
cache_file = os.path.join(root_path, '.audio_file_cache.pkl')
# Try to load from cache first
if use_cache and os.path.exists(cache_file):
import pickle
print(f"Loading file list from cache: {cache_file}")
try:
with open(cache_file, 'rb') as f:
self.data = pickle.load(f)
print(f"Loaded {len(self.data)} files from cache")
return
except Exception as e:
print(f"Failed to load cache: {e}, scanning directory...")
# Method 1: Use os.walk() which is typically faster than pathlib
print(f"Scanning directory: {root_path}")
print(f"Looking for extensions: {extensions}")
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
def scan_directory(args):
dirpath, extensions = args
files = []
try:
with os.scandir(dirpath) as entries:
for entry in entries:
if entry.is_file() and any(entry.name.endswith(ext) for ext in extensions):
files.append(entry.path)
except PermissionError:
pass
return files
# Collect all directories first
all_dirs = [root_path]
for dirpath, dirnames, _ in os.walk(root_path):
all_dirs.extend(os.path.join(dirpath, d) for d in dirnames)
# Process directories in parallel
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(scan_directory, (d, extensions)) for d in all_dirs]
with tqdm(total=len(all_dirs), desc="Scanning directories") as pbar:
for future in as_completed(futures):
self.data.extend(future.result())
pbar.update(1)
# Sort for consistent ordering
self.data.sort()
if len(self.data) == 0:
raise ValueError(f"No audio files found in {root_path}")
print(f"Found {len(self.data)} audio files")
# Save to cache
if use_cache:
try:
import pickle
print(f"Saving file list to cache: {cache_file}")
# Ensure parent directory exists
cache_dir = os.path.dirname(cache_file)
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
with open(cache_file, 'wb') as f:
pickle.dump(self.data, f)
except Exception as e:
print(f"Failed to save cache: {e}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
file_path = self.data[idx]
try:
audio = s3tokenizer.load_audio(file_path)
mel = s3tokenizer.log_mel_spectrogram(audio)
return file_path, mel
except Exception as e:
print(f"Error processing {file_path}: {e}")
return None, None
def collate_fn(batch):
# Filter out None entries (failed files)
batch = [item for item in batch if item[0] is not None]
if len(batch) == 0:
return [], None, None
file_paths = [item[0] for item in batch]
mels = [item[1] for item in batch]
mels, mels_lens = s3tokenizer.padding(mels)
return file_paths, mels, mels_lens
def init_distributed():
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
print('Inference on multiple gpus, this gpu {}'.format(local_rank) +
', rank {}, world_size {}'.format(rank, world_size))
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def get_args():
parser = argparse.ArgumentParser(description='extract speech code')
parser.add_argument('--model',
required=True,
type=str,
choices=[
"speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
"speech_tokenizer_v2_25hz"
],
help='model version')
parser.add_argument('--root_path',
required=True,
type=str,
help='root directory containing audio files')
parser.add_argument('--device',
required=True,
type=str,
choices=["cuda", "cpu"],
help='device for inference')
parser.add_argument('--batch_size',
required=True,
type=int,
help='batch size (per-device) for inference')
parser.add_argument('--num_workers',
type=int,
default=4,
help='workers for dataloader')
parser.add_argument('--prefetch',
type=int,
default=5,
help='prefetch for dataloader')
parser.add_argument('--extensions',
nargs='+',
default=['.wav', '.flac', '.mp3'],
help='audio file extensions to process')
parser.add_argument('--use_cache',
action='store_true',
help='use cached file list to avoid re-scanning')
parser.add_argument('--no_cache',
action='store_true',
help='force re-scan even if cache exists')
parser.add_argument('--cache_file',
type=str,
default=None,
help='path to cache file (default: root_path/.audio_file_cache.pkl)')
parser.add_argument('--scan_workers',
type=int,
default=8,
help='number of workers for directory scanning')
parser.add_argument('--file_list',
type=str,
default=None,
help='path to pre-generated file list (one file per line)')
parser.add_argument('--skip_existing',
action='store_true',
help='skip files that already have _fsq.pt output')
args = parser.parse_args()
return args
def save_tokens(file_path, codes, codes_len):
"""Save tokens as .pt file with _fsq suffix"""
# Remove extension and add _fsq.pt
base_name = os.path.splitext(file_path)[0]
output_path = f"{base_name}_fsq.pt"
# Extract only valid codes (up to codes_len)
valid_codes = codes[:codes_len]
# Save as tensor
torch.save(valid_codes, output_path)
return output_path
def main():
args = get_args()
if args.device == "cuda":
assert (torch.cuda.is_available())
world_size, local_rank, rank = init_distributed()
else:
world_size, local_rank, rank = 1, 0, 0
device = torch.device(args.device)
model = s3tokenizer.load_model(args.model).to(device)
# Handle different data loading methods
if args.file_list:
# Option 3: Load from pre-generated file list
print(f"Loading file list from: {args.file_list}")
with open(args.file_list, 'r') as f:
file_paths = []
for line in f:
line = line.strip()
if line:
file_paths.append(line)
# Create a simple dataset
class FileListDataset(Dataset):
def __init__(self, file_paths, skip_existing=False):
self.data = []
skipped_existing = 0
for fp in file_paths:
if skip_existing:
output_path = fp.replace('.wav', '_fsq.pt')
if os.path.exists(output_path):
print(f'*******skip file {output_path}')
skipped_existing += 1
continue
self.data.append(fp)
print(f"Will process {len(self.data)} files")
if skip_existing and skipped_existing > 0:
print(f"Skipped {skipped_existing} already processed files")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
file_path = self.data[idx]
try:
# Check if file exists
if not os.path.exists(file_path):
print(f"File not found: {file_path}")
return None, None
# Try to load audio
audio = s3tokenizer.load_audio(file_path)
mel = s3tokenizer.log_mel_spectrogram(audio)
return file_path, mel
except Exception as e:
print(f"Error processing {file_path}: {e}")
return None, None
dataset = FileListDataset(file_paths, skip_existing=args.skip_existing)
else:
# Use the enhanced AudioDataset with caching
dataset = AudioDataset(
args.root_path,
args.extensions,
use_cache=not args.no_cache,
cache_file=args.cache_file,
max_workers=args.scan_workers
)
# Filter out existing files if requested
if args.skip_existing:
original_count = len(dataset.data)
dataset.data = [
fp for fp in dataset.data
if not os.path.exists(os.path.join(os.path.dirname(fp), f"{os.path.splitext(os.path.basename(fp))[0]}_fsq.pt"))
]
print(f"Skipping {original_count - len(dataset.data)} already processed files")
if args.device == "cuda":
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank])
sampler = DistributedSampler(dataset,
num_replicas=world_size,
rank=rank)
else:
sampler = None
dataloader = DataLoader(dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=collate_fn)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
processed_count = 0
failed_count = 0
failed_files = []
for file_paths, mels, mels_lens in dataloader:
# Skip empty batches (all files failed)
if len(file_paths) == 0:
continue
codes, codes_lens = model(mels.to(device), mels_lens.to(device))
# Process each file in the batch
for i, file_path in enumerate(file_paths):
try:
code = codes[i]
code_len = codes_lens[i].item()
# Save tokens as .pt file
output_path = save_tokens(file_path, code, code_len)
if rank == 0 and processed_count < 10: # Only show first 10 to avoid spam
tqdm.write(f"Saved: {file_path} -> {output_path}")
processed_count += 1
except Exception as e:
failed_count += 1
failed_files.append(file_path)
if rank == 0:
tqdm.write(f"Failed to save {file_path}: {e}")
if rank == 0:
progress_bar.update(world_size * (len(file_paths) + failed_count))
if rank == 0:
progress_bar.close()
print(f"\nProcessed {processed_count} files successfully on rank {rank}")
if failed_count > 0:
print(f"Failed to process {failed_count} files")
# Save failed files list
failed_list_path = os.path.join(args.root_path if not args.file_list else ".", "failed_files.txt")
with open(failed_list_path, 'w') as f:
for ff in failed_files:
f.write(f"{ff}\n")
print(f"Failed files saved to: {failed_list_path}")
if args.device == "cuda":
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()