De-limiter / separate_func /conv_tasnet_separate.py
jeonchangbin49's picture
first commit
a00b67a
raw
history blame
3.13 kB
import os
import soundfile as sf
import torch
import pyloudnorm as pyln
import librosa
import matplotlib
import matplotlib.pyplot as plt
from dataloader import SingleTrackSet
from utils import db2linear
def conv_tasnet_separate(
args, our_model, device, track_audio, track_name, meter=None, augmented_gain=None
):
if args.use_singletrackset:
db = SingleTrackSet(
track_audio.squeeze(dim=0),
hop_length=args.data_params.nhop,
num_frame=128,
target_name=args.target,
)
separated = []
for item in db:
item = item.unsqueeze(0).to(device)
estimates, *estimates_vars = our_model(item)
if args.task_params.dataset == "delimit":
estimates = estimates_vars[0]
estimates = estimates.cpu().detach()
separated.append(
estimates[..., db.trim_length : -db.trim_length].cpu().detach().clone()
)
estimates = torch.cat(separated, dim=-1)
estimates = estimates[0, :, : track_audio.shape[-1]].numpy()
else:
estimates, *estimates_vars = our_model(track_audio)
if args.save_histogram and args.task_params.dataset == "delimit":
plt.figure(figsize=(10, 10))
plt.hist(estimates.cpu().detach().numpy().flatten(), bins=100)
os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
plt.savefig(
f"{args.test_output_dir}/{track_name}/{args.target}_histogram.png"
)
if args.task_params.dataset == "delimit":
estimates = estimates_vars[0]
estimates = estimates.cpu().detach().numpy()
estimates = estimates[0, :, : track_audio.shape[-1]]
if args.save_name_as_target:
os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
if args.save_output_loudnorm:
print("SAVE Loudness normalized OUTPUT ")
loudness = meter.integrated_loudness(estimates.T)
estimates = estimates * db2linear(args.save_output_loudnorm - loudness, eps=0.0)
elif augmented_gain != None and args.save_output_loudnorm == None:
estimates = estimates * db2linear(-augmented_gain, eps=0.0)
sf.write(
f"{args.test_output_dir}/{track_name}/{args.target}.wav"
if args.save_name_as_target
else f"{args.test_output_dir}/{track_name}.wav",
estimates.T,
samplerate=args.data_params.sample_rate,
)
if args.save_16k_mono:
estimates_16k_mono = librosa.to_mono(estimates)
estimates_16k_mono = librosa.resample(
estimates_16k_mono,
orig_sr=args.data_params.sample_rate,
target_sr=16000,
)
os.makedirs(f"{args.test_output_dir}_16k_mono/{track_name}", exist_ok=True)
sf.write(
f"{args.test_output_dir}_16k_mono/{track_name}/{args.target}.wav"
if args.save_name_as_target
else f"{args.test_output_dir}_16k_mono/{track_name}.wav",
estimates_16k_mono,
samplerate=16000,
)
return estimates