emotion-diarization-wavlm-large / custom_interface.py
yingzhi's picture
Update custom_interface.py
c6f13d8
raw
history blame
6.88 kB
import torch
from speechbrain.pretrained import Pretrained
class Speech_Emotion_Diarization(Pretrained):
"""A ready-to-use SED interface (audio -> emotions and their durations)
Arguments
---------
hparams
Hyperparameters (from HyperPyYAML)
Example
-------
>>> from speechbrain.pretrained import Speech_Emotion_Diarization
>>> tmpdir = getfixture("tmpdir")
>>> sed_model = Speech_Emotion_Diarization.from_hparams(source="speechbrain/emotion-diarization-wavlm-large", savedir=tmpdir,) # doctest: +SKIP
>>> sed_model.diarize_file("speechbrain/emotion-diarization-wavlm-large/example.wav") # doctest: +SKIP
"""
MODULES_NEEDED = ["input_norm", "wav2vec", "output_mlp"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def diarize_file(self, path):
"""Get emotion diarization of a spoken utterance.
Arguments
---------
path : str
Path to audio file which to diarize.
Returns
-------
dict
The emotions and their boundaries.
"""
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
frame_class = self.diarize_batch(
batch, rel_length, [path]
)
return frame_class
def encode_batch(self, wavs, wav_lens):
"""Encodes audios into fine-grained emotional embeddings
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = self.mods.input_norm(wavs, wav_lens)
outputs = self.mods.wav2vec2(wavs)
return outputs
def diarize_batch(self, wavs, wav_lens, batch_id):
"""Get emotion diarization of a batch of waveforms.
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The frame-wise predictions
"""
outputs = self.encode_batch(wavs, wav_lens)
averaged_out = self.hparams.avg_pool(outputs)
outputs = self.mods.output_mlp(averaged_out)
outputs = self.hparams.log_softmax(outputs)
score, index = torch.max(outputs, dim=-1)
preds = self.hparams.label_encoder.decode_torch(index)
results = self.preds_to_diarization(preds, batch_id)
return results
def preds_to_diarization(self, prediction, batch_id):
"""Convert frame-wise predictions into a dictionary of
diarization results.
Returns
-------
dictionary
A dictionary with the start/end of each emotion
"""
results = {}
for i in range(len(prediction)):
pred = prediction[i]
lol = []
for j in range(len(pred)):
start = round(self.hparams.stride * 0.02 * j, 2)
end = round(start + self.hparams.window_length * 0.02, 2)
lol.append([batch_id[i], start, end, pred[j]])
lol = merge_ssegs_same_emotion_adjacent(lol)
results[batch_id[i]] = [{"start": k[1], "end":k[2], "emotion": k[3]} for k in lol]
return results
def forward(self, wavs, wav_lens):
"""Runs full transcription - note: no gradients through decoding"""
return self.transcribe_batch(wavs, wav_lens)
def is_overlapped(end1, start2):
"""Returns True if segments are overlapping.
Arguments
---------
end1 : float
End time of the first segment.
start2 : float
Start time of the second segment.
Returns
-------
overlapped : bool
True of segments overlapped else False.
Example
-------
>>> from speechbrain.processing import diarization as diar
>>> diar.is_overlapped(5.5, 3.4)
True
>>> diar.is_overlapped(5.5, 6.4)
False
"""
if start2 > end1:
return False
else:
return True
def merge_ssegs_same_emotion_adjacent(lol):
"""Merge adjacent sub-segs if they are the same emotion.
Arguments
---------
lol : list of list
Each list contains [utt_id, sseg_start, sseg_end, emo_label].
Returns
-------
new_lol : list of list
new_lol contains adjacent segments merged from the same emotion ID.
Example
-------
>>> from speechbrain.utils.EDER import merge_ssegs_same_emotion_adjacent
>>> lol=[['u1', 0.0, 7.0, 'a'],
... ['u1', 7.0, 9.0, 'a'],
... ['u1', 9.0, 11.0, 'n'],
... ['u1', 11.0, 13.0, 'n'],
... ['u1', 13.0, 15.0, 'n'],
... ['u1', 15.0, 16.0, 'a']]
>>> merge_ssegs_same_emotion_adjacent(lol)
[['u1', 0.0, 9.0, 'a'], ['u1', 9.0, 15.0, 'n'], ['u1', 15.0, 16.0, 'a']]
"""
new_lol = []
# Start from the first sub-seg
sseg = lol[0]
flag = False
for i in range(1, len(lol)):
next_sseg = lol[i]
# IF sub-segments overlap AND has same emotion THEN merge
if is_overlapped(sseg[2], next_sseg[1]) and sseg[3] == next_sseg[3]:
sseg[2] = next_sseg[2] # just update the end time
# This is important. For the last sseg, if it is the same emotion then merge
# Make sure we don't append the last segment once more. Hence, set FLAG=True
if i == len(lol) - 1:
flag = True
new_lol.append(sseg)
else:
new_lol.append(sseg)
sseg = next_sseg
# Add last segment only when it was skipped earlier.
if flag is False:
new_lol.append(lol[-1])
return new_lol