|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FULL_WAV = [ |
|
'english_hfullh.wav', |
|
'english_4x_hfullh.wav', |
|
'human_hfullh.wav', |
|
'foreign_hfullh.wav', |
|
'foreign_4x_hfullh.wav', |
|
] |
|
WIN = 40 |
|
HOP = 10 |
|
import pandas as pd |
|
import os |
|
|
|
import json |
|
import numpy as np |
|
import audonnx |
|
import audb |
|
from pathlib import Path |
|
import transformers |
|
import torch |
|
import audmodel |
|
import audinterface |
|
import matplotlib.pyplot as plt |
|
import audiofile |
|
|
|
LABELS = ['arousal', 'dominance', 'valence', |
|
|
|
'Angry', |
|
'Sad', |
|
'Happy', |
|
'Surprise', |
|
'Fear', |
|
'Disgust', |
|
'Contempt', |
|
'Neutral' |
|
] |
|
|
|
|
|
config = transformers.Wav2Vec2Config() |
|
config.dev = torch.device('cuda:0') |
|
config.dev2 = torch.device('cuda:0') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _softmax(x): |
|
'''x : (batch, num_class)''' |
|
x -= x.max(1, keepdims=True) |
|
x = np.maximum(-100, x) |
|
x = np.exp(x) |
|
x /= x.sum(1, keepdims=True) |
|
return x |
|
|
|
def _sigmoid(x): |
|
'''x : (batch, num_class)''' |
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for long_audio in FULL_WAV: |
|
file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl' |
|
if not os.path.exists(file_interface): |
|
|
|
|
|
print('_______________________________________\nProcessing\n', file_interface, '\n___________') |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForAudioClassification |
|
import types |
|
def _infer(self, x): |
|
'''x: (batch, audio-samples-16KHz)''' |
|
x = (x + self.config.mean) / self.config.std |
|
x = self.ssl_model(x, attention_mask=None).last_hidden_state |
|
|
|
h = self.pool_model.sap_linear(x).tanh() |
|
w = torch.matmul(h, self.pool_model.attention) |
|
w = w.softmax(1) |
|
mu = (x * w).sum(1) |
|
x = torch.cat( |
|
[ |
|
mu, |
|
((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt() |
|
], 1) |
|
return self.ser_model(x) |
|
|
|
teacher_cat = AutoModelForAudioClassification.from_pretrained( |
|
'3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes', |
|
trust_remote_code=True |
|
).to(config.dev2).eval() |
|
teacher_cat.forward = types.MethodType(_infer, teacher_cat) |
|
|
|
|
|
|
|
def _prenorm(x, attention_mask=None): |
|
'''mean/var''' |
|
if attention_mask is not None: |
|
N = attention_mask.sum(1, keepdim=True) |
|
x -= x.sum(1, keepdim=True) / N |
|
var = (x * x).sum(1, keepdim=True) / N |
|
|
|
else: |
|
x -= x.mean(1, keepdim=True) |
|
var = (x * x).mean(1, keepdim=True) |
|
return x / torch.sqrt(var + 1e-7) |
|
|
|
from torch import nn |
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model |
|
class RegressionHead(nn.Module): |
|
r"""Classification head.""" |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__() |
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.final_dropout) |
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, features, **kwargs): |
|
|
|
x = features |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = torch.tanh(x) |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
|
|
return x |
|
|
|
|
|
class Dawn(Wav2Vec2PreTrainedModel): |
|
r"""Speech emotion classifier.""" |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
|
|
self.config = config |
|
self.wav2vec2 = Wav2Vec2Model(config) |
|
self.classifier = RegressionHead(config) |
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_values, |
|
attention_mask=None, |
|
): |
|
x = _prenorm(input_values, attention_mask=attention_mask) |
|
outputs = self.wav2vec2(x, attention_mask=attention_mask) |
|
hidden_states = outputs[0] |
|
hidden_states = torch.mean(hidden_states, dim=1) |
|
logits = self.classifier(hidden_states) |
|
return logits |
|
|
|
|
|
dawn = Dawn.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim').to(config.dev).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_function(x, sampling_rate, idx): |
|
'''run audioset ct, adv |
|
|
|
USE onnx teachers |
|
|
|
return [synth-speech, synth-singing, 7x, 3x adv] = 11 |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
logits_cat = teacher_cat(torch.from_numpy(x).to(config.dev)).cpu().detach().numpy() |
|
|
|
|
|
|
|
|
|
|
|
logits_adv = dawn(torch.from_numpy(x).to(config.dev)).cpu().detach().numpy() |
|
|
|
cat = np.concatenate([logits_adv, |
|
|
|
_softmax(logits_cat)], |
|
1) |
|
print(cat) |
|
return cat |
|
|
|
|
|
|
|
|
|
|
|
interface = audinterface.Feature( |
|
feature_names=LABELS, |
|
process_func=process_function, |
|
|
|
process_func_applies_sliding_window=False, |
|
win_dur=WIN, |
|
hop_dur=HOP, |
|
sampling_rate=16000, |
|
resample=True, |
|
verbose=True, |
|
) |
|
df_pred = interface.process_file(long_audio) |
|
df_pred.to_pickle(file_interface) |
|
else: |
|
print(file_interface, 'FOUND') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preds = {} |
|
SHORTEST_PD = 100000 |
|
for long_audio in FULL_WAV: |
|
file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl' |
|
y = pd.read_pickle(file_interface) |
|
preds[long_audio] = y |
|
SHORTEST_PD = min(SHORTEST_PD, len(y)) |
|
|
|
|
|
|
|
for k,v in preds.items(): |
|
p = v[:SHORTEST_PD] |
|
|
|
p.reset_index(inplace= True) |
|
p.drop(columns=['file','start'], inplace=True) |
|
p.set_index('end', inplace=True) |
|
|
|
p.index = p.index.map(mapper = (lambda x: x.total_seconds())) |
|
preds[k] = p |
|
|
|
|
|
|
|
print(preds.keys(),'p') |
|
|
|
|
|
|
|
|
|
|
|
|
|
for lang in ['english', |
|
'foreign']: |
|
|
|
|
|
fig, ax = plt.subplots(nrows=8, ncols=2, figsize=(24,20.7), |
|
gridspec_kw={'hspace': 0, 'wspace': .04}) |
|
|
|
|
|
|
|
|
|
time_stamp = preds['human_hfullh.wav'].index.to_numpy() |
|
for j, dim in enumerate(['arousal', |
|
'dominance', |
|
'valence']): |
|
|
|
|
|
|
|
ax[j, 0].plot(time_stamp, preds[f'{lang}_hfullh.wav'][dim], |
|
color=(0,104/255,139/255), |
|
label='mean_1', |
|
linewidth=2) |
|
ax[j, 0].fill_between(time_stamp, |
|
|
|
0*preds[f'{lang}_hfullh.wav'][dim], |
|
preds['human_hfullh.wav'][dim], |
|
|
|
color=(.2,.2,.2), |
|
alpha=0.244) |
|
if j == 0: |
|
if lang == 'english': |
|
desc = 'English' |
|
else: |
|
desc = 'Non-English' |
|
ax[j, 0].legend([f'StyleTTS2 using Mimic-3 {desc}', |
|
f'StyleTTS2 uising EmoDB'], |
|
prop={'size': 14}, |
|
) |
|
ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=17) |
|
|
|
|
|
ax[j, 0].set_ylim([1e-7, .9999]) |
|
|
|
|
|
ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()]) |
|
ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]]) |
|
|
|
|
|
|
|
|
|
|
|
ax[j, 1].plot(time_stamp, preds[f'{lang}_4x_hfullh.wav'][dim], |
|
color=(0,104/255,139/255), |
|
label='mean_1', |
|
linewidth=2) |
|
ax[j, 1].fill_between(time_stamp, |
|
|
|
0 * preds[f'{lang}_4x_hfullh.wav'][dim], |
|
preds['human_hfullh.wav'][dim], |
|
|
|
color=(.2,.2,.2), |
|
alpha=0.244) |
|
if j == 0: |
|
if lang == 'english': |
|
desc = 'English' |
|
else: |
|
desc = 'Non-English' |
|
ax[j, 1].legend([f'StyleTTS2 using Mimic-3 {desc} 4x speed', |
|
f'StyleTTS2 using EmoDB'], |
|
prop={'size': 14}, |
|
|
|
) |
|
|
|
|
|
ax[j, 1].set_xlabel('720 Harvard Sentences') |
|
|
|
|
|
|
|
|
|
ax[j, 1].set_ylim([1e-7, .9999]) |
|
|
|
ax[j, 1].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()]) |
|
ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]]) |
|
|
|
|
|
|
|
|
|
ax[j, 0].grid() |
|
ax[j, 1].grid() |
|
|
|
|
|
|
|
|
|
|
|
|
|
time_stamp = preds['human_hfullh.wav'].index.to_numpy() |
|
for j, dim in enumerate(['Angry', |
|
'Sad', |
|
'Happy', |
|
|
|
'Fear', |
|
'Disgust', |
|
|
|
|
|
]): |
|
j = j + 3 |
|
|
|
|
|
|
|
ax[j, 0].plot(time_stamp, preds[f'{lang}_hfullh.wav'][dim], |
|
color=(0,104/255,139/255), |
|
label='mean_1', |
|
linewidth=2) |
|
ax[j, 0].fill_between(time_stamp, |
|
|
|
0*preds[f'{lang}_hfullh.wav'][dim], |
|
preds['human_hfullh.wav'][dim], |
|
|
|
color=(.2,.2,.2), |
|
alpha=0.244) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=17) |
|
|
|
|
|
ax[j, 0].set_ylim([1e-7, .9999]) |
|
ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]]) |
|
ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()]) |
|
ax[j, 0].set_xlabel('720 Harvard Sentences', fontsize=17, color=(.2,.2,.2)) |
|
|
|
|
|
|
|
|
|
|
|
ax[j, 1].plot(time_stamp, preds[f'{lang}_4x_hfullh.wav'][dim], |
|
color=(0,104/255,139/255), |
|
label='mean_1', |
|
linewidth=2) |
|
ax[j, 1].fill_between(time_stamp, |
|
|
|
0*preds[f'{lang}_4x_hfullh.wav'][dim], |
|
preds['human_hfullh.wav'][dim], |
|
|
|
color=(.2,.2,.2), |
|
alpha=0.244) |
|
|
|
|
|
|
|
|
|
|
|
ax[j, 1].set_xlabel('720 Harvard Sentences', fontsize=17, color=(.2,.2,.2)) |
|
ax[j, 1].set_ylim([1e-7, .9999]) |
|
|
|
ax[j, 1].set_xticklabels(['' for _ in ax[j, 1].get_xticklabels()]) |
|
ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ax[j, 0].grid() |
|
ax[j, 1].grid() |
|
|
|
|
|
|
|
plt.savefig(f'fig_{lang}_{WIN=}_{HOP=}_HFdisc.png', bbox_inches='tight') |
|
plt.close() |
|
|
|
|