File size: 6,455 Bytes
d30410b
 
 
 
9fc679b
d30410b
 
9fc679b
d30410b
5a5eb56
d30410b
 
9fc679b
5a899de
 
 
88109c0
 
 
d30410b
 
15ab68a
d30410b
 
 
 
15ab68a
 
d30410b
 
15ab68a
d30410b
 
15ab68a
 
 
d30410b
e6d7cc1
d30410b
c748316
e6d7cc1
d30410b
15ab68a
e6d7cc1
d30410b
88109c0
d30410b
 
 
 
15ab68a
1c6a282
15ab68a
d30410b
 
 
 
15ab68a
e6d7cc1
0ff8797
9fc679b
15ab68a
d30410b
9fc679b
 
 
 
 
 
 
 
 
 
 
f36944b
9fc679b
 
 
 
 
e6d7cc1
9fc679b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a899de
9fc679b
 
 
 
 
 
c748316
e6d7cc1
c748316
e6d7cc1
 
 
 
88109c0
9fc679b
 
 
 
 
 
d30410b
 
9fc679b
d30410b
 
 
 
 
 
 
9fc679b
d30410b
 
 
 
 
15ab68a
d30410b
 
 
 
 
 
 
15ab68a
 
 
 
d30410b
 
15ab68a
d30410b
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gradio as gr
import pandas as pd
import json
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS
from init import is_model_on_hub, load_all_info_from_dataset_hub
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message
from datetime import datetime, timezone
import torch

LAST_UPDATED = "OCT 21st 2024"

column_names = {
    "MODEL": "Model",
    "WER": "Common Voice WER ⬇️",
    "CER": "Common Voice CER ⬇️",
    "WER2": "persian-asr-test-set WER" ,
    "CER2": "persian-asr-test-set CER",
    "WER3": "asr-farsi-youtube WER",
    "CER3": "asr-farsi-youtube CER"
}

# Load evaluation results
eval_queue_repo, requested_models, csv_results = load_all_info_from_dataset_hub()

if not csv_results.exists():
    raise Exception(f"CSV file {csv_results} does not exist locally")

# Read CSV with data and parse columns
original_df = pd.read_csv(csv_results)

# Format the columns
def formatter(x):
    if type(x) is str:
        return x
    else:
        return round(x, 2)

# Apply formatting to other columns
for col in original_df.columns:
    if col in ['Model']:
        continue  # Skip the 'model' and 'Model' columns
    else:
        original_df[col] = original_df[col].apply(formatter)
        
original_df.rename(columns=column_names, inplace=True)
original_df.sort_values(by='asr-farsi-youtube WER', inplace=True)

COLS = [c.name for c in fields(AutoEvalColumn)]
TYPES = [c.type for c in fields(AutoEvalColumn)]

def request_model(model_text):
    global original_df
    # Check if the model exists on the Hub
    base_model_on_hub, error_msg = is_model_on_hub(model_text)

    if not base_model_on_hub:
        return styled_error(f"Base model '{model_text}' {error_msg}")

    # Check if the model has already been evaluated using the raw 'model' column
    if model_text in original_df['Model'].values:
        return styled_error(f"The model '{model_text}' is already in the leaderboard.")

    try:
        # Run the evaluation code
        from transformers import pipeline
        from datasets import load_dataset
        from tqdm import tqdm
        from transformers.pipelines.pt_utils import KeyDataset
        from evaluate import load

        # Load a subset of the Common Voice test dataset for evaluation
        common_voice_test = load_dataset(
            "mozilla-foundation/common_voice_17_0", "fa", split="test"
        ).shuffle(seed=42).select(range(len(load_dataset(
            "mozilla-foundation/common_voice_17_0", "fa", split="test")) // 150))

        # Initialize the pipeline with the requested model
        pipe = pipeline(
            "automatic-speech-recognition",
            model=model_text,
            torch_dtype=torch.float16,
            device=0 if torch.cuda.is_available() else -1,  # Use GPU if available
        )

        all_predictions = []

        # Run inference
        for prediction in tqdm(
            pipe(
                KeyDataset(common_voice_test, "audio"),
                max_new_tokens=128,
                chunk_length_s=30,
                generate_kwargs={"task": "transcribe"},
                batch_size=32,
            ),
            total=len(common_voice_test),
        ):
            all_predictions.append(prediction["text"])

        wer_metric = load("wer")
        cer_metric = load("cer")

        wer_result = 100 * wer_metric.compute(
            references=common_voice_test["sentence"], predictions=all_predictions
        )

        cer_result = 100 * cer_metric.compute(
            references=common_voice_test["sentence"], predictions=all_predictions
        )

        # Update the results CSV
        new_row = {'model': model_text, 'Common Voice WER ⬇️': wer_result, 'Common Voice CER ⬇️': cer_result}
        df_results = pd.read_csv(csv_results)
        df_results = df_results.append(new_row, ignore_index=True)
        df_results.to_csv(csv_results, index=False)

        # Update the leaderboard DataFrame
        original_df = df_results.copy()
        original_df['Model'] = original_df['Model'].apply(make_clickable_model)
        for col in original_df.columns:
            if col in ['Model']:
                continue  # Skip the 'model' and 'Model' columns
            else:
                original_df[col] = original_df[col].apply(formatter)
        original_df.rename(columns=column_names, inplace=True)
        original_df.sort_values(by='asr-farsi-youtube WER', inplace=True)

        # Update the leaderboard table in the UI
        leaderboard_table.update(value=original_df)

        # Return success message
        return styled_message("πŸ€— Your model has been evaluated and added to the leaderboard!")

    except Exception as e:
        return styled_error(f"Error during evaluation: {e}")

with gr.Blocks() as demo:
    gr.HTML(BANNER, elem_id="banner")
    gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")

    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("πŸ… Leaderboard", elem_id="od-benchmark-tab-table", id=0):
            leaderboard_table = gr.Dataframe(
                value=original_df,
                datatype=TYPES,
                elem_id="leaderboard-table",
                interactive=False,
                visible=True,
            )

        with gr.TabItem("πŸ“ˆ Metrics", elem_id="od-benchmark-tab-table", id=1):
            gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text")

        with gr.TabItem("βœ‰οΈβœ¨ Request a model here!", elem_id="od-benchmark-tab-table", id=2):
            with gr.Column():
                gr.Markdown("# βœ‰οΈβœ¨ Request results for a new model here!", elem_classes="markdown-text")
                model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)")
                mdw_submission_result = gr.Markdown()
                btn_submit = gr.Button(value="πŸš€ Request")
                btn_submit.click(request_model, [model_name_textbox], mdw_submission_result)

    gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text")

    with gr.Row():
        with gr.Accordion("πŸ“™ Citation", open=False):
            gr.Textbox(
                value=CITATION_TEXT, lines=7,
                label="Copy the BibTeX snippet to cite this source",
                elem_id="citation-button",
                show_copy_button=True,
            )

demo.launch()