wav2vec2 / src /utils /metrics.py
hoang1007
init
5381499
raw
history blame contribute delete
No virus
1.9 kB
from typing import Tuple, Union
import re
def levenshtein_distance(source: Tuple[str], target: Tuple[str]):
"""
Compute the Levenshtein distance between two sequences.
"""
n, m = len(source), len(target)
if n > m:
# Make sure n <= m, to use O(min(n,m)) space
source, target = target, source
n, m = m, n
current_row = range(n + 1) # Keep current and previous row, not entire matrix
for i in range(1, m + 1):
previous_row, current_row = current_row, [i] + [0] * n
for j in range(1, n + 1):
add, delete, change = (
previous_row[j] + 1,
current_row[j - 1] + 1,
previous_row[j - 1],
)
if source[j - 1] != target[i - 1]:
change += 1
current_row[j] = min(add, delete, change)
distance = current_row[n]
del current_row
del previous_row
return distance
def word_error_rate(
predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
):
if isinstance(predicted, str):
predicted = (predicted,)
if isinstance(transcript, str):
transcript = (transcript,)
pattern = r"\W+"
err, total = 0, 0
for pred, tgt in zip(predicted, transcript):
pred_tokens = re.split(pattern, pred)
tgt_tokens = re.split(pattern, tgt)
err += levenshtein_distance(pred_tokens, tgt_tokens)
total += len(tgt_tokens)
return err / total
def character_error_rate(
predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
):
if isinstance(predicted, str):
predicted = (predicted,)
if isinstance(transcript, str):
transcript = (transcript,)
err, total = 0, 0
for pred, tgt in zip(predicted, transcript):
err += levenshtein_distance(pred, tgt)
total += len(tgt)
return err / total