|
|
import io |
|
|
import math |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchaudio |
|
|
import torchaudio.compliance.kaldi as kaldi |
|
|
import traceback |
|
|
from torch.utils.dlpack import from_dlpack |
|
|
import triton_python_backend_utils as pb_utils |
|
|
|
|
|
|
|
|
class TritonPythonModel: |
|
|
def initialize(self, args): |
|
|
self.sample_rate = 16000 |
|
|
self.feature_dim = 80 |
|
|
self.vad_enabled = True |
|
|
self.min_duration = 0.1 |
|
|
|
|
|
|
|
|
self.speaker_model_name = "speaker_model" |
|
|
|
|
|
def execute(self, requests): |
|
|
responses = [] |
|
|
for request in requests: |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0] |
|
|
audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0] |
|
|
|
|
|
|
|
|
feats1 = self.preprocess(audio1_bytes) |
|
|
feats2 = self.preprocess(audio2_bytes) |
|
|
|
|
|
|
|
|
similarity = self.compute_similarity(feats1, feats2) |
|
|
|
|
|
pb_utils.Logger.log_info(similarity) |
|
|
|
|
|
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32)) |
|
|
response = pb_utils.InferenceResponse(output_tensors=[output_tensor]) |
|
|
responses.append(response) |
|
|
|
|
|
except pb_utils.TritonModelException as e: |
|
|
|
|
|
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e))) |
|
|
responses.append(error_response) |
|
|
except Exception as e: |
|
|
|
|
|
error_message = f"Unexpected error: {e}\n{traceback.format_exc()}" |
|
|
pb_utils.Logger.log_error(error_message) |
|
|
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message)) |
|
|
responses.append(error_response) |
|
|
|
|
|
return responses |
|
|
|
|
|
def preprocess(self, audio_bytes: bytes): |
|
|
""" |
|
|
Processes audio data from an in-memory byte buffer. |
|
|
If the audio is too short, it's padded by repetition to meet the minimum length. |
|
|
""" |
|
|
try: |
|
|
|
|
|
|
|
|
buffer = audio_bytes.decode('utf-8') |
|
|
waveform, sample_rate = torchaudio.load(buffer) |
|
|
|
|
|
|
|
|
if sample_rate != self.sample_rate: |
|
|
|
|
|
|
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate) |
|
|
waveform = resampler(waveform) |
|
|
|
|
|
duration = waveform.shape[1] / self.sample_rate |
|
|
|
|
|
if duration < self.min_duration: |
|
|
|
|
|
repeat_times = math.ceil(self.min_duration / duration) |
|
|
waveform = waveform.repeat(1, repeat_times) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features = kaldi.fbank( |
|
|
waveform.squeeze(0).unsqueeze(0), |
|
|
num_mel_bins=self.feature_dim, |
|
|
sample_frequency=self.sample_rate, |
|
|
frame_length=25, |
|
|
frame_shift=10 |
|
|
) |
|
|
|
|
|
|
|
|
return features.squeeze(0) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}") |
|
|
|
|
|
def compute_similarity(self, waveform1, waveform2): |
|
|
|
|
|
|
|
|
e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda") |
|
|
e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda") |
|
|
|
|
|
|
|
|
e1 = e1.flatten() |
|
|
e2 = e2.flatten() |
|
|
|
|
|
|
|
|
dot_product = torch.dot(e1, e2) |
|
|
norm_e1 = torch.norm(e1) |
|
|
norm_e2 = torch.norm(e2) |
|
|
|
|
|
|
|
|
if norm_e1 == 0 or norm_e2 == 0: |
|
|
return 0.0 |
|
|
|
|
|
similarity = (dot_product / (norm_e1 * norm_e2)).item() |
|
|
|
|
|
|
|
|
return (similarity + 1) / 2 |
|
|
|
|
|
def call_speaker_model(self, waveform): |
|
|
"""Calls the speaker_model to get an embedding vector.""" |
|
|
|
|
|
|
|
|
if waveform.dim() == 2: |
|
|
waveform = waveform.unsqueeze(0) |
|
|
input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32)) |
|
|
|
|
|
inference_request = pb_utils.InferenceRequest( |
|
|
model_name=self.speaker_model_name, |
|
|
requested_output_names=["embs"], |
|
|
inputs=[input_tensor] |
|
|
) |
|
|
|
|
|
inference_response = inference_request.exec() |
|
|
|
|
|
if inference_response.has_error(): |
|
|
raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}") |
|
|
|
|
|
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs") |
|
|
if output_tensor.is_cpu(): |
|
|
output_tensor = output_tensor.as_numpy() |
|
|
else: |
|
|
output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy() |
|
|
return output_tensor |