|
|
"""Based on https://pytorch.org/audio/stable/tutorials/ctc_forced_alignment_api_tutorial.html""" |
|
|
import csv |
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
import typing |
|
|
|
|
|
import sqlalchemy |
|
|
import torch |
|
|
import torchaudio |
|
|
import torchaudio.functional as F |
|
|
import tqdm |
|
|
from kalpy.fstext.lexicon import CtmInterval as KalpyCtmInterval |
|
|
from kalpy.fstext.lexicon import HierarchicalCtm, WordCtmInterval |
|
|
from montreal_forced_aligner import config |
|
|
from montreal_forced_aligner.command_line.mfa import mfa_cli |
|
|
from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpus |
|
|
from montreal_forced_aligner.data import CtmInterval |
|
|
from montreal_forced_aligner.db import File, Utterance, Word, WordInterval |
|
|
from montreal_forced_aligner.exceptions import TextGridParseError |
|
|
from montreal_forced_aligner.helper import mfa_open |
|
|
from praatio import textgrid as tgio |
|
|
|
|
|
root_dir = r"D:\Data\experiments\alignment_benchmarking" |
|
|
mfa10_dir = r"D:\Data\models\1.0_archived" |
|
|
mfa20_dir = r"D:\Data\models\2.0_archived" |
|
|
mfa20a_dir = r"D:\Data\models\2.0.0a_archived" |
|
|
mfa21_dir = r"D:\Data\models\2.1_trained" |
|
|
mfa22_dir = r"D:\Data\models\2.2_trained" |
|
|
mfa30_dir = r"D:\Data\models\3.0_trained" |
|
|
trained22_dir = r"D:\Data\models\2.2_trained\buckeye" |
|
|
trained30_dir = r"D:\Data\models\3.0_trained\buckeye" |
|
|
mapping_directory = os.path.join( |
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "mapping_files" |
|
|
) |
|
|
nemo_tools_path = r"C:\Users\micha\Documents\Dev\NeMo\tools" |
|
|
config.CLEAN = True |
|
|
config.USE_POSTGRES = True |
|
|
|
|
|
root_dir = r"D:\Data\experiments\alignment_benchmarking" |
|
|
|
|
|
corpus_directories = { |
|
|
"timit": r"D:\Data\speech\benchmark_datasets\timit", |
|
|
"buckeye": r"D:\Data\speech\Buckeye", |
|
|
} |
|
|
|
|
|
conditions = { |
|
|
"torchaudio_mms_fa": (None, None), |
|
|
"nemo_forced_aligner": (None, None), |
|
|
"whisperx": (None, None), |
|
|
"arpa_1.0": ( |
|
|
os.path.join(mfa10_dir, "english.dict"), |
|
|
os.path.join(mfa10_dir, "english.zip"), |
|
|
), |
|
|
"arpa_2.0": ( |
|
|
os.path.join(mfa20_dir, "english_us_arpa.dict"), |
|
|
os.path.join(mfa20_dir, "english_us_arpa.zip"), |
|
|
), |
|
|
"arpa_2.0a": ( |
|
|
os.path.join(mfa20a_dir, "english_us_arpa.dict"), |
|
|
os.path.join(mfa20a_dir, "english_us_arpa.zip"), |
|
|
), |
|
|
"mfa_2.0": ( |
|
|
os.path.join(mfa20_dir, "english_us_mfa.dict"), |
|
|
os.path.join(mfa20_dir, "english_mfa.zip"), |
|
|
), |
|
|
"mfa_2.0a": ( |
|
|
os.path.join(mfa20a_dir, "english_us_mfa.dict"), |
|
|
os.path.join(mfa20a_dir, "english_mfa.zip"), |
|
|
), |
|
|
"mfa_2.1": ( |
|
|
os.path.join(mfa21_dir, "english_us_mfa.dict"), |
|
|
os.path.join(mfa21_dir, "english_mfa.zip"), |
|
|
), |
|
|
"mfa_2.2": ( |
|
|
os.path.join(mfa22_dir, "english_us_mfa.dict"), |
|
|
os.path.join(mfa22_dir, "english_mfa.zip"), |
|
|
), |
|
|
"mfa_3.0": ( |
|
|
os.path.join(mfa30_dir, "english_us_mfa.dict"), |
|
|
os.path.join(mfa30_dir, "english_mfa.zip"), |
|
|
), |
|
|
"arpa_3.0": ( |
|
|
os.path.join(mfa30_dir, "english_us_arpa.dict"), |
|
|
os.path.join(mfa30_dir, "english_us_arpa.zip"), |
|
|
), |
|
|
"trained_2.2": ( |
|
|
os.path.join(trained22_dir, "english_us_mfa.dict"), |
|
|
os.path.join(trained22_dir, "english_mfa.zip"), |
|
|
), |
|
|
"trained_3.0": ( |
|
|
os.path.join(trained30_dir, "english_us_mfa.dict"), |
|
|
os.path.join(trained30_dir, "english_mfa.zip"), |
|
|
), |
|
|
"arpa_2.2": ( |
|
|
os.path.join(mfa20a_dir, "english_us_arpa.dict"), |
|
|
os.path.join(mfa20a_dir, "english_us_arpa.zip"), |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def align(emission, tokens): |
|
|
targets = torch.tensor([tokens], dtype=torch.int32, device=device) |
|
|
alignments, scores = F.forced_align(emission, targets, blank=0) |
|
|
|
|
|
alignments, scores = ( |
|
|
alignments[0], |
|
|
scores[0], |
|
|
) |
|
|
scores = scores.exp() |
|
|
return alignments, scores |
|
|
|
|
|
|
|
|
def unflatten(list_, lengths): |
|
|
assert len(list_) == sum(lengths) |
|
|
i = 0 |
|
|
ret = [] |
|
|
for l in lengths: |
|
|
ret.append(list_[i : i + l]) |
|
|
i += l |
|
|
return ret |
|
|
|
|
|
|
|
|
def generate_ctm( |
|
|
token_spans, transcript, ratio, sample_rate, utterance_begin, reference |
|
|
) -> HierarchicalCtm: |
|
|
i = 0 |
|
|
reference_index = 0 |
|
|
word_intervals = [] |
|
|
for word in transcript: |
|
|
character_intervals = [] |
|
|
for j in range(len(word)): |
|
|
span = token_spans[i] |
|
|
begin = int(ratio * span.start) / sample_rate |
|
|
begin += utterance_begin |
|
|
end = int(ratio * span.end) / sample_rate |
|
|
end += utterance_begin |
|
|
character_intervals.append( |
|
|
KalpyCtmInterval(begin, end, LABELS[span.token], span.token) |
|
|
) |
|
|
i += 1 |
|
|
if word == reference[reference_index].label: |
|
|
reference_index += 1 |
|
|
elif ( |
|
|
word_intervals |
|
|
and f"{word_intervals[-1].label}-{word}" in reference[reference_index].label |
|
|
): |
|
|
word_intervals[-1].label = f"{word_intervals[-1].label}-{word}" |
|
|
word_intervals[-1].symbol = f"{word_intervals[-1].label}-{word}" |
|
|
word_intervals[-1].phones.extend(character_intervals) |
|
|
if reference[reference_index].label == word_intervals[-1].label: |
|
|
reference_index += 1 |
|
|
continue |
|
|
word_intervals.append(WordCtmInterval(word, word, character_intervals)) |
|
|
return HierarchicalCtm(word_intervals) |
|
|
|
|
|
|
|
|
def align_words( |
|
|
ref: typing.List[CtmInterval], |
|
|
test: typing.List[CtmInterval], |
|
|
): |
|
|
try: |
|
|
assert len(ref) == len(test) |
|
|
except AssertionError: |
|
|
print(ref) |
|
|
print(test) |
|
|
raise |
|
|
error_sum = 0.0 |
|
|
boundary_count = 0 |
|
|
for i, r in enumerate(ref): |
|
|
t = test[i] |
|
|
|
|
|
error_sum += abs(r.begin - t.begin) |
|
|
error_sum += abs(r.end - t.end) |
|
|
boundary_count += 2 |
|
|
return error_sum / boundary_count |
|
|
|
|
|
|
|
|
def parse_aligned_textgrid(path: str, exclude_unknowns=True) -> typing.List[CtmInterval]: |
|
|
""" |
|
|
Load a TextGrid as a dictionary of speaker's phone tiers |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
path: :class:`~pathlib.Path` |
|
|
TextGrid file to parse |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict[str, list[:class:`~montreal_forced_aligner.data.CtmInterval`]] |
|
|
Parsed phone tier |
|
|
""" |
|
|
tg = tgio.openTextgrid(path, includeEmptyIntervals=False, reportingMode="silence") |
|
|
num_tiers = len(tg.tiers) |
|
|
if num_tiers == 0: |
|
|
raise TextGridParseError(path, "Number of tiers parsed was zero") |
|
|
data = [] |
|
|
for tier_name in tg.tierNames: |
|
|
ti = tg._tierDict[tier_name] |
|
|
if not isinstance(ti, tgio.IntervalTier): |
|
|
continue |
|
|
if "words" not in tier_name: |
|
|
continue |
|
|
for begin, end, text in ti.entries: |
|
|
text = text.lower().strip().replace("?", "") |
|
|
if exclude_unknowns and re.match(r"^[\[(<].*", text): |
|
|
continue |
|
|
if not text: |
|
|
continue |
|
|
begin, end = round(begin, 4), round(end, 4) |
|
|
|
|
|
|
|
|
interval = CtmInterval(begin, end, text) |
|
|
data.append(interval) |
|
|
return data |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
csv_header = [ |
|
|
"file", |
|
|
"begin", |
|
|
"end", |
|
|
"speaker", |
|
|
"duration", |
|
|
"normalized_text", |
|
|
"filtered_text", |
|
|
"alignment_score", |
|
|
"alignment_likelihood", |
|
|
"word_count", |
|
|
"frames_per_second", |
|
|
"words_per_second", |
|
|
] |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(device) |
|
|
|
|
|
for condition, (dictionary_path, model_path) in conditions.items(): |
|
|
if condition == "torchaudio_mms_fa": |
|
|
with torch.inference_mode(): |
|
|
for corpus, root in corpus_directories.items(): |
|
|
output_directory = os.path.join(root_dir, "alignments", condition, corpus) |
|
|
if os.path.exists(output_directory): |
|
|
continue |
|
|
try: |
|
|
bundle = torchaudio.pipelines.MMS_FA |
|
|
except AttributeError: |
|
|
print("Incorrect version of torchaudio, skipping torchaudio_mms_fa") |
|
|
continue |
|
|
|
|
|
model = bundle.get_model(with_star=False).to(device) |
|
|
|
|
|
LABELS = bundle.get_labels(star=None) |
|
|
DICTIONARY = bundle.get_dict(star=None) |
|
|
corpus = AcousticCorpus(corpus_directory=os.path.join(root, "benchmark")) |
|
|
corpus.delete_database() |
|
|
corpus._load_corpus() |
|
|
os.makedirs(output_directory, exist_ok=True) |
|
|
csv_path = os.path.join(output_directory, "word_alignment.csv") |
|
|
with ( |
|
|
corpus.session() as session, |
|
|
mfa_open(csv_path, "w") as f, |
|
|
tqdm.tqdm(total=corpus.num_utterances) as progress_bar, |
|
|
): |
|
|
writer = csv.DictWriter(f, fieldnames=csv_header) |
|
|
writer.writeheader() |
|
|
file_query = ( |
|
|
session.query(File) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(File.name) |
|
|
) |
|
|
for file in file_query: |
|
|
file_ctm = HierarchicalCtm([]) |
|
|
if file.relative_path: |
|
|
output_path = os.path.join(output_directory, file.relative_path) |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
else: |
|
|
output_path = output_directory |
|
|
output_path = os.path.join(output_path, file.name + ".TextGrid") |
|
|
query = ( |
|
|
session.query(Utterance) |
|
|
.filter(Utterance.file_id == file.id) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload( |
|
|
Utterance.file, innerjoin=True |
|
|
).joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(Utterance.begin) |
|
|
) |
|
|
reference_path = os.path.join(root, "reference") |
|
|
if file.relative_path: |
|
|
reference_path = os.path.join(reference_path, file.relative_path) |
|
|
reference_path = os.path.join(reference_path, file.name + ".TextGrid") |
|
|
reference_word_intervals = parse_aligned_textgrid(reference_path) |
|
|
reference_interval_index = 0 |
|
|
for utterance in query: |
|
|
kalpy_utterance = utterance.to_kalpy() |
|
|
waveform = torch.unsqueeze( |
|
|
torch.Tensor(kalpy_utterance.segment.load_audio()), |
|
|
0, |
|
|
) |
|
|
text = utterance.text.lower().replace("-", " ") |
|
|
transcript = [ |
|
|
re.sub(rf'[^{"".join(LABELS[1:])}]', "", x) |
|
|
for x in text.split() |
|
|
if not re.match(r"^[\[(<].*", x) |
|
|
] |
|
|
transcript = [x for x in transcript if x] |
|
|
tokenized_transcript = [ |
|
|
DICTIONARY[c] for word in transcript for c in word |
|
|
] |
|
|
emission, _ = model(waveform.to(device)) |
|
|
aligned_tokens, alignment_scores = align( |
|
|
emission, tokenized_transcript |
|
|
) |
|
|
token_spans = F.merge_tokens(aligned_tokens, alignment_scores) |
|
|
alignment_likelihood = sum(alignment_scores) / len( |
|
|
alignment_scores |
|
|
) |
|
|
word_spans = unflatten( |
|
|
token_spans, [len(word) for word in transcript] |
|
|
) |
|
|
ratio = waveform.size(1) / emission.size(1) |
|
|
frames_per_second = emission.size(1) / utterance.duration |
|
|
reference = [] |
|
|
while True: |
|
|
try: |
|
|
interval = reference_word_intervals[ |
|
|
reference_interval_index |
|
|
] |
|
|
except IndexError: |
|
|
break |
|
|
if interval.begin >= utterance.end: |
|
|
break |
|
|
if interval.begin >= utterance.begin: |
|
|
reference.append(interval) |
|
|
reference_interval_index += 1 |
|
|
ctm = generate_ctm( |
|
|
token_spans, |
|
|
transcript, |
|
|
ratio, |
|
|
bundle.sample_rate, |
|
|
utterance.begin, |
|
|
reference, |
|
|
) |
|
|
file_ctm.word_intervals.extend(ctm.word_intervals) |
|
|
try: |
|
|
alignment_score = align_words(reference, ctm.word_intervals) |
|
|
except AssertionError: |
|
|
print(file.name) |
|
|
raise |
|
|
continue |
|
|
words_per_second = len(reference) / utterance.duration |
|
|
|
|
|
data = { |
|
|
"file": file.name, |
|
|
"begin": utterance.begin, |
|
|
"end": utterance.end, |
|
|
"duration": utterance.duration, |
|
|
"speaker": utterance.speaker_name, |
|
|
"normalized_text": text.replace(",", ""), |
|
|
"filtered_text": " ".join(transcript), |
|
|
"alignment_score": alignment_score, |
|
|
"alignment_likelihood": float(alignment_likelihood), |
|
|
"word_count": len(ctm.word_intervals), |
|
|
"frames_per_second": frames_per_second, |
|
|
"words_per_second": words_per_second, |
|
|
} |
|
|
writer.writerow(data) |
|
|
progress_bar.update(1) |
|
|
file_ctm.export_textgrid( |
|
|
output_path, |
|
|
file_duration=file.duration, |
|
|
output_format="short_textgrid", |
|
|
) |
|
|
elif condition == "nemo_forced_aligner": |
|
|
try: |
|
|
from dataclasses import dataclass, field, is_dataclass |
|
|
|
|
|
from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
|
|
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import ( |
|
|
EncDecHybridRNNTCTCModel, |
|
|
) |
|
|
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR |
|
|
from nemo.collections.asr.parts.utils.transcribe_utils import setup_model |
|
|
from nemo.core.config import hydra_runner |
|
|
except ImportError: |
|
|
raise |
|
|
continue |
|
|
sys.path.append(nemo_tools_path) |
|
|
sys.path.insert(0, os.path.join(nemo_tools_path, "nemo_forced_aligner")) |
|
|
import tempfile |
|
|
|
|
|
from nemo_forced_aligner.align import AlignmentConfig, ASSFileConfig, CTMFileConfig |
|
|
from nemo_forced_aligner.utils.data_prep import ( |
|
|
V_NEGATIVE_NUM, |
|
|
Segment, |
|
|
Token, |
|
|
Word, |
|
|
add_t_start_end_to_utt_obj, |
|
|
get_utt_obj, |
|
|
) |
|
|
from nemo_forced_aligner.utils.viterbi_decoding import viterbi_decoding |
|
|
|
|
|
cfg = AlignmentConfig() |
|
|
cfg.pretrained_name = "stt_en_conformer_ctc_medium" |
|
|
if cfg.transcribe_device is None: |
|
|
transcribe_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
else: |
|
|
transcribe_device = torch.device(cfg.transcribe_device) |
|
|
if cfg.viterbi_device is None: |
|
|
viterbi_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
else: |
|
|
viterbi_device = torch.device(cfg.viterbi_device) |
|
|
model, _ = setup_model(cfg, transcribe_device) |
|
|
model.eval() |
|
|
if isinstance(model, EncDecHybridRNNTCTCModel): |
|
|
model.change_decoding_strategy(decoder_type="ctc") |
|
|
if cfg.use_local_attention: |
|
|
model.change_attention_model( |
|
|
self_attention_model="rel_pos_local_attn", att_context_size=[64, 64] |
|
|
) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for corpus, root in corpus_directories.items(): |
|
|
output_directory = os.path.join(root_dir, "alignments", condition, corpus) |
|
|
if os.path.exists(output_directory): |
|
|
continue |
|
|
corpus = AcousticCorpus(corpus_directory=os.path.join(root, "benchmark")) |
|
|
corpus.delete_database() |
|
|
corpus._load_corpus() |
|
|
os.makedirs(output_directory, exist_ok=True) |
|
|
csv_path = os.path.join(output_directory, "word_alignment.csv") |
|
|
with ( |
|
|
corpus.session() as session, |
|
|
mfa_open(csv_path, "w") as f, |
|
|
tempfile.TemporaryDirectory() as tmpdir, |
|
|
tqdm.tqdm(total=corpus.num_utterances) as progress_bar, |
|
|
): |
|
|
writer = csv.DictWriter(f, fieldnames=csv_header) |
|
|
writer.writeheader() |
|
|
file_query = ( |
|
|
session.query(File) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(File.name) |
|
|
) |
|
|
for file in file_query: |
|
|
file_ctm = HierarchicalCtm([]) |
|
|
if file.relative_path: |
|
|
output_path = os.path.join(output_directory, file.relative_path) |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
else: |
|
|
output_path = output_directory |
|
|
output_path = os.path.join(output_path, file.name + ".TextGrid") |
|
|
query = ( |
|
|
session.query(Utterance) |
|
|
.filter(Utterance.file_id == file.id) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload( |
|
|
Utterance.file, innerjoin=True |
|
|
).joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(Utterance.begin) |
|
|
) |
|
|
reference_path = os.path.join(root, "reference") |
|
|
if file.relative_path: |
|
|
reference_path = os.path.join(reference_path, file.relative_path) |
|
|
reference_path = os.path.join(reference_path, file.name + ".TextGrid") |
|
|
reference_word_intervals = parse_aligned_textgrid(reference_path) |
|
|
reference_interval_index = 0 |
|
|
for utterance in query: |
|
|
kalpy_utterance = utterance.to_kalpy() |
|
|
reference = [] |
|
|
while True: |
|
|
try: |
|
|
interval = reference_word_intervals[ |
|
|
reference_interval_index |
|
|
] |
|
|
except IndexError: |
|
|
break |
|
|
if interval.begin >= utterance.end: |
|
|
break |
|
|
if interval.begin >= utterance.begin: |
|
|
reference.append(interval) |
|
|
reference_interval_index += 1 |
|
|
text = " ".join(x.label for x in reference) |
|
|
with torch.no_grad(): |
|
|
output_timestep_duration = None |
|
|
path = os.path.join(tmpdir, f"sound_file.wav") |
|
|
torchaudio.save( |
|
|
path, |
|
|
torch.unsqueeze( |
|
|
torch.Tensor(kalpy_utterance.segment.load_audio()), |
|
|
0, |
|
|
), |
|
|
16000, |
|
|
) |
|
|
hypotheses = model.transcribe( |
|
|
[path], |
|
|
return_hypotheses=True, |
|
|
verbose=False, |
|
|
batch_size=1, |
|
|
) |
|
|
|
|
|
if type(hypotheses) == tuple and len(hypotheses) == 2: |
|
|
hypotheses = hypotheses[0] |
|
|
|
|
|
for hypothesis in hypotheses: |
|
|
log_probs_list_batch = [hypothesis.y_sequence] |
|
|
T_list_batch = [hypothesis.y_sequence.shape[0]] |
|
|
pred_text_batch = [hypothesis.text] |
|
|
utt_obj = get_utt_obj( |
|
|
text, |
|
|
model, |
|
|
cfg.additional_segment_grouping_separator, |
|
|
T_list_batch[0], |
|
|
file.sound_file.sound_file_path, |
|
|
utterance.kaldi_id, |
|
|
) |
|
|
alignment_likelihood = float(hypothesis.score) |
|
|
y_list_batch = [utt_obj.token_ids_with_blanks] |
|
|
U_list_batch = [len(utt_obj.token_ids_with_blanks)] |
|
|
utt_obj_batch = [utt_obj] |
|
|
T_max = max(T_list_batch) |
|
|
U_max = max(U_list_batch) |
|
|
if hasattr(model, "tokenizer"): |
|
|
V = len(model.tokenizer.vocab) + 1 |
|
|
else: |
|
|
V = len(model.decoder.vocabulary) + 1 |
|
|
T_batch = torch.tensor(T_list_batch) |
|
|
U_batch = torch.tensor(U_list_batch) |
|
|
log_probs_batch = V_NEGATIVE_NUM * torch.ones((1, T_max, V)) |
|
|
for b, log_probs_utt in enumerate(log_probs_list_batch): |
|
|
t = log_probs_utt.shape[0] |
|
|
log_probs_batch[b, :t, :] = log_probs_utt |
|
|
y_batch = V * torch.ones((1, U_max), dtype=torch.int64) |
|
|
for b, y_utt in enumerate(y_list_batch): |
|
|
U_utt = U_batch[b] |
|
|
y_batch[b, :U_utt] = torch.tensor(y_utt) |
|
|
|
|
|
if output_timestep_duration is None: |
|
|
if not "window_stride" in model.cfg.preprocessor: |
|
|
raise ValueError( |
|
|
"Don't have attribute 'window_stride' in " |
|
|
"'model.cfg.preprocessor' => cannot calculate " |
|
|
" model_downsample_factor => stopping process" |
|
|
) |
|
|
|
|
|
if not "sample_rate" in model.cfg.preprocessor: |
|
|
raise ValueError( |
|
|
"Don't have attribute 'sample_rate' in " |
|
|
"'model.cfg.preprocessor' => cannot calculate start " |
|
|
" and end time of segments => stopping process" |
|
|
) |
|
|
|
|
|
audio_dur = utterance.duration |
|
|
n_input_frames = ( |
|
|
audio_dur / model.cfg.preprocessor.window_stride |
|
|
) |
|
|
model_downsample_factor = round( |
|
|
n_input_frames / int(T_batch[0]) |
|
|
) |
|
|
|
|
|
output_timestep_duration = ( |
|
|
model.preprocessor.featurizer.hop_length |
|
|
* model_downsample_factor |
|
|
/ model.cfg.preprocessor.sample_rate |
|
|
) |
|
|
alignments_batch = viterbi_decoding( |
|
|
log_probs_batch, |
|
|
y_batch, |
|
|
T_batch, |
|
|
U_batch, |
|
|
viterbi_device, |
|
|
) |
|
|
utt_obj, alignment_utt = ( |
|
|
utt_obj_batch[0], |
|
|
alignments_batch[0], |
|
|
) |
|
|
|
|
|
utt_obj = add_t_start_end_to_utt_obj( |
|
|
utt_obj, alignment_utt, output_timestep_duration |
|
|
) |
|
|
segment = [ |
|
|
x |
|
|
for x in utt_obj.segments_and_tokens |
|
|
if isinstance(x, Segment) |
|
|
][0] |
|
|
words = [ |
|
|
x for x in segment.words_and_tokens if isinstance(x, Word) |
|
|
] |
|
|
word_intervals = [] |
|
|
for word in words: |
|
|
character_intervals = [] |
|
|
word_label = word.text |
|
|
for token in word.tokens: |
|
|
if token.text == "<b>": |
|
|
continue |
|
|
begin = token.t_start + utterance.begin |
|
|
end = token.t_end + utterance.begin |
|
|
if end > utterance.end: |
|
|
end = utterance.end |
|
|
label = token.text |
|
|
character_intervals.append( |
|
|
KalpyCtmInterval(begin, end, label, label) |
|
|
) |
|
|
word_intervals.append( |
|
|
WordCtmInterval( |
|
|
word_label, word_label, character_intervals |
|
|
) |
|
|
) |
|
|
ctm = HierarchicalCtm(word_intervals) |
|
|
|
|
|
frames_per_second = ( |
|
|
log_probs_list_batch[0].size(0) / utterance.duration |
|
|
) |
|
|
file_ctm.word_intervals.extend(ctm.word_intervals) |
|
|
try: |
|
|
alignment_score = align_words(reference, ctm.word_intervals) |
|
|
except AssertionError: |
|
|
print(file.name) |
|
|
raise |
|
|
continue |
|
|
words_per_second = len(reference) / utterance.duration |
|
|
|
|
|
data = { |
|
|
"file": file.name, |
|
|
"begin": utterance.begin, |
|
|
"end": utterance.end, |
|
|
"duration": utterance.duration, |
|
|
"speaker": utterance.speaker_name, |
|
|
"normalized_text": utterance.normalized_text, |
|
|
"filtered_text": text, |
|
|
"alignment_score": alignment_score, |
|
|
"alignment_likelihood": float(alignment_likelihood), |
|
|
"word_count": len(ctm.word_intervals), |
|
|
"frames_per_second": frames_per_second, |
|
|
"words_per_second": words_per_second, |
|
|
} |
|
|
writer.writerow(data) |
|
|
progress_bar.update(1) |
|
|
file_ctm.export_textgrid( |
|
|
output_path, |
|
|
file_duration=file.duration, |
|
|
output_format="short_textgrid", |
|
|
) |
|
|
elif condition == "whisperx": |
|
|
try: |
|
|
import whisperx |
|
|
except ImportError: |
|
|
print("whisperx not installed, skipping") |
|
|
continue |
|
|
compute_type = "float16" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = whisperx.load_model( |
|
|
"large-v2", device, compute_type=compute_type, language="en" |
|
|
) |
|
|
model_a, metadata = whisperx.load_align_model(language_code="en", device=device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for corpus, root in corpus_directories.items(): |
|
|
output_directory = os.path.join(root_dir, "alignments", condition, corpus) |
|
|
if os.path.exists(output_directory): |
|
|
continue |
|
|
corpus = AcousticCorpus(corpus_directory=os.path.join(root, "benchmark")) |
|
|
corpus.delete_database() |
|
|
corpus._load_corpus() |
|
|
os.makedirs(output_directory, exist_ok=True) |
|
|
csv_path = os.path.join(output_directory, "word_alignment.csv") |
|
|
with ( |
|
|
corpus.session() as session, |
|
|
mfa_open(csv_path, "w") as f, |
|
|
tqdm.tqdm(total=corpus.num_utterances) as progress_bar, |
|
|
): |
|
|
writer = csv.DictWriter(f, fieldnames=csv_header) |
|
|
writer.writeheader() |
|
|
file_query = ( |
|
|
session.query(File) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(File.name) |
|
|
) |
|
|
for file in file_query: |
|
|
file_ctm = HierarchicalCtm([]) |
|
|
if file.relative_path: |
|
|
output_path = os.path.join(output_directory, file.relative_path) |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
else: |
|
|
output_path = output_directory |
|
|
output_path = os.path.join(output_path, file.name + ".TextGrid") |
|
|
query = ( |
|
|
session.query(Utterance) |
|
|
.filter(Utterance.file_id == file.id) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload( |
|
|
Utterance.file, innerjoin=True |
|
|
).joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(Utterance.begin) |
|
|
) |
|
|
reference_path = os.path.join(root, "reference") |
|
|
if file.relative_path: |
|
|
reference_path = os.path.join(reference_path, file.relative_path) |
|
|
reference_path = os.path.join(reference_path, file.name + ".TextGrid") |
|
|
reference_word_intervals = parse_aligned_textgrid(reference_path) |
|
|
reference_interval_index = 0 |
|
|
for utterance in query: |
|
|
kalpy_utterance = utterance.to_kalpy() |
|
|
reference = [] |
|
|
while True: |
|
|
try: |
|
|
interval = reference_word_intervals[ |
|
|
reference_interval_index |
|
|
] |
|
|
except IndexError: |
|
|
break |
|
|
if interval.begin >= utterance.end: |
|
|
break |
|
|
if interval.begin >= utterance.begin: |
|
|
reference.append(interval) |
|
|
reference_interval_index += 1 |
|
|
text = " ".join(x.label for x in reference) |
|
|
|
|
|
audio = kalpy_utterance.segment.load_audio() |
|
|
|
|
|
|
|
|
|
|
|
segments = [ |
|
|
{ |
|
|
"text": text, |
|
|
"start": 0.0, |
|
|
"end": utterance.duration, |
|
|
} |
|
|
] |
|
|
result = whisperx.align( |
|
|
segments, |
|
|
model_a, |
|
|
metadata, |
|
|
audio, |
|
|
device, |
|
|
return_char_alignments=True, |
|
|
) |
|
|
|
|
|
word_intervals = [] |
|
|
words = result["segments"][0]["words"] |
|
|
chars = result["segments"][0]["chars"] |
|
|
character_index = 0 |
|
|
score_sum = 0 |
|
|
for word in words: |
|
|
character_intervals = [] |
|
|
word_label = word["word"] |
|
|
while True: |
|
|
try: |
|
|
c = chars[character_index] |
|
|
except IndexError: |
|
|
break |
|
|
if c["char"] == " ": |
|
|
character_index += 1 |
|
|
continue |
|
|
if c["start"] >= word["end"]: |
|
|
break |
|
|
begin = c["start"] + utterance.begin |
|
|
end = c["end"] + utterance.begin |
|
|
character_intervals.append( |
|
|
KalpyCtmInterval(begin, end, c["char"], c["char"]) |
|
|
) |
|
|
character_index += 1 |
|
|
word_intervals.append( |
|
|
WordCtmInterval( |
|
|
word_label, word_label, character_intervals |
|
|
) |
|
|
) |
|
|
score_sum += word["score"] |
|
|
ctm = HierarchicalCtm(word_intervals) |
|
|
alignment_likelihood = score_sum / len(words) |
|
|
file_ctm.word_intervals.extend(ctm.word_intervals) |
|
|
try: |
|
|
alignment_score = align_words(reference, ctm.word_intervals) |
|
|
except AssertionError: |
|
|
print(file.name) |
|
|
raise |
|
|
continue |
|
|
words_per_second = len(reference) / utterance.duration |
|
|
|
|
|
data = { |
|
|
"file": file.name, |
|
|
"begin": utterance.begin, |
|
|
"end": utterance.end, |
|
|
"duration": utterance.duration, |
|
|
"speaker": utterance.speaker_name, |
|
|
"normalized_text": utterance.normalized_text, |
|
|
"filtered_text": text, |
|
|
"alignment_score": alignment_score, |
|
|
"alignment_likelihood": float(alignment_likelihood), |
|
|
"word_count": len(ctm.word_intervals), |
|
|
"frames_per_second": "NA", |
|
|
"words_per_second": words_per_second, |
|
|
} |
|
|
writer.writerow(data) |
|
|
progress_bar.update(1) |
|
|
file_ctm.export_textgrid( |
|
|
output_path, |
|
|
file_duration=file.duration, |
|
|
output_format="short_textgrid", |
|
|
) |
|
|
else: |
|
|
if not os.path.exists(model_path): |
|
|
continue |
|
|
for corpus, root in corpus_directories.items(): |
|
|
output_directory = os.path.join(root_dir, "alignments", condition, corpus) |
|
|
csv_path = os.path.join(output_directory, "word_alignment.csv") |
|
|
if os.path.exists(csv_path): |
|
|
continue |
|
|
command = [ |
|
|
"align", |
|
|
os.path.join(root, "benchmark"), |
|
|
dictionary_path, |
|
|
model_path, |
|
|
output_directory, |
|
|
"-j", |
|
|
"10", |
|
|
"--clean", |
|
|
"--debug", |
|
|
"--use_mp", |
|
|
"--use_cutoff_model", |
|
|
"--use_postgres", |
|
|
"--cleanup_textgrids", |
|
|
"--beam", |
|
|
"10", |
|
|
"--retry_beam", |
|
|
"40", |
|
|
] |
|
|
print(command) |
|
|
mfa_cli(command, standalone_mode=False) |
|
|
corpus = AcousticCorpus(corpus_directory=os.path.join(root, "benchmark")) |
|
|
with ( |
|
|
corpus.session() as session, |
|
|
mfa_open(csv_path, "w") as f, |
|
|
tqdm.tqdm(total=corpus.num_utterances) as progress_bar, |
|
|
): |
|
|
writer = csv.DictWriter(f, fieldnames=csv_header) |
|
|
writer.writeheader() |
|
|
file_query = ( |
|
|
session.query(File) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(File.name) |
|
|
) |
|
|
for file in file_query: |
|
|
query = ( |
|
|
session.query(Utterance) |
|
|
.filter(Utterance.file_id == file.id) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload( |
|
|
Utterance.file, innerjoin=True |
|
|
).joinedload(File.sound_file, innerjoin=True), |
|
|
) |
|
|
.order_by(Utterance.begin) |
|
|
) |
|
|
reference_path = os.path.join(root, "reference") |
|
|
if file.relative_path: |
|
|
reference_path = os.path.join(reference_path, file.relative_path) |
|
|
reference_path = os.path.join(reference_path, file.name + ".TextGrid") |
|
|
reference_word_intervals = parse_aligned_textgrid( |
|
|
reference_path, exclude_unknowns=True |
|
|
) |
|
|
reference_interval_index = 0 |
|
|
for utterance in query: |
|
|
word_interval_query = ( |
|
|
session.query(WordInterval) |
|
|
.join(WordInterval.word) |
|
|
.filter(WordInterval.utterance_id == utterance.id) |
|
|
.filter(Word.word.op("~")(r"^[^\[<{(]")) |
|
|
.options( |
|
|
sqlalchemy.orm.joinedload(WordInterval.word, innerjoin=True), |
|
|
) |
|
|
.order_by(WordInterval.begin) |
|
|
) |
|
|
reference = [] |
|
|
while True: |
|
|
try: |
|
|
interval = reference_word_intervals[reference_interval_index] |
|
|
except IndexError: |
|
|
break |
|
|
if interval.begin >= utterance.end: |
|
|
break |
|
|
if interval.begin >= utterance.begin: |
|
|
reference.append(interval) |
|
|
reference_interval_index += 1 |
|
|
word_intervals = [] |
|
|
reference_index = 0 |
|
|
for i, wi in enumerate(word_interval_query): |
|
|
if reference[reference_index].label == wi.word.word: |
|
|
word_intervals.append(wi.as_ctm()) |
|
|
reference_index += 1 |
|
|
elif ( |
|
|
word_intervals |
|
|
and reference[reference_index].label |
|
|
== word_intervals[-1].label + wi.word.word |
|
|
): |
|
|
word_intervals[-1].label += wi.word.word |
|
|
word_intervals[-1].end = wi.end |
|
|
reference_index += 1 |
|
|
elif ( |
|
|
word_intervals |
|
|
and word_intervals[-1].label + "-" + wi.word.word |
|
|
in reference[reference_index].label |
|
|
): |
|
|
word_intervals[-1].label += "-" + wi.word.word |
|
|
word_intervals[-1].end = wi.end |
|
|
if ( |
|
|
word_intervals[-1].label |
|
|
== reference[reference_index].label |
|
|
): |
|
|
reference_index += 1 |
|
|
else: |
|
|
word_intervals.append(wi.as_ctm()) |
|
|
if not word_intervals: |
|
|
continue |
|
|
word_index = 0 |
|
|
alignment_score = align_words(reference, word_intervals) |
|
|
words_per_second = len(reference) / utterance.duration |
|
|
data = { |
|
|
"file": file.name, |
|
|
"begin": utterance.begin, |
|
|
"end": utterance.end, |
|
|
"duration": utterance.duration, |
|
|
"speaker": utterance.speaker_name, |
|
|
"normalized_text": utterance.normalized_text, |
|
|
"filtered_text": " ".join(x.label for x in word_intervals), |
|
|
"alignment_score": alignment_score, |
|
|
"alignment_likelihood": utterance.alignment_log_likelihood, |
|
|
"word_count": len(word_intervals), |
|
|
"frames_per_second": 100, |
|
|
"words_per_second": words_per_second, |
|
|
} |
|
|
writer.writerow(data) |
|
|
progress_bar.update(1) |
|
|
|