MCplayer's picture
speech similarity model
29c0409
import io
import math
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import traceback
from torch.utils.dlpack import from_dlpack
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args):
self.sample_rate = 16000
self.feature_dim = 80
self.vad_enabled = True # This variable is declared but not used.
self.min_duration = 0.1
# This seems correct for BLS (Business Logic Scripting)
self.speaker_model_name = "speaker_model"
def execute(self, requests):
responses = []
for request in requests:
try:
# 1. Get the input audio BYTES, not a file path string.
# The input tensor is of type TYPE_STRING, which holds bytes.
# .as_numpy()[0] gives you the raw bytes object.
audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0]
audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0]
# 2. Preprocess audio from bytes
feats1 = self.preprocess(audio1_bytes)
feats2 = self.preprocess(audio2_bytes)
# 3. Call the speaker_model to compute similarity
similarity = self.compute_similarity(feats1, feats2)
pb_utils.Logger.log_info(similarity)
# Prepare output
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32))
response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
responses.append(response)
except pb_utils.TritonModelException as e:
# If a Triton-specific error occurs, create an error response
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e)))
responses.append(error_response)
except Exception as e:
# For any other unexpected error, log it and return an error response
error_message = f"Unexpected error: {e}\n{traceback.format_exc()}"
pb_utils.Logger.log_error(error_message)
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message))
responses.append(error_response)
return responses
def preprocess(self, audio_bytes: bytes):
"""
Processes audio data from an in-memory byte buffer.
If the audio is too short, it's padded by repetition to meet the minimum length.
"""
try:
# Wrap the raw bytes in a file-like object for torchaudio
# buffer = io.BytesIO(audio_bytes)
buffer = audio_bytes.decode('utf-8')
waveform, sample_rate = torchaudio.load(buffer)
# You might want to resample if the client's sample rate differs
if sample_rate != self.sample_rate:
# Note: This requires the 'torchaudio.transforms' module.
# Make sure torchaudio is fully installed in your Triton environment.
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
waveform = resampler(waveform)
duration = waveform.shape[1] / self.sample_rate
if duration < self.min_duration:
# Audio is too short, repeat it to meet the minimum duration
repeat_times = math.ceil(self.min_duration / duration)
waveform = waveform.repeat(1, repeat_times)
# --- THIS IS THE NEW, CRITICAL PART ---
# Calculate 80-dimensional Fbank features, which is what the speaker_model expects.
# The waveform needs to be shape [batch, time], so we squeeze it.
features = kaldi.fbank(
waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T]
num_mel_bins=self.feature_dim, # This is 80
sample_frequency=self.sample_rate,
frame_length=25,
frame_shift=10
)
# The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80]
# We need [num_frames, num_bins] for the speaker model
return features.squeeze(0) # Returns shape [num_frames, 80]
except Exception as e:
# Raise a specific exception that can be caught in execute()
raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}")
def compute_similarity(self, waveform1, waveform2):
# Call speaker_model to get embeddings
# Assuming speaker_model takes a waveform and outputs an embedding
e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda")
e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda")
# Flatten the tensors
e1 = e1.flatten()
e2 = e2.flatten()
# Calculate cosine similarity
dot_product = torch.dot(e1, e2)
norm_e1 = torch.norm(e1)
norm_e2 = torch.norm(e2)
# Handle zero norms
if norm_e1 == 0 or norm_e2 == 0:
return 0.0
similarity = (dot_product / (norm_e1 * norm_e2)).item()
# Normalize from [-1, 1] to [0, 1]
return (similarity + 1) / 2
def call_speaker_model(self, waveform):
"""Calls the speaker_model to get an embedding vector."""
# Create the input tensor for the speaker_model.
# The name 'feats' here must match the input name in speaker_model's config.pbtxt
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0)
input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32))
inference_request = pb_utils.InferenceRequest(
model_name=self.speaker_model_name,
requested_output_names=["embs"], # Must match output name in speaker_model's config
inputs=[input_tensor]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}")
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
if output_tensor.is_cpu():
output_tensor = output_tensor.as_numpy()
else:
output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy()
return output_tensor