ameythakur's picture
Deepfake-Audio
1d8403e verified
# ==================================================================================================
# DEEPFAKE AUDIO - encoder/inference.py (Neural Identity Distillation Interface)
# ==================================================================================================
#
# πŸ“ DESCRIPTION
# This module provides the high-level API for using the Speaker Encoder in a
# production environment. It encapsulates the complexities of model loading,
# tensor orchestration, and d-vector derivation. It is the primary bridge
# used by the web interface (app.py) to extract speaker identities from
# uploaded reference audio samples.
#
# πŸ‘€ AUTHORS
# - Amey Thakur (https://github.com/Amey-Thakur)
# - Mega Satish (https://github.com/msatmod)
#
# 🀝🏻 CREDITS
# Original Real-Time Voice Cloning methodology by CorentinJ
# Repository: https://github.com/CorentinJ/Real-Time-Voice-Cloning
#
# πŸ”— PROJECT LINKS
# Repository: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO
# Video Demo: https://youtu.be/i3wnBcbHDbs
# Research: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO/blob/main/DEEPFAKE-AUDIO.ipynb
#
# πŸ“œ LICENSE
# Released under the MIT License
# Release Date: 2021-02-06
# ==================================================================================================
from encoder.params_data import *
from encoder.model import SpeakerEncoder
from encoder.audio import preprocess_wav
from matplotlib import cm
from encoder import audio
from pathlib import Path
import numpy as np
import torch
# --- INTERNAL STATE (SINGLETON PATTERN) ---
_model = None # type: SpeakerEncoder
_device = None # type: torch.device
def load_model(weights_fpath: Path, device=None):
"""
Initializes the Speaker Encoder neural network.
Deserializes the PyTorch state dictionary and prepares the model for eval mode.
"""
global _model, _device
# Precise hardware targeting
if device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif isinstance(device, str):
_device = torch.device(device)
# Constructing the architecture
_model = SpeakerEncoder(_device, torch.device("cpu"))
# Loading serialized weights
checkpoint = torch.load(weights_fpath, map_location=_device, weights_only=False)
_model.load_state_dict(checkpoint["model_state"])
_model.eval()
print("🀝🏻 Encoder Active: Loaded \"%s\" (Step %d)" % (weights_fpath.name, checkpoint["step"]))
def is_loaded():
"""Checks the initialization status of the neural engine."""
return _model is not None
def embed_frames_batch(frames_batch):
"""
Neural Forward Pass: Computes speaker embeddings for a batch of spectrograms.
Returns l2-normalized d-vectors.
"""
if _model is None:
raise Exception("Fatal: Neural Encoder is not initialized. Invoke load_model().")
frames = torch.from_numpy(frames_batch).to(_device)
embed = _model.forward(frames).detach().cpu().numpy()
return embed
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
min_pad_coverage=0.75, overlap=0.5):
"""
Spatio-Temporal Segmentation: Defines how a long utterance is sliced into
overlapping windows for stable embedding derivation.
"""
assert 0 <= overlap < 1
assert 0 < min_pad_coverage <= 1
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
# Window Orchestration
wav_slices, mel_slices = [], []
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
for i in range(0, steps, frame_step):
mel_range = np.array([i, i + partial_utterance_n_frames])
wav_range = mel_range * samples_per_frame
mel_slices.append(slice(*mel_range))
wav_slices.append(slice(*wav_range))
# Defensive Padding Evaluation
last_wav_range = wav_slices[-1]
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
if coverage < min_pad_coverage and len(mel_slices) > 1:
mel_slices = mel_slices[:-1]
wav_slices = wav_slices[:-1]
return wav_slices, mel_slices
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
"""
Core Identity Extraction: Distills a processed waveform into a single
256-dimensional identity vector (d-vector).
"""
# 1. Full-Waveform Processing (Fallback for short utterances)
if not using_partials:
frames = audio.wav_to_mel_spectrogram(wav)
embed = embed_frames_batch(frames[None, ...])[0]
if return_partials:
return embed, None, None
return embed
# 2. Windowed Distillation
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
max_wave_length = wave_slices[-1].stop
if max_wave_length >= len(wav):
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
# 3. Batch Inference on Windows
frames = audio.wav_to_mel_spectrogram(wav)
frames_batch = np.array([frames[s] for s in mel_slices])
partial_embeds = embed_frames_batch(frames_batch)
# 4. Statistical Averaging & Re-Normalization
raw_embed = np.mean(partial_embeds, axis=0)
embed = raw_embed / np.linalg.norm(raw_embed, 2)
if return_partials:
return embed, partial_embeds, wave_slices
return embed
def embed_speaker(wavs, **kwargs):
"""Aggregate identity extraction for multiple utterances from the same speaker."""
raise NotImplementedError("Collaborative development in progress for multi-wav aggregation.")
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
"""Visualizes the high-dimensional latent vector as a spatial intensity map."""
import matplotlib.pyplot as plt
if ax is None:
ax = plt.gca()
if shape is None:
height = int(np.sqrt(len(embed)))
shape = (height, -1)
embed = embed.reshape(shape)
cmap = plt.get_cmap()
mappable = ax.imshow(embed, cmap=cmap)
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
sm = cm.ScalarMappable(cmap=cmap)
sm.set_clim(*color_range)
ax.set_xticks([]), ax.set_yticks([])
ax.set_title(title)