hf-speech-bench / app.py
sanchit-gandhi's picture
recommend latest CV
1c31846
raw
history blame contribute delete
No virus
7.62 kB
import requests
import json
import pandas as pd
from tqdm.auto import tqdm
import streamlit as st
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import metadata_load
aliases_lang = {"sv": "sv-SE"}
cer_langs = ["ja", "zh-CN", "zh-HK", "zh-TW"]
with open("languages.json") as f:
lang2name = json.load(f)
suggested_datasets = [
"librispeech_asr",
"mozilla-foundation/common_voice_8_0",
"mozilla-foundation/common_voice_11_0",
"speech-recognition-community-v2/eval_data",
"facebook/multilingual_librispeech"
]
def make_clickable(model_name):
link = "https://huggingface.co/" + model_name
return f'<a target="_blank" href="{link}">{model_name}</a>'
def get_model_ids():
api = HfApi()
models = api.list_models(filter="hf-asr-leaderboard")
model_ids = [x.modelId for x in models]
return model_ids
def get_metadata(model_id):
try:
readme_path = hf_hub_download(model_id, filename="README.md")
return metadata_load(readme_path)
except:
# 404 README.md not found
print(f"Model id: {model_id} is not great!")
return None
def parse_metric_value(value):
if isinstance(value, str):
"".join(value.split("%"))
try:
value = float(value)
except: # noqa: E722
value = None
elif isinstance(value, float) and value < 1.1:
# assuming that WER is given in 0.xx format
value = 100 * value
elif isinstance(value, list):
if len(value) > 0:
value = value[0]
else:
value = None
value = round(value, 2) if value is not None else None
return value
def parse_metrics_rows(meta):
if "model-index" not in meta or "language" not in meta:
return None
for result in meta["model-index"][0]["results"]:
if "dataset" not in result or "metrics" not in result:
continue
dataset = result["dataset"]["type"]
if "args" in result["dataset"] and "language" in result["dataset"]["args"]:
lang = result["dataset"]["args"]["language"]
else:
lang = meta["language"]
lang = lang[0] if isinstance(lang, list) else lang
lang = aliases_lang[lang] if lang in aliases_lang else lang
config = result["dataset"]["config"] if "config" in result["dataset"] else lang
split = result["dataset"]["split"] if "split" in result["dataset"] else None
row = {
"dataset": dataset,
"lang": lang,
"config": config,
"split": split
}
for metric in result["metrics"]:
type = metric["type"].lower().strip()
if type not in ["wer", "cer"]:
continue
value = parse_metric_value(metric["value"])
if value is None:
continue
if type not in row or value < row[type]:
# overwrite the metric if the new value is lower (e.g. with LM)
row[type] = value
if "wer" in row or "cer" in row:
yield row
@st.cache(ttl=600)
def get_data():
data = []
model_ids = get_model_ids()
for model_id in tqdm(model_ids):
meta = get_metadata(model_id)
if meta is None:
continue
for row in parse_metrics_rows(meta):
if row is None:
continue
row["model_id"] = model_id
data.append(row)
return pd.DataFrame.from_records(data)
def sort_datasets(datasets):
# 1. sort by name
datasets = sorted(datasets)
# 2. bring the suggested datasets to the top and append the rest
datasets = sorted(
datasets,
key=lambda dataset_id: suggested_datasets.index(dataset_id)
if dataset_id in suggested_datasets
else len(suggested_datasets),
)
return datasets
@st.cache(ttl=600)
def generate_dataset_info(datasets):
msg = """
The models have been trained and/or evaluated on the following datasets:
"""
for dataset_id in datasets:
if dataset_id in suggested_datasets:
msg += f"* [{dataset_id}](https://hf.co/datasets/{dataset_id}) *(recommended)*\n"
else:
msg += f"* [{dataset_id}](https://hf.co/datasets/{dataset_id})\n"
msg = "\n".join([line.strip() for line in msg.split("\n")])
return msg
dataframe = get_data()
dataframe = dataframe.fillna("")
st.sidebar.image("logo.png", width=200)
st.markdown("# The πŸ€— Speech Bench")
st.markdown(
f"This is a leaderboard of **{dataframe['model_id'].nunique()}** speech recognition models "
f"and **{dataframe['dataset'].nunique()}** datasets.\n\n"
"β¬… Please select the language you want to find a model for from the dropdown on the left."
)
lang = st.sidebar.selectbox(
"Language",
sorted(dataframe["lang"].unique(), key=lambda key: lang2name.get(key, key)),
format_func=lambda key: lang2name.get(key, key),
index=0,
)
lang_df = dataframe[dataframe.lang == lang]
sorted_datasets = sort_datasets(lang_df["dataset"].unique())
lang_name = lang2name[lang] if lang in lang2name else ""
num_models = len(lang_df["model_id"].unique())
num_datasets = len(lang_df["dataset"].unique())
text = f"""
For the `{lang}` ({lang_name}) language, there are currently `{num_models}` model(s)
trained on `{num_datasets}` dataset(s) available for `automatic-speech-recognition`.
"""
st.markdown(text)
st.sidebar.markdown("""
Choose the dataset that is most relevant to your task and select it from the dropdown below:
""")
dataset = st.sidebar.selectbox(
"Dataset",
sorted_datasets,
index=0,
)
dataset_df = lang_df[lang_df.dataset == dataset]
text = generate_dataset_info(sorted_datasets)
st.sidebar.markdown(text)
# sort by WER or CER depending on the language
metric_col = "cer" if lang in cer_langs else "wer"
if dataset_df["config"].nunique() > 1:
# if there are more than one dataset config
dataset_df = dataset_df[["model_id", "config", metric_col]]
dataset_df = dataset_df.pivot_table(index=['model_id'], columns=["config"], values=[metric_col])
dataset_df = dataset_df.reset_index(level=0)
else:
dataset_df = dataset_df[["model_id", metric_col]]
dataset_df.sort_values(dataset_df.columns[-1], inplace=True)
dataset_df = dataset_df.fillna("")
dataset_df.rename(
columns={
"model_id": "Model",
"wer": "WER (lower is better)",
"cer": "CER (lower is better)",
},
inplace=True,
)
st.markdown(
"Please click on the model's name to be redirected to its model card which includes documentation and examples on how to use it."
)
# display the model ranks
dataset_df = dataset_df.reset_index(drop=True)
dataset_df.index += 1
# turn the model ids into clickable links
dataset_df["Model"] = dataset_df["Model"].apply(make_clickable)
table_html = dataset_df.to_html(escape=False)
table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
st.write(table_html, unsafe_allow_html=True)
if lang in cer_langs:
st.markdown(
"---\n\* **CER** is [Char Error Rate](https://huggingface.co/metrics/cer)"
)
else:
st.markdown(
"---\n\* **WER** is [Word Error Rate](https://huggingface.co/metrics/wer)"
)
st.markdown(
"Want to beat the Leaderboard? Don't see your speech recognition model show up here? "
"Simply add the `hf-asr-leaderboard` tag to your model card alongside your evaluation metrics. "
"Try our [Metrics Editor](https://huggingface.co/spaces/huggingface/speech-bench-metrics-editor) to get started!"
)