|
|
import torch |
|
|
|
|
|
def normalize(audio): |
|
|
mean = audio.mean(dim=(-2, -1), keepdim=True) |
|
|
std = audio.std(dim=(-2, -1), keepdim=True) |
|
|
audio = (audio - mean) / (std + 1e-5) |
|
|
return audio |
|
|
|
|
|
def calculate_padding_mask(pad_frames, total_frames, sr, output_steps, process_seconds, device, B): |
|
|
|
|
|
|
|
|
total_frames = int((total_frames / sr) / process_seconds) |
|
|
total_output_steps = output_steps * total_frames |
|
|
mask = torch.zeros((B, total_output_steps), dtype = torch.bool, device = device) |
|
|
|
|
|
|
|
|
output_sr = int(output_steps / process_seconds) |
|
|
pad_seconds = pad_frames / sr |
|
|
pad_steps = int(pad_seconds * output_sr) |
|
|
|
|
|
|
|
|
mask[..., total_output_steps - pad_steps:] = True |
|
|
return mask, total_output_steps - pad_steps |
|
|
|
|
|
|
|
|
def get_timestamps(sample_rate, B, input_audio_len, x): |
|
|
audio_len = input_audio_len |
|
|
sec = audio_len / sample_rate |
|
|
x_len = x.shape[1] |
|
|
step = sec / x_len * 1000 |
|
|
ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0) |
|
|
ts = ts.repeat(B, 1) |
|
|
return ts |