RACK808 / model.py
cmagganas's picture
Update model.py
9017fab
from transformers import EncodecModel, AutoProcessor
import torch
from audiocraft.data.audio import audio_read, audio_write
import datetime
import IPython
import os
import julius
from transformers import EncodecModel
from typing import List, Optional, Tuple, Union
class EncodecNoQuantizeModel(EncodecModel):
def _encode_frame(
self, input_values: torch.Tensor, bandwidth: float, padding_mask: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
normalized. The padding mask is required to compute the correct scale.
"""
length = input_values.shape[-1]
duration = length / self.config.sampling_rate
if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s:
raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}")
scale = None
if self.config.normalize:
# if the padding is non zero
input_values = input_values * padding_mask
mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
input_values = input_values / scale
embeddings = self.encoder(input_values)
# codes = self.quantizer.encode(embeddings, bandwidth)
# codes = codes.transpose(0, 1)
return embeddings, scale
def _decode_frame(self, embeddings: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor:
# codes = codes.transpose(0, 1)
# embeddings = self.quantizer.decode(codes)
outputs = self.decoder(embeddings)
if scale is not None:
outputs = outputs * scale.view(-1, 1, 1)
return outputs
MODEL_SAMPLING_RATE = 48000
def load_model():
# load the model + processor (for pre-processing the audio)
model = EncodecNoQuantizeModel.from_pretrained("facebook/encodec_48khz").to("cuda:0")
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
return model, processor
@torch.no_grad()
def invert_audio(
model, processor, input_audio, sampling_rate,
normalize=True, flip_input=True, flip_output=False):
model.config.normalize = normalize
# Check and resample the input audio if necessary
if sampling_rate != MODEL_SAMPLING_RATE:
input_audio = julius.resample_frac(input_audio, sampling_rate, MODEL_SAMPLING_RATE)
# Flip the audio if required
if flip_input:
input_audio = torch.flip(input_audio, dims=(1,))
# Pre-process the inputs
inputs_1 = processor(raw_audio=input_audio, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt")
inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")
# Explicitly encode then decode the audio inputs
print("Encoding...")
encoder_outputs_1 = model.encode(
inputs_1["input_values"],
inputs_1["padding_mask"],
bandwidth=max(model.config.target_bandwidths))
avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
avg_repeat = avg.repeat(
encoder_outputs_1.audio_codes.shape[0],
encoder_outputs_1.audio_codes.shape[1],
1,
encoder_outputs_1.audio_codes.shape[3])
diff_repeat = encoder_outputs_1.audio_codes - avg_repeat
POWER_FACTOR = 1
max_abs_diff = torch.max(torch.abs(diff_repeat))
diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power
# Inversion of difference
latents = latents * -1.0
print("Decoding...")
audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]
if flip_output:
audio_values = torch.flip(audio_values, dims=(2,))
# Return the decoded audio tensor (or NumPy array, based on your audio_write function)
decoded_wav = audio_values.squeeze(0).to("cpu")
return decoded_wav