|
import numpy as np |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import soundfile as sf |
|
from collections import defaultdict |
|
from dtw import dtw |
|
from sklearn_extra.cluster import KMedoids |
|
from scipy import stats |
|
from copy import deepcopy |
|
import os, librosa, json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def z_score(x, mean, std): |
|
return (x - mean) / std |
|
|
|
|
|
|
|
|
|
def get_word_aligns(norm_sent, aln_paths): |
|
""" |
|
Returns a dictionary of word alignments for a given sentence. |
|
""" |
|
word_aligns = defaultdict(list) |
|
slist = norm_sent.split(" ") |
|
|
|
for spk,aln_path in aln_paths: |
|
with open(aln_path) as f: |
|
lines = f.read().splitlines() |
|
lines = [l.split('\t') for l in lines] |
|
try: |
|
assert len(lines) == len(slist) |
|
word_aligns[spk] = [(w,float(s),float(e)) for w,s,e in lines] |
|
except: |
|
print(slist, lines, "<---- something didn't match") |
|
return word_aligns |
|
|
|
|
|
|
|
def get_pitches(start_time, end_time, fpath): |
|
""" |
|
Returns an array of pitch values for a given speech. |
|
Reads from .f0 file of Time, F0, IsVoiced |
|
""" |
|
with open(fpath) as f: |
|
lines = f.read().splitlines() |
|
lines = [[float(x) for x in line.split()] for line in lines] |
|
pitches = [] |
|
|
|
|
|
mean = np.mean([line[1] for line in lines if line[2] == 1]) |
|
|
|
std = np.std([line[1] for line in lines if line[2] == 1]) |
|
|
|
tracked = [p for t,p,v in lines if v == 1] |
|
if tracked: |
|
low = min(tracked) - 1 |
|
for line in lines: |
|
time, pitch, is_pitch = line |
|
if start_time <= time <= end_time: |
|
if is_pitch == 1: |
|
pitches.append(z_score(pitch, mean, std)) |
|
else: |
|
pitches.append(z_score(low, mean, std)) |
|
|
|
return pitches |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_rmse(start_time, end_time, wpath, znorm = True): |
|
""" |
|
Returns an array of RMSE values for a given speech. |
|
""" |
|
|
|
audio, sr = librosa.load(wpath, sr=16000) |
|
hop = 80 |
|
|
|
rmse = librosa.feature.rms(y=audio,frame_length=480,hop_length=hop) |
|
rmse = rmse[0] |
|
if znorm: |
|
rmse = stats.zscore(rmse) |
|
segment = rmse[int(np.floor(start_time * sr/hop)):int(np.ceil(end_time * sr/hop))] |
|
|
|
return segment |
|
|
|
|
|
|
|
def downsample_rmse2pitch(rmse,pitch_len): |
|
idx = np.round(np.linspace(0, len(rmse) - 1, pitch_len)).astype(int) |
|
return rmse[idx] |
|
|
|
|
|
|
|
|
|
|
|
def parse_word_indices(start_end_word_index): |
|
ixs = start_end_word_index.split('-') |
|
if len(ixs) == 1: |
|
s = int(ixs[0]) |
|
e = int(ixs[0]) |
|
else: |
|
s = int(ixs[0]) |
|
e = int(ixs[-1]) |
|
return s-1,e-1 |
|
|
|
|
|
|
|
|
|
def get_data(norm_sent,path_key,start_end_word_index): |
|
""" |
|
Returns a dictionary of pitch, rmse, and spectral centroids values for a given sentence/word combinations. |
|
""" |
|
|
|
s_ix, e_ix = parse_word_indices(start_end_word_index) |
|
words = '_'.join(norm_sent.split(' ')[s_ix:e_ix+1]) |
|
|
|
align_paths = [(spk,pdict['aln']) for spk,pdict in path_key] |
|
word_aligns = get_word_aligns(norm_sent, align_paths) |
|
|
|
data = defaultdict(list) |
|
align_data = defaultdict(list) |
|
playable_audio = {} |
|
exclude = [] |
|
|
|
for spk, pdict in path_key: |
|
word_al = word_aligns[spk] |
|
start_time = word_al[s_ix][1] |
|
end_time = word_al[e_ix][2] |
|
|
|
seg_aligns = word_al[s_ix:e_ix+1] |
|
seg_aligns = [(w,round(s-start_time,2),round(e-start_time,2)) for w,s,e in seg_aligns] |
|
|
|
pitches = get_pitches(start_time, end_time, pdict['f0']) |
|
|
|
rmses = get_rmse(start_time, end_time, pdict['wav']) |
|
rmses = downsample_rmse2pitch(rmses,len(pitches)) |
|
|
|
|
|
if pitches and seg_aligns: |
|
pitches_cpy = np.array(deepcopy(pitches)) |
|
rmses_cpy = np.array(deepcopy(rmses)) |
|
d = [[p, r] for p, r in zip(pitches_cpy, rmses_cpy)] |
|
|
|
data[f"{words}**{spk}"] = d |
|
align_data[f"{words}**{spk}"] = seg_aligns |
|
playable_audio[spk] = (pdict['play'], start_time, end_time) |
|
else: |
|
exclude.append(spk) |
|
|
|
return words, data, align_data, exclude, playable_audio |
|
|
|
|
|
|
|
def dtw_distance(x, y): |
|
""" |
|
Returns the DTW distance between two pitch sequences. |
|
""" |
|
|
|
alignment = dtw(x, y, keep_internals=True) |
|
return alignment.normalizedDistance |
|
|
|
|
|
|
|
|
|
|
|
|
|
def pair_dists(data,words,recs): |
|
|
|
dtw_dists = [] |
|
|
|
for rec1 in recs: |
|
key1 = f'{words}**{rec1}' |
|
val1 = data[key1] |
|
for rec2 in recs: |
|
key2 = f'{words}**{rec2}' |
|
val2 = data[key2] |
|
dtw_dists.append((f"{rec1}**{rec2}", dtw_distance(val1, val2))) |
|
|
|
return dtw_dists |
|
|
|
|
|
|
|
|
|
|
|
def kmedoids_clustering(X): |
|
kmedoids = KMedoids(n_clusters=3, random_state=0).fit(X) |
|
y_km = kmedoids.labels_ |
|
return y_km, kmedoids |
|
|
|
|
|
|
|
def match_tts(clusters, speech_data, tts_data, tts_align, words, seg_aligns, voice): |
|
|
|
|
|
tts_info = [] |
|
for label in set([c for r,c in clusters]): |
|
recs = [r for r,c in clusters if c==label] |
|
dists = [] |
|
for rec in recs: |
|
key = f'{words}**{rec}' |
|
dists.append(dtw_distance(tts_data, speech_data[key])) |
|
tts_info.append((label,np.nanmean(dists))) |
|
|
|
tts_info = sorted(tts_info,key = lambda x: x[1]) |
|
best_cluster = tts_info[0][0] |
|
best_cluster_score = tts_info[0][1] |
|
|
|
matched_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==best_cluster} |
|
|
|
|
|
|
|
|
|
mid_cluster = tts_info[1][0] |
|
mid_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==mid_cluster} |
|
bad_cluster = tts_info[2][0] |
|
bad_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==bad_cluster} |
|
|
|
|
|
tts_fig_p, best_cc = plot_one_cluster(words,'pitch',matched_data,seg_aligns,best_cluster,tts_data=tts_data,tts_align=tts_align,voice=voice) |
|
fig_mid_p, mid_cc = plot_one_cluster(words,'pitch',mid_data,seg_aligns,mid_cluster) |
|
fig_bad_p, bad_cc = plot_one_cluster(words,'pitch',bad_data,seg_aligns,bad_cluster) |
|
|
|
|
|
tts_fig_e, _ = plot_one_cluster(words,'rmse',matched_data,seg_aligns,best_cluster,tts_data=tts_data,tts_align=tts_align,voice=voice) |
|
fig_mid_e, _ = plot_one_cluster(words,'rmse',mid_data,seg_aligns,mid_cluster) |
|
fig_bad_e, _ = plot_one_cluster(words,'rmse',bad_data,seg_aligns,bad_cluster) |
|
|
|
|
|
|
|
|
|
spk_cc_map = [('Best',best_cluster,best_cc), ('Mid',mid_cluster,mid_cc), ('Last',bad_cluster,bad_cc)] |
|
|
|
|
|
return best_cluster_score, tts_fig_p, fig_mid_p, fig_bad_p, tts_fig_e, fig_mid_e, fig_bad_e |
|
|
|
|
|
|
|
def gp(d,s,x): |
|
return os.path.join(d, f'{s}.{x}') |
|
|
|
def gen_tts_paths(tdir,voices): |
|
plist = [(v, {'wav': gp(tdir,v,'wav'), 'aln': gp(tdir,v,'tsv'), 'f0': gp(tdir,v,'f0'), 'play': gp(tdir,v,'wav')}) for v in voices] |
|
return plist |
|
|
|
def gen_h_paths(wdir,adir,f0dir,pldir,spks): |
|
plist = [(s, {'wav': gp(wdir,s,'wav'), 'aln': gp(adir,s,'tsv'), 'f0': gp(f0dir,s,'f0'), 'play': gp(pldir,s,'wav')}) for s in spks] |
|
return plist |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cluster(norm_sent,orig_sent,h_spk_ids, h_align_dir, h_f0_dir, h_wav_dir, h_play_dir, tts_sent_dir, voices, start_end_word_index): |
|
|
|
h_spk_ids = sorted(h_spk_ids) |
|
h_all_paths = gen_h_paths(h_wav_dir,h_align_dir,h_f0_dir,h_play_dir,h_spk_ids) |
|
|
|
words, h_data, h_seg_aligns, drop_spk, h_playable = get_data(norm_sent,h_all_paths,start_end_word_index) |
|
h_spk_ids = [spk for spk in h_spk_ids if spk not in drop_spk] |
|
h_all_paths = [pinfo for pinfo in h_all_paths if pinfo[0] not in drop_spk] |
|
nsents = len(h_spk_ids) |
|
|
|
dtw_dists = pair_dists(h_data,words,h_spk_ids) |
|
|
|
kmedoids_cluster_dists = [] |
|
|
|
X = [d[1] for d in dtw_dists] |
|
X = [X[i:i+nsents] for i in range(0, len(X), nsents)] |
|
X = np.array(X) |
|
|
|
y_km, kmedoids = kmedoids_clustering(X) |
|
|
|
|
|
|
|
result = zip(X, kmedoids.labels_) |
|
groups = [[r,c] for r,c in zip(h_spk_ids,kmedoids.labels_)] |
|
|
|
|
|
tts_all_paths = gen_tts_paths(tts_sent_dir, voices) |
|
_, tts_data, tts_seg_aligns, _, _ = get_data(norm_sent,tts_all_paths,start_end_word_index) |
|
|
|
for v in voices: |
|
voice_data = tts_data[f"{words}**{v}"] |
|
voice_align = tts_seg_aligns[f"{words}**{v}"] |
|
|
|
|
|
|
|
|
|
best_cluster_score, tts_fig_p, fig_mid_p, fig_bad_p, tts_fig_e, fig_mid_e, fig_bad_e = match_tts(groups, h_data, voice_data, voice_align, words, h_seg_aligns,v) |
|
|
|
|
|
audio_html = clusters_audio(groups,h_playable) |
|
|
|
|
|
return best_cluster_score, tts_fig_p, fig_mid_p, fig_bad_p, tts_fig_e, fig_mid_e, fig_bad_e, audio_html |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clusters_audio(clusters,audios): |
|
|
|
html = '''<html><body>''' |
|
|
|
for label in set([c for r,c in clusters]): |
|
recs = [r for r,c in clusters if c==label] |
|
|
|
html += '<div>' |
|
html += f'<h2>Cluster {label}</h2>' |
|
|
|
html += '<div>' |
|
html += '<table><tbody>' |
|
|
|
for rec in recs: |
|
html += f'<tr><td><audio controls id="{rec}">' |
|
|
|
html += f'<source src="{audios[rec][0]}#t={audios[rec][1]:.2f},{audios[rec][2]:.2f}" type="audio/wav">' |
|
|
|
|
|
html += '</audio></td>' |
|
html += f'<td>{rec}</td></tr>' |
|
|
|
html += '</tbody></table>' |
|
html += '</div>' |
|
|
|
|
|
html += '</div>' |
|
html += '</body></html>' |
|
|
|
return html |
|
|
|
|
|
|
|
|
|
def reset_cluster_times(words,cluster_speakers,human_aligns,tts_align): |
|
words = words.split('_') |
|
|
|
retimes = [(words[0], 0.0)] |
|
for i in range(len(words)-1): |
|
gaps = [human_aligns[spk][i+1][1]-human_aligns[spk][i][1] for spk in cluster_speakers] |
|
if tts_align: |
|
gaps.append(tts_align[i+1][1] - tts_align[i][1]) |
|
retimes.append((words[i+1],retimes[i][1]+max(gaps))) |
|
return retimes |
|
|
|
|
|
def retime_speaker_xvals(retimes, speaker_aligns, speaker_xvals): |
|
new_xvals = [] |
|
def xlim(x,i,retimes,speaker_aligns): |
|
return (x < speaker_aligns[i+1][1]) if i+1<len(retimes) else True |
|
|
|
for i in range(len(retimes)): |
|
wd,st = retimes[i] |
|
w,s,e = speaker_aligns[i] |
|
xdiff = st-s |
|
new_xvals += [x+xdiff for x in speaker_xvals if (x>= s) and xlim(x,i,retimes,speaker_aligns) ] |
|
|
|
return [round(x,3) for x in new_xvals] |
|
|
|
|
|
|
|
def retime_xs_feats(retimes, speaker_aligns, speaker_xvals, feats): |
|
feat_xvals = retime_speaker_xvals(retimes, speaker_aligns, speaker_xvals) |
|
xf0 = list(zip(feat_xvals, feats)) |
|
xf = [xf0[0]] |
|
for x,f in xf0[1:]: |
|
lx = xf[-1][0] |
|
if x - lx >= 0.01: |
|
xf.append((lx+0.005,np.nan)) |
|
xf.append((x,f)) |
|
return [x for x,f in xf], [f for x,f in xf] |
|
|
|
|
|
|
|
def plot_one_cluster(words,feature,speech_data,seg_aligns,cluster_id,tts_data=None,tts_align=None,voice=None): |
|
|
|
colors = ["red", "green", "blue", "orange", "purple", "pink", "brown", "gray", "cyan"] |
|
cc = 0 |
|
spk_ccs = [] |
|
fig = plt.figure(figsize=(10, 5)) |
|
|
|
if feature.lower() in ['pitch','f0']: |
|
fname = 'Pitch' |
|
ffunc = lambda x: [p for p,e in x] |
|
pfunc = plt.scatter |
|
elif feature.lower() in ['energy', 'rmse']: |
|
fname = 'Energy' |
|
ffunc = lambda x: [e for p,e in x] |
|
pfunc = plt.plot |
|
else: |
|
print('problem with the figure') |
|
return fig, [] |
|
|
|
|
|
|
|
retimes = reset_cluster_times(words,list(speech_data.keys()),seg_aligns,tts_align) |
|
if len(retimes)>1: |
|
for w,bound_line in retimes: |
|
plt.axvline(x=bound_line, color="gray", linestyle='--', linewidth=1, label=f'Start "{w}"') |
|
|
|
plt.title(f"{words} - {fname} - Cluster {cluster_id}") |
|
|
|
for k,v in speech_data.items(): |
|
|
|
spk = k.split('**')[1] |
|
word_times = seg_aligns[k] |
|
|
|
feats = ffunc(v) |
|
|
|
feat_xvals = [x*0.005 for x in range(len(feats))] |
|
|
|
feat_xvals, feats = retime_xs_feats(retimes,word_times,feat_xvals,feats) |
|
pfunc(feat_xvals, feats, color=colors[cc], label=f"Speaker {spk}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spk_ccs.append((spk,colors[cc])) |
|
cc += 1 |
|
if cc >= len(colors): |
|
cc=0 |
|
|
|
if voice: |
|
tfeats = ffunc(tts_data) |
|
t_xvals = [x*0.005 for x in range(len(tfeats))] |
|
|
|
t_xvals, tfeats = retime_xs_feats(retimes, tts_align, t_xvals, tfeats) |
|
pfunc(t_xvals, tfeats, color="black", label=f"TTS {voice}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return fig, spk_ccs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|