# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Command-line for audio compression.""" import argparse import os import sys import typing as tp from collections import OrderedDict from pathlib import Path import librosa import soundfile as sf import torch from academicodec.models.encodec.net3 import SoundStream def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool=False): limit = 0.99 mx = wav.abs().max() if rescale: wav = wav * min(limit / mx, 1) else: wav = wav.clamp(-limit, limit) wav = wav.squeeze().cpu().numpy() sf.write(path, wav, sample_rate) def get_parser(): parser = argparse.ArgumentParser( 'encodec', description='High fidelity neural audio codec. ' 'If input is a .ecdc, decompresses it. ' 'If input is .wav, compresses it. If output is also wav, ' 'do a compression/decompression cycle.') parser.add_argument( '--input', type=Path, help='Input file, whatever is supported by torchaudio on your system.') parser.add_argument( '--output', type=Path, nargs='?', help='Output file, otherwise inferred from input file.') parser.add_argument( '--resume_path', type=str, default='resume_path', help='resume_path') parser.add_argument( '--sr', type=int, default=16000, help='sample rate of model') parser.add_argument( '-r', '--rescale', action='store_true', help='Automatically rescale the output to avoid clipping.') parser.add_argument( '--ratios', type=int, nargs='+', # probs(ratios) = hop_size default=[8, 5, 4, 2], help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)' ) parser.add_argument( '--target_bandwidths', type=float, nargs='+', # default for 16k_320d default=[1, 1.5, 2, 4, 6, 12], help='target_bandwidths of net3.py') parser.add_argument( '--target_bw', type=float, # default for 16k_320d default=12, help='target_bw of net3.py') return parser def fatal(*args): print(*args, file=sys.stderr) sys.exit(1) # 这只是打印了但是没有真的 clip def check_clipping(wav, rescale): if rescale: return mx = wav.abs().max() limit = 0.99 if mx > limit: print( f"Clipping!! max scale {mx}, limit is {limit}. " "To avoid clipping, use the `-r` option to rescale the output.", file=sys.stderr) def test_one(args, wav_root, store_root, rescale, soundstream): # torchaudio.load 的采样率为原始音频的采样率,不会自动下采样 # wav, sr = torchaudio.load(wav_root) # # 取单声道, output shape [1, T] # wav = wav[0].unsqueeze(0) # # 重采样为模型的采样率 # wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=args.sr)(wav) # load wav with librosa wav, sr = librosa.load(wav_root, sr=args.sr) wav = torch.tensor(wav).unsqueeze(0) # add batch axis wav = wav.unsqueeze(1).cuda() # compressing compressed = soundstream.encode(wav, target_bw=args.target_bw) print('finish compressing') out = soundstream.decode(compressed) out = out.detach().cpu().squeeze(0) check_clipping(out, rescale) save_audio(wav=out, path=store_root, sample_rate=args.sr, rescale=rescale) print('finish decompressing') def remove_encodec_weight_norm(model): from academicodec.modules import SConv1d from academicodec.modules.seanet import SConvTranspose1d from academicodec.modules.seanet import SEANetResnetBlock from torch.nn.utils import remove_weight_norm encoder = model.encoder.model for key in encoder._modules: if isinstance(encoder._modules[key], SEANetResnetBlock): remove_weight_norm(encoder._modules[key].shortcut.conv.conv) block_modules = encoder._modules[key].block._modules for skey in block_modules: if isinstance(block_modules[skey], SConv1d): remove_weight_norm(block_modules[skey].conv.conv) elif isinstance(encoder._modules[key], SConv1d): remove_weight_norm(encoder._modules[key].conv.conv) decoder = model.decoder.model for key in decoder._modules: if isinstance(decoder._modules[key], SEANetResnetBlock): remove_weight_norm(decoder._modules[key].shortcut.conv.conv) block_modules = decoder._modules[key].block._modules for skey in block_modules: if isinstance(block_modules[skey], SConv1d): remove_weight_norm(block_modules[skey].conv.conv) elif isinstance(decoder._modules[key], SConvTranspose1d): remove_weight_norm(decoder._modules[key].convtr.convtr) elif isinstance(decoder._modules[key], SConv1d): remove_weight_norm(decoder._modules[key].conv.conv) def test_batch(): args = get_parser().parse_args() print("args.target_bandwidths:", args.target_bandwidths) if not args.input.exists(): fatal(f"Input file {args.input} does not exist.") input_lists = os.listdir(args.input) input_lists.sort() soundstream = SoundStream( n_filters=32, D=512, ratios=args.ratios, sample_rate=args.sr, target_bandwidths=args.target_bandwidths) parameter_dict = torch.load(args.resume_path) new_state_dict = OrderedDict() # k 为 module.xxx.weight, v 为权重 for k, v in parameter_dict.items(): # 截取`module.`后面的xxx.weight name = k[7:] new_state_dict[name] = v soundstream.load_state_dict(new_state_dict) # load model remove_encodec_weight_norm(soundstream) soundstream.cuda() soundstream.eval() os.makedirs(args.output, exist_ok=True) for audio in input_lists: test_one( args=args, wav_root=os.path.join(args.input, audio), store_root=os.path.join(args.output, audio), rescale=args.rescale, soundstream=soundstream) if __name__ == '__main__': test_batch()