jharrison27 commited on
Commit
31fb546
1 Parent(s): 6ef9e0f

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +152 -0
  2. packages.txt +1 -0
  3. requirements.txt +10 -0
  4. test.json +12 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import logging
3
+ import sys
4
+ import tempfile
5
+ import numpy as np
6
+ import datetime
7
+
8
+ from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
9
+ from typing import Optional
10
+ from TTS.utils.manage import ModelManager
11
+ from TTS.utils.synthesizer import Synthesizer
12
+
13
+ logging.basicConfig(
14
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
15
+ datefmt="%m/%d/%Y %H:%M:%S",
16
+ handlers=[logging.StreamHandler(sys.stdout)],
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+ logger.setLevel(logging.DEBUG)
20
+
21
+
22
+ LARGE_MODEL_BY_LANGUAGE = {
23
+ "Arabic": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", "has_lm": False},
24
+ "Chinese": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", "has_lm": False},
25
+ #"Dutch": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "has_lm": False},
26
+ "English": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "has_lm": True},
27
+ "Finnish": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", "has_lm": False},
28
+ "French": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-french", "has_lm": True},
29
+ "German": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-german", "has_lm": True},
30
+ "Greek": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "has_lm": False},
31
+ "Hungarian": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", "has_lm": False},
32
+ "Italian": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-italian", "has_lm": True},
33
+ "Japanese": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", "has_lm": False},
34
+ "Persian": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "has_lm": False},
35
+ "Polish": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "has_lm": True},
36
+ "Portuguese": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "has_lm": True},
37
+ "Russian": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "has_lm": True},
38
+ "Spanish": {"model_id": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "has_lm": True},
39
+ }
40
+
41
+ XLARGE_MODEL_BY_LANGUAGE = {
42
+ "English": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-english", "has_lm": True},
43
+ "Spanish": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-spanish", "has_lm": True},
44
+ "German": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-german", "has_lm": True},
45
+ "Russian": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-russian", "has_lm": True},
46
+ "French": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-french", "has_lm": True},
47
+ "Italian": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-italian", "has_lm": True},
48
+ #"Dutch": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-dutch", "has_lm": False},
49
+ "Polish": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-polish", "has_lm": True},
50
+ "Portuguese": {"model_id": "jonatasgrosman/wav2vec2-xls-r-1b-portuguese", "has_lm": True},
51
+ }
52
+
53
+
54
+ # LANGUAGES = sorted(LARGE_MODEL_BY_LANGUAGE.keys())
55
+
56
+ # the container given by HF has 16GB of RAM, so we need to limit the number of models to load
57
+ LANGUAGES = sorted(XLARGE_MODEL_BY_LANGUAGE.keys())
58
+ CACHED_MODELS_BY_ID = {}
59
+
60
+
61
+ def run(input_file, language, decoding_type, history, model_size="300M"):
62
+
63
+ logger.info(f"Running ASR {language}-{model_size}-{decoding_type} for {input_file}")
64
+
65
+ history = history or []
66
+
67
+ if model_size == "300M":
68
+ model = LARGE_MODEL_BY_LANGUAGE.get(language, None)
69
+ else:
70
+ model = XLARGE_MODEL_BY_LANGUAGE.get(language, None)
71
+
72
+ if model is None:
73
+ history.append({
74
+ "error_message": f"Model size {model_size} not found for {language} language :("
75
+ })
76
+ elif decoding_type == "LM" and not model["has_lm"]:
77
+ history.append({
78
+ "error_message": f"LM not available for {language} language :("
79
+ })
80
+ else:
81
+
82
+ # model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
83
+ model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
84
+ if model_instance is None:
85
+ model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
86
+ CACHED_MODELS_BY_ID[model["model_id"]] = model_instance
87
+
88
+ if decoding_type == "LM":
89
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
90
+ asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
91
+ feature_extractor=processor.feature_extractor, decoder=processor.decoder)
92
+ else:
93
+ processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
94
+ asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
95
+ feature_extractor=processor.feature_extractor, decoder=None)
96
+
97
+ transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]
98
+
99
+ logger.info(f"Transcription for {input_file}: {transcription}")
100
+
101
+ history.append({
102
+ "model_id": model["model_id"],
103
+ "language": language,
104
+ "model_size": model_size,
105
+ "decoding_type": decoding_type,
106
+ "transcription": transcription,
107
+ "error_message": None
108
+ })
109
+
110
+ html_output = "<div class='result'>"
111
+ for item in history:
112
+ if item["error_message"] is not None:
113
+ html_output += f"<div class='result_item result_item_error'>{item['error_message']}</div>"
114
+ else:
115
+ url_suffix = " + LM" if item["decoding_type"] == "LM" else ""
116
+ html_output += "<div class='result_item result_item_success'>"
117
+ html_output += f'<strong><a target="_blank" href="https://huggingface.co/{item["model_id"]}">{item["model_id"]}{url_suffix}</a></strong><br/><br/>'
118
+ html_output += f'{item["transcription"]}<br/>'
119
+ html_output += "</div>"
120
+ html_output += "</div>"
121
+
122
+ return html_output, history
123
+
124
+
125
+ gr.Interface(
126
+ run,
127
+ inputs=[
128
+ #gr.inputs.Audio(source="microphone", type="filepath", label="Record something..."),
129
+ gr.Audio(source="microphone", type='filepath', streaming=True),
130
+ #gr.inputs.Audio(source="microphone", type="filepath", label="Record something...", streaming="True"),
131
+ gr.inputs.Radio(label="Language", choices=LANGUAGES),
132
+ gr.inputs.Radio(label="Decoding type", choices=["greedy", "LM"]),
133
+ # gr.inputs.Radio(label="Model size", choices=["300M", "1B"]),
134
+ "state"
135
+ ],
136
+ outputs=[
137
+ gr.outputs.HTML(label="Outputs"),
138
+ "state"
139
+ ],
140
+ title="🗣️NLP ASR Wav2Vec2 GR📄",
141
+ description="",
142
+ css="""
143
+ .result {display:flex;flex-direction:column}
144
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
145
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
146
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
147
+ """,
148
+ allow_screenshot=False,
149
+ allow_flagging="never",
150
+ theme="grass",
151
+ live=True # test1
152
+ ).launch(enable_queue=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ pyctcdecode
4
+ pypi-kenlm
5
+ streamlit
6
+ google-cloud-firestore
7
+ firebase-admin
8
+ Werkzeug==2.0.3
9
+ huggingface_hub==0.4.0
10
+ TTS==0.2.1
test.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "service_account",
3
+ "project_id": "clinical-nlp-b9117",
4
+ "private_key_id": "6972d02311e8ee0c5b582551fbcf9c99b9169b58",
5
+ "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCmrSoB92G/ihxL\nzIk7Y8RUNc6Iezr6pZ+eSz2RGxEz2qPMfWjNeOJEAlACYJp4aUwyX5IHGb8Eh/oj\nkr7nVsgvuDyrTWpCAv16AuRycKgxvqj0+uDaVrF0vLgTumy62x5QM7i+n2YTDXoP\nXHMHX7yXZ6zc9Ibmm065f2kgWyjmIZDt+flTBYeBS203ZIzMBHhN1e1jdtzR36z/\n1MBmLjpRKvmuHF2SnraVjoRh7Xe6R99K8DxRQ61TJt9xLukvLBYelnqf2/cK8bZM\n5p2pErR4FE7ki3MX7HWdMJQSe+Uj10hurjNBdHcCaNUou5EL5+NRgqLow0tfatWC\n+Jpiw3K9AgMBAAECggEAGpT7YhzmBfos0RnpuQMMSLHcIoAkw9yuPDybsQy0DaUN\nAovtrvdcfqQvxnFJsXJ5qH79dwxwHnThO9MnhxWcD6A+bMOH8scvTcowTOASsvxJ\nTejE+41f99IxOVQ+Cv7vMrNM/3nEeb1ofhKsdbybAzqRoxuMeDLEt2jOh06Ck1D8\n/YV8kavGYR/VNxO2l7C5DZJYXgcm18ZrTFEXZes8bydZesoHl+JRVO1utjR2IhAj\nnYqqNaf5RXruEzXWxP0+jjEgg4NLFfqVnQTZFrLwokwc8NEMXf3dZJ0k0cHHmxOB\n6BHuPZhMOZ56U74PyWgCmbPp9g/SLt3iInpZ4ahmAQKBgQDhQwdbUEQ1q+KSMsMm\ndJl+ghX/Ff3uaZ7LjdBiOgtmTaIVbuf/bw0V9x8GbRGdJJyp546R5vhUE0zKzkMt\nTNdDNrWk3Zh4tCRHvPEHiqmDn91pWFeDDQf/OjKz+SFV31mQ050BOatZ8dBEy+md\nvHG8yLTB7oJvSpviim4ty15wIQKBgQC9a5jsBFB0fltHNJ0lZp7I2hF+aOqOngJM\nqEipPjJABJ4izGTOK/KW8CyWEP82nb6p7u9v0f4sV8CFWXG178DMv1NlRYzom3CQ\nkXdx+nRgO4oX4eEfYuoP2PxF0hCOwbh55NgFdwTt/dExX6bau4d9yQMV7o0TXpRW\nZzygOOTfHQKBgQC7ayhwyfymZydwmjmSAks/XX5tqN+IgGo1U/1/7GlVqdvkV01B\nUiUiFGTE1PRluXN7TYRqUjBky1YGGsz7oMYtTxScYh6ctszEvygPLUhSki0GnBDb\noXj42nQbF3mr19POUrJ7tX6irDWrN7lcmtBK0PbLr+ToMbw3JRP8mAsv4QKBgEac\nC18/pHYofAIpHMNKY7pff9HtbjJHuHe2648bPkQa9I/oPVOVklKtqREvuNM1LlPO\nW7cFQohpFb0fwIGfo/EvCPlhWcuD1gwuDaaRRDxzNWD9tJusla/epPup+L4efJQD\nuHshCNdmnEqZa2tyKGm9Osc8K56izQ0AYtsfGkIJAoGAMtaXTA96OXUvpEm4waQX\nOTbuEZQEdntnYWHacNrGlvwnNmvNC9hXwB38ijxXHEn0j1QUcV3w5QXFupwzjpZ2\nlIp9vTq1mOTVhHzmQmOb9DKKAE/2pi2HnekItncoQCBtgJ7k6tIk1KEfvXuQS/oM\nh8qPMwuMcQ/vKGhl3xLYo9M=\n-----END PRIVATE KEY-----\n",
6
+ "client_email": "firebase-adminsdk-qaxaj@clinical-nlp-b9117.iam.gserviceaccount.com",
7
+ "client_id": "117623958723912081118",
8
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
9
+ "token_uri": "https://oauth2.googleapis.com/token",
10
+ "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
11
+ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-qaxaj%40clinical-nlp-b9117.iam.gserviceaccount.com"
12
+ }