from transformers import EncodecModel, AutoProcessor import torch from audiocraft.data.audio import audio_read, audio_write import datetime import IPython import os import julius from transformers import EncodecModel from typing import List, Optional, Tuple, Union class EncodecNoQuantizeModel(EncodecModel): def _encode_frame( self, input_values: torch.Tensor, bandwidth: float, padding_mask: int ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first normalized. The padding mask is required to compute the correct scale. """ length = input_values.shape[-1] duration = length / self.config.sampling_rate if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s: raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}") scale = None if self.config.normalize: # if the padding is non zero input_values = input_values * padding_mask mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1] scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8 input_values = input_values / scale embeddings = self.encoder(input_values) # codes = self.quantizer.encode(embeddings, bandwidth) # codes = codes.transpose(0, 1) return embeddings, scale def _decode_frame(self, embeddings: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor: # codes = codes.transpose(0, 1) # embeddings = self.quantizer.decode(codes) outputs = self.decoder(embeddings) if scale is not None: outputs = outputs * scale.view(-1, 1, 1) return outputs MODEL_SAMPLING_RATE = 48000 def load_model(): # load the model + processor (for pre-processing the audio) model = EncodecNoQuantizeModel.from_pretrained("facebook/encodec_48khz").to("cuda:0") processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") return model, processor @torch.no_grad() def invert_audio( model, processor, input_audio, sampling_rate, normalize=True, flip_input=True, flip_output=False): model.config.normalize = normalize # Check and resample the input audio if necessary if sampling_rate != MODEL_SAMPLING_RATE: input_audio = julius.resample_frac(input_audio, sampling_rate, MODEL_SAMPLING_RATE) # Flip the audio if required if flip_input: input_audio = torch.flip(input_audio, dims=(1,)) # Pre-process the inputs inputs_1 = processor(raw_audio=input_audio, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt") inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0") inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0") # Explicitly encode then decode the audio inputs print("Encoding...") encoder_outputs_1 = model.encode( inputs_1["input_values"], inputs_1["padding_mask"], bandwidth=max(model.config.target_bandwidths)) avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True) avg_repeat = avg.repeat( encoder_outputs_1.audio_codes.shape[0], encoder_outputs_1.audio_codes.shape[1], 1, encoder_outputs_1.audio_codes.shape[3]) diff_repeat = encoder_outputs_1.audio_codes - avg_repeat POWER_FACTOR = 1 max_abs_diff = torch.max(torch.abs(diff_repeat)) diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power # Inversion of difference latents = latents * -1.0 print("Decoding...") audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0] if flip_output: audio_values = torch.flip(audio_values, dims=(2,)) # Return the decoded audio tensor (or NumPy array, based on your audio_write function) decoded_wav = audio_values.squeeze(0).to("cpu") return decoded_wav