import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import soundfile as sf import colorcet as clc from collections import defaultdict from dtw import dtw from sklearn_extra.cluster import KMedoids from scipy import stats from copy import deepcopy import os, librosa, json # based on original implementation by # https://colab.research.google.com/drive/1RApnJEocx3-mqdQC2h5SH8vucDkSlQYt?authuser=1#scrollTo=410ecd91fa29bc73 # by magnús freyr morthens 2023 supported by rannís nsn def z_score(x, mean, std): return (x - mean) / std # given a sentence and list of its speakers + their alignment files, # return a dictionary of word alignments def get_word_aligns(norm_sent, aln_paths): """ Returns a dictionary of word alignments for a given sentence. """ word_aligns = defaultdict(list) slist = norm_sent.split(" ") for spk,aln_path in aln_paths: with open(aln_path) as f: lines = f.read().splitlines() lines = [l.split('\t') for l in lines] try: assert len(lines) == len(slist) word_aligns[spk] = [(w,float(s),float(e)) for w,s,e in lines] except: print(slist, lines, "<---- something didn't match") return word_aligns def get_pitches(start_time, end_time, fpath): """ Returns an array of pitch values for a given speech. Reads from .f0 file of Time, F0, IsVoiced """ with open(fpath) as f: lines = f.read().splitlines() lines = [[float(x) for x in line.split()] for line in lines] # split lines into floats pitches = [] # find the mean of all pitches in the whole sentence mean = np.mean([line[1] for line in lines if line[2] == 1]) # find the std of all pitches in the whole sentence std = np.std([line[1] for line in lines if line[2] == 1]) tracked = [p for t,p,v in lines if v == 1] if tracked: low = min(tracked) - 1 for line in lines: time, pitch, is_pitch = line if start_time <= time <= end_time: if is_pitch == 1: pitches.append(z_score(pitch, mean, std)) else: pitches.append(z_score(low, mean, std)) #pitches.append(-0.99) return pitches # jcheng used energy from esps get_f0 # get f0 says (?) : #The RMS value of each record is computed based on a 30 msec hanning #window with its left edge placed 5 msec before the beginning of the #frame. # jcheng z-scored the energys, per file. # TODO: implement that. ? # not sure librosa provides hamming window in rms function directly # TODO handle audio that not originally .wav def get_rmse(start_time, end_time, wpath, znorm = True): """ Returns an array of RMSE values for a given speech. """ audio, sr = librosa.load(wpath, sr=16000) hop = 80 #segment = audio[int(np.floor(start_time * sr)):int(np.ceil(end_time * sr))] rmse = librosa.feature.rms(y=audio,frame_length=480,hop_length=hop) rmse = rmse[0] if znorm: rmse = stats.zscore(rmse) segment = rmse[int(np.floor(start_time * sr/hop)):int(np.ceil(end_time * sr/hop))] #idx = np.round(np.linspace(0, len(rmse) - 1, pitch_len)).astype(int) return segment#[idx] # may be unnecessary depending how rmse and pitch window/hop are calculated already def downsample_rmse2pitch(rmse,pitch_len): idx = np.round(np.linspace(0, len(rmse) - 1, pitch_len)).astype(int) return rmse[idx] # parse user input string to usable word indices for the sentence # TODO handle more user input cases def parse_word_indices(start_end_word_index): ixs = start_end_word_index.split('-') if len(ixs) == 1: s = int(ixs[0]) e = int(ixs[0]) else: s = int(ixs[0]) e = int(ixs[-1]) return s-1,e-1 # take any (1stword, lastword) or (word) # unit and prepare data for that unit def get_data(norm_sent,path_key,start_end_word_index): """ Returns a dictionary of pitch, rmse, and spectral centroids values for a given sentence/word combinations. """ s_ix, e_ix = parse_word_indices(start_end_word_index) words = '_'.join(norm_sent.split(' ')[s_ix:e_ix+1]) align_paths = [(spk,pdict['aln']) for spk,pdict in path_key] word_aligns = get_word_aligns(norm_sent, align_paths) data = defaultdict(list) align_data = defaultdict(list) playable_audio = {} exclude = [] for spk, pdict in path_key: word_al = word_aligns[spk] start_time = word_al[s_ix][1] end_time = word_al[e_ix][2] seg_aligns = word_al[s_ix:e_ix+1] seg_aligns = [(w,round(s-start_time,2),round(e-start_time,2)) for w,s,e in seg_aligns] pitches = get_pitches(start_time, end_time, pdict['f0']) rmses = get_rmse(start_time, end_time, pdict['wav']) rmses = downsample_rmse2pitch(rmses,len(pitches)) #spectral_centroids = get_spectral_centroids(start_time, end_time, id, wav_dir, len(pitches)) if pitches and seg_aligns: pitches_cpy = np.array(deepcopy(pitches)) rmses_cpy = np.array(deepcopy(rmses)) d = [[p, r] for p, r in zip(pitches_cpy, rmses_cpy)] #words = "-".join(word_combs) data[f"{words}**{spk}"] = d align_data[f"{words}**{spk}"] = seg_aligns playable_audio[spk] = (pdict['play'], start_time, end_time) else: exclude.append(spk) return words, data, align_data, exclude, playable_audio def dtw_distance(x, y): """ Returns the DTW distance between two pitch sequences. """ alignment = dtw(x, y, keep_internals=True) return alignment.normalizedDistance # recs is a sorted list of rec IDs # all recs/data contain the same words # rec1 and rec2 can be the same def pair_dists(data,words,recs): dtw_dists = [] for rec1 in recs: key1 = f'{words}**{rec1}' val1 = data[key1] for rec2 in recs: key2 = f'{words}**{rec2}' val2 = data[key2] dtw_dists.append((f"{rec1}**{rec2}", dtw_distance(val1, val2))) return dtw_dists # TODO # make n_clusters a param with default 3 def kmedoids_clustering(X): kmedoids = KMedoids(n_clusters=3, random_state=0).fit(X) y_km = kmedoids.labels_ return y_km, kmedoids def match_tts(clusters, speech_data, tts_data, tts_align, words, seg_aligns, voice): tts_info = defaultdict(list) for label in set([c for r,c in clusters]): recs = [r for r,c in clusters if c==label] dists = [] for rec in recs: dists.append(dtw_distance(tts_data[f'{words}**{voice}'], speech_data[f'{words}**{rec}'])) tts_info[voice].append((label,np.nanmean(dists))) #tts_info[voice] = sorted(tts_info[voice],key = lambda x: x[1]) #best_cluster = tts_info[voice][0][0] #best_cluster_score = tts_info[voice][0][1] #tts_pldat = {f'{words}**{voice}': tts_data} f0_fig_tts, _ = plot_one_cluster(words,'pitch',tts_data,tts_align,0,['#c97eb7'],gtype='tts',voice=voice) en_fig_tts, _ = plot_one_cluster(words,'energy',tts_data,tts_align,0,['#9276d9'],gtype='tts',voice=voice) return tts_info[voice], f0_fig_tts, en_fig_tts def gp(d,s,x): return os.path.join(d, f'{s}.{x}') def gen_tts_paths(tdir,voices): plist = [(v, {'wav': gp(tdir,v,'wav'), 'aln': gp(tdir,v,'tsv'), 'f0': gp(tdir,v,'f0'), 'play': gp(tdir,v,'wav')}) for v in voices] return plist def gen_h_paths(wdir,adir,f0dir,pldir,spks): plist = [(s, {'wav': gp(wdir,s,'wav'), 'aln': gp(adir,s,'tsv'), 'f0': gp(f0dir,s,'f0'), 'play': gp(pldir,s,'wav')}) for s in spks] return plist # since clustering strictly operates on X, # once reduce a duration metric down to pair-distances, # it no longer matters that duration and pitch/energy had different dimensionality # TODO option to dtw on 3 feats pitch/ener/dur separately # check if possible cluster with 3dim distance mat? # or can it not take that input in multidimensional space # then the 3 dists can still be averaged to flatten, if appropriately scaled def cluster(norm_sent,orig_sent,h_spk_ids, h_align_dir, h_f0_dir, h_wav_dir, h_play_dir, tts_sent_dir, voices, start_end_word_index): h_spk_ids = sorted(h_spk_ids) h_all_paths = gen_h_paths(h_wav_dir,h_align_dir,h_f0_dir,h_play_dir,h_spk_ids) words, h_data, h_seg_aligns, drop_spk, h_playable = get_data(norm_sent,h_all_paths,start_end_word_index) h_spk_ids = [spk for spk in h_spk_ids if spk not in drop_spk] h_all_paths = [pinfo for pinfo in h_all_paths if pinfo[0] not in drop_spk] nsents = len(h_spk_ids) dtw_dists = pair_dists(h_data,words,h_spk_ids) kmedoids_cluster_dists = [] X = [d[1] for d in dtw_dists] X = [X[i:i+nsents] for i in range(0, len(X), nsents)] X = np.array(X) y_km, kmedoids = kmedoids_clustering(X) result = zip(X, kmedoids.labels_) groups = [[r,c] for r,c in zip(h_spk_ids,kmedoids.labels_)] f0_fig_c0, f0_fig_c1, f0_fig_c2, en_fig_c0, en_fig_c1, en_fig_c2, spk_cc_map = graph_humans(groups,h_data,words,h_seg_aligns) audio_html = clusters_audio(groups,spk_cc_map,h_playable) tts_all_paths = gen_tts_paths(tts_sent_dir, voices) _, tts_data, tts_seg_aligns, _, _ = get_data(norm_sent,tts_all_paths,start_end_word_index) tts_results = defaultdict(dict) for v in voices: #voice_data = tts_data[f"{words}**{v}"] #voice_align = tts_seg_aligns[f"{words}**{v}"] # match the data with a cluster ----- cluster_scores, f0_fig_tts, en_fig_tts = match_tts(groups, h_data, tts_data, tts_seg_aligns, words, h_seg_aligns, v) best_cluster = [c for c,s in cluster_scores if s == min([s for c,s in cluster_scores])] scorestring = [] for c,s in cluster_scores: if c== best_cluster: scorestring.append(f' **Cluster {c}: {round(s,2)}** ') else: scorestring.append(f' Cluster {c}: {round(s,2)} ') scorestring = ' - '.join(scorestring) audiosample = [pdict['play'] for voic, pdict in tts_all_paths if voic == v][0] tts_results[v] = {'audio': audiosample, 'f0_fig_tts': f0_fig_tts, 'en_fig_tts':en_fig_tts, 'scoreinfo': scorestring} return f0_fig_c0, f0_fig_c1, f0_fig_c2, en_fig_c0, en_fig_c1, en_fig_c2, audio_html, tts_results #return words, kmedoids_cluster_dists, group # generate html panel to play audios for each human cluster # audios is dict {recording_id : (wav_path, seg_start_time, seg_end_time)} def clusters_audio(clusters,colormap,audios): html = '''''' for label in set([c for r,c in clusters]): recs = [r for r,c in clusters if c==label] html += '
' html += f'

Cluster {label}

' html += '
' html += '' for rec in recs: cc = colormap[label][rec] html += f'' html += f'' html += '
{rec}
' html += '
' #html += '
' html += '
' html += '' return html # find offsets to visually align start of each word for speakers in cluster def reset_cluster_times(words,cluster_speakers,human_aligns,tts_align=None): words = words.split('_') retimes = [(words[0], 0.0)] for i in range(len(words)-1): gaps = [human_aligns[spk][i+1][1]-human_aligns[spk][i][1] for spk in cluster_speakers] if tts_align: gaps.append(tts_align[i+1][1] - tts_align[i][1]) retimes.append((words[i+1],retimes[i][1]+max(gaps))) return retimes # apply offsets for a speaker def retime_speaker_xvals(retimes, speaker_aligns, speaker_xvals): new_xvals = [] def xlim(x,i,retimes,speaker_aligns): return (x < speaker_aligns[i+1][1]) if i+1= s) and xlim(x,i,retimes,speaker_aligns) ] return [round(x,3) for x in new_xvals] # interpolate NAN to break lines def retime_xs_feats(retimes, speaker_aligns, speaker_xvals, feats): feat_xvals = retime_speaker_xvals(retimes, speaker_aligns, speaker_xvals) xf0 = list(zip(feat_xvals, feats)) xf = [xf0[0]] for x,f in xf0[1:]: lx = xf[-1][0] if x - lx >= 0.01: xf.append((lx+0.005,np.nan)) xf.append((x,f)) return [x for x,f in xf], [f for x,f in xf] # TODO handle the ccmap in here not inside plot_one def graph_humans(clusters,speech_data,words,seg_aligns): c0,c1,c2 = (0,1,2) nsents = len(speech_data) c0_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==c0} c1_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==c1} c2_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==c2} colors = [(pc,ec) for pc,ec in zip(clc.CET_C8s,clc.CET_C9s)] cix = [int(x) for x in np.linspace(0,len(colors)-1, nsents)] pcolors = [colors[x][0] for x in cix] ecolors= [colors[x][1] for x in cix] f0_fig_c0, c0_cc = plot_one_cluster(words,'pitch',c0_data,seg_aligns,c0,pcolors) f0_fig_c1, c1_cc= plot_one_cluster(words,'pitch',c1_data,seg_aligns,c1,pcolors[len(c0_data):]) f0_fig_c2, c2_cc = plot_one_cluster(words,'pitch',c2_data,seg_aligns,c2,pcolors[len(c0_data)+len(c1_data):]) en_fig_c0, _ = plot_one_cluster(words,'rmse',c0_data,seg_aligns,c0,ecolors) en_fig_c1, _ = plot_one_cluster(words,'rmse',c1_data,seg_aligns,c1,ecolors[len(c0_data):]) en_fig_c2, _ = plot_one_cluster(words,'rmse',c2_data,seg_aligns,c2,ecolors[len(c0_data)+len(c1_data):]) # TODO # not necessarily here, bc paths to audio files. spk_cc_map = {c0 : c0_cc, c1 : c1_cc, c2 : c2_cc} #playable = audio_htmls(spk_cc_map) return f0_fig_c0, f0_fig_c1, f0_fig_c2, en_fig_c0, en_fig_c1, en_fig_c2, spk_cc_map #TODO handle the colour list OUTSIDE of this part.... def plot_one_cluster(words,feature,speech_data,seg_aligns,cluster_id,colors,gtype='cluster',voice=None): cc=0 gclr = "#909090" spk_ccs = {} # for external display #fig = plt.figure(figsize=(10, 5)) if voice: fig, ax = plt.subplots(figsize=(7.5,4)) else: fig, ax = plt.subplots(figsize=(10,5)) fig.patch.set_facecolor('none') ax.patch.set_facecolor('none') fig.patch.set_alpha(0) ax.tick_params(color=gclr,labelcolor=gclr) for spine in ['bottom','left']: ax.spines[spine].set_color(gclr) for spine in ['top','right']: ax.spines[spine].set(visible=False) if feature.lower() in ['pitch','f0']: fname = 'Pitch' def _ffunc(feats): ps = [p for p,e in feats] nv = min(ps) ps = [np.nan if p == nv else p for p in ps] return ps ffunc = _ffunc pfunc = plt.plot ylab = "Mean-variance normalised F0" elif feature.lower() in ['energy', 'rmse']: fname = 'Energy' ffunc = lambda x: [e for p,e in x] pfunc = plt.plot ylab = "Mean-variance normalised energy" else: print('problem with the figure') return fig, [] if gtype == 'cluster': # boundary for start of each word retimes = reset_cluster_times(words,list(speech_data.keys()),seg_aligns)#,tts_align) plt.title(f"{words} - {fname} - Cluster {cluster_id}", color=gclr, fontsize=16) xmax = 0 for k,v in speech_data.items(): spk = k.split('**')[1] word_times = seg_aligns[k] feats = ffunc(v) # datapoint interval is 0.005 seconds feat_xvals = [x*0.005 for x in range(len(feats))] feat_xvals, feats = retime_xs_feats(retimes,word_times,feat_xvals,feats) pfunc(feat_xvals, feats, color=colors[cc], linewidth=2, label=f"Speaker {spk}") xmax = max(xmax,max(feat_xvals)) spk_ccs[spk] = colors[cc] cc += 1 if cc >= len(colors): cc=0 elif gtype == 'tts': # boundary for start of each word retimes = reset_cluster_times(words,[f'{words}**{voice}'],seg_aligns) word_times = seg_aligns[f'{words}**{voice}'] tfeats = ffunc(speech_data[f'{words}**{voice}']) t_xvals = [x*0.005 for x in range(len(tfeats))] t_xvals, tfeats = retime_xs_feats(retimes, word_times, t_xvals, tfeats) pfunc(t_xvals, tfeats, color=colors[cc], label=f"TTS {voice}") plt.title(f"{fname}", color=gclr, fontsize=14) xmax = max(t_xvals) if len(retimes)>1: for w,bound_line in retimes: plt.axvline(x=bound_line, color=gclr, linestyle='--', linewidth=1, label=f'Start "{w}"') plt.xlim([0, xmax]) ax.set_xlabel("Time --->",fontsize=13,color=gclr) ax.set_ylabel(ylab,fontsize=13,color=gclr) #plt.legend() #plt.show() return fig, spk_ccs