File size: 5,374 Bytes
856bef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fce27b
 
 
 
 
 
 
 
 
 
 
 
 
856bef6
7fce27b
 
856bef6
7fce27b
 
 
856bef6
7fce27b
 
 
 
856bef6
 
 
7fce27b
 
856bef6
7fce27b
 
 
 
 
 
 
 
 
 
 
 
856bef6
 
7fce27b
 
 
 
 
856bef6
7fce27b
 
 
 
 
 
 
 
856bef6
7fce27b
856bef6
7fce27b
 
 
 
 
 
 
 
 
 
856bef6
 
 
 
 
 
 
7fce27b
 
856bef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fce27b
 
e6a252d
 
 
7fce27b
856bef6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import logging
import sys
import gradio as gr
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


DICT_MODELS = {
    "robust-300m": {"model_id": "dbdmg/wav2vec2-xls-r-300m-italian-robust", "has_lm": True},
    "robust-1b": {"model_id": "dbdmg/wav2vec2-xls-r-1b-italian-robust", "has_lm": True},
    "300m": {"model_id": "dbdmg/wav2vec2-xls-r-300m-italian", "has_lm": True},
}


# LANGUAGES = sorted(LARGE_MODEL_BY_LANGUAGE.keys())

# the container given by HF has 16GB of RAM, so we need to limit the number of models to load
MODELS = sorted(DICT_MODELS.keys())
CACHED_MODELS_BY_ID = {}

def build_html(history):
    html_output = "<div class='result'>"
    for item in history:
        if item["error_message"] is not None:
            html_output += f"<div class='result_item result_item_error'>{item['error_message']}</div>"
        else:
            url_suffix = " + Guided by Language Model" if item["decoding_type"] == "Guided by Language Model" else ""
            html_output += "<div class='result_item result_item_success'>"
            html_output += f'<strong><a target="_blank" href="https://huggingface.co/{item["model_id"]}">{item["model_id"]}{url_suffix}</a></strong><br/><br/>'
            html_output += f'{item["transcription"]}<br/>'
            html_output += "</div>"
    html_output += "</div>"
    return html_output

def run(uploaded_file, input_file, model_name, decoding_type, history):
    
    model = DICT_MODELS.get(model_name)
    history = history or []
    
    if uploaded_file is None and input_file is None:
        history.append({
            "model_id": model["model_id"],
            "decoding_type": decoding_type,
            "transcription": "",
            "error_message": "No input provided."
        })
    else:

        if input_file is None:
            input_file = uploaded_file

        logger.info(f"Running ASR {model_name}-{decoding_type} for {input_file}")

        history = history or []

        if model is None:
            history.append({
                "error_message": f"Model size {model_size} not found for {language} language :("
            })
        elif decoding_type == "Guided by Language Model" and not model["has_lm"]:
            history.append({
                "error_message": f"LM not available for {language} language :("
            })
        else:

            # model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
            model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
            if model_instance is None:
                model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
                CACHED_MODELS_BY_ID[model["model_id"]] = model_instance

            if decoding_type == "Guided by Language Model":
                processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
                asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer, 
                            feature_extractor=processor.feature_extractor, decoder=processor.decoder)
            else:
                processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
                asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer, 
                            feature_extractor=processor.feature_extractor, decoder=None)

            transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]

            logger.info(f"Transcription for {input_file}: {transcription}")

            history.append({
                "model_id": model["model_id"],
                "decoding_type": decoding_type,
                "transcription": transcription,
                "error_message": None
            })

    html_output = build_html(history)

    return html_output, history


gr.Interface(
    run,
    inputs=[
        gr.inputs.Audio(source="upload", type='filepath', optional=True),
        gr.inputs.Audio(source="microphone", type="filepath", label="Record something...", optional=True),
        gr.inputs.Radio(label="Model", choices=MODELS),
        gr.inputs.Radio(label="Decoding type", choices=["Standard", "Guided by Language Model"]),
        "state"
    ],
    outputs=[
        gr.outputs.HTML(label="Outputs"),
        "state"
    ],
    title="Italian Robust ASR",
    description="",
    css="""
    .result {display:flex;flex-direction:column}
    .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
    .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
    .result_item_error {background-color:#ff7070;color:white;align-self:start}
    """,
    allow_screenshot=False,
    allow_flagging="never",
    theme="huggingface",
    examples = [
        ['demo_example_1.mp3', 'demo_example_1.mp3', 'robust-300m', 'Guided by Language Model'],
        ['demo_luca_1.wav', 'demo_luca_1.wav', 'robust-300m', 'Guided by Language Model'],
        ['demo_luca_2.wav', 'demo_luca_2.wav', 'robust-300m', 'Guided by Language Model']
    ]
).launch(enable_queue=True)