mpc001's picture
Upload 125 files
09481f3
raw
history blame
No virus
1.2 kB
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# This code refers https://github.com/espnet/espnet/blob/24c3676a8d4c2e60d2726e9bcd9bdbed740610e0/espnet/nets/e2e_asr_common.py#L213-L249
import numpy as np
def get_wer(s, ref):
return get_er(s.split(), ref.split())
def get_cer(s, ref):
return get_er(s.replace(" ", ""), ref.replace(" ", ""))
def get_er(s, ref):
"""
FROM wikipedia levenshtein distance
s: list of words/char in sentence to measure
ref: list of words/char in reference
"""
costs = np.zeros((len(s) + 1, len(ref) + 1))
for i in range(len(s) + 1):
costs[i, 0] = i
for j in range(len(ref) + 1):
costs[0, j] = j
for j in range(1, len(ref) + 1):
for i in range(1, len(s) + 1):
cost = None
if s[i-1] == ref[j-1]:
cost = 0
else:
cost = 1
costs[i,j] = min(
costs[i-1, j] + 1,
costs[i, j-1] + 1,
costs[i-1, j-1] + cost
)
return costs[-1,-1] / len(ref)