wavjepa-base / utils.py
GokseninYuksel's picture
Upload model
fefd7ae verified
import torch
def normalize(audio):
mean = audio.mean(dim=(-2, -1), keepdim=True)
std = audio.std(dim=(-2, -1), keepdim=True)
audio = (audio - mean) / (std + 1e-5) # Add epsilon for stability
return audio
def calculate_padding_mask(pad_frames, total_frames, sr, output_steps, process_seconds, device, B):
# How many 2 seconds chunks does this audio have?
# Find it and then multiply by the output_steps.
total_frames = int((total_frames / sr) / process_seconds)
total_output_steps = output_steps * total_frames
mask = torch.zeros((B, total_output_steps), dtype = torch.bool, device = device)
# Check the number of padding tokens that we have in the audio.
output_sr = int(output_steps / process_seconds)
pad_seconds = pad_frames / sr
pad_steps = int(pad_seconds * output_sr)
# Create the mask
mask[..., total_output_steps - pad_steps:] = True
return mask, total_output_steps - pad_steps
def get_timestamps(sample_rate, B, input_audio_len, x):
audio_len = input_audio_len
sec = audio_len / sample_rate
x_len = x.shape[1]
step = sec / x_len * 1000 # sec -> ms
ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0)
ts = ts.repeat(B, 1)
return ts