File size: 13,747 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import glob
import os
import shutil
import time
from html import unescape
from uuid import uuid4

import model_api
import torch
import werkzeug
from flask import Flask, make_response, render_template, request, url_for
from flask_cors import CORS
from werkzeug.utils import secure_filename

from nemo.utils import logging

app = Flask(__name__)
CORS(app)

# Upload folder for audio files; models are stored in permanent cache
# which gets deleted once the container shuts down
app.config[f'UPLOAD_FOLDER'] = f"tmp/"


@app.route('/initialize_model', methods=['POST'])
def initialize_model():
    """
    API Endpoint to instantiate a model

    Loads ASR model by its pretrained checkpoint name or upload ASR model that is provided by the user,
    then load that checkpoint into the cache.

    Loading of the model into cache is done once per worker. Number of workers should be limited
    so as not to exhaust the GPU memory available on device (if GPU is being used).
    """
    logging.info("Starting ASR service")
    if torch.cuda.is_available():
        logging.info("CUDA is available. Running on GPU")
    else:
        logging.info("CUDA is not available. Defaulting to CPUs")

    # get form fields
    model_name = request.form['model_names_select']
    use_gpu_if_available = request.form.get('use_gpu_ckbx', "off")

    # get nemo model from user (if not none)
    nemo_model_file = request.files.get('nemo_model', '')

    # if nemo model is not None, upload it to model cache
    if nemo_model_file != '':
        model_name = _store_model(nemo_model_file)

        # Alert user that model has been uploaded into the model cache,
        # and they should refresh the page to access the model
        result = render_template(
            'toast_msg.html', toast_message=f"Model {model_name} has been uploaded. " f"Refresh page !", timeout=5000
        )

    else:
        # Alert user that model has been loaded onto a workers memory
        result = render_template(
            'toast_msg.html', toast_message=f"Model {model_name} has been initialized !", timeout=2000
        )

    # Load model into memory cache
    model_api.initialize_model(model_name=model_name)

    # reset file banner
    reset_nemo_model_file_script = """
        <script>
            document.getElementById('nemo_model_file').value = ""
        </script>
    """

    result = result + reset_nemo_model_file_script
    result = make_response(result)

    # set cookies
    result.set_cookie("model_name", model_name)
    result.set_cookie("use_gpu", use_gpu_if_available)
    return result


def _store_model(nemo_model_file):
    """
    Preserve the model supplied by user into permanent cache
    This cache needs to be manually deleted (if run locally), or gets deleted automatically
    (when the container gets shutdown / killed).

    Args:
        nemo_model_file: User path to .nemo checkpoint.

    Returns:
        A file name (with a .nemo) at the end - to signify this is an uploaded checkpoint.
    """
    filename = secure_filename(nemo_model_file.filename)
    file_basename = os.path.basename(filename)
    model_dir = os.path.splitext(file_basename)[0]

    model_store = os.path.join('models', model_dir)
    if not os.path.exists(model_store):
        os.makedirs(model_store)

    # upload model
    model_path = os.path.join(model_store, filename)
    nemo_model_file.save(model_path)
    return file_basename


@app.route('/upload_audio_files', methods=['POST'])
def upload_audio_files():
    """
    API Endpoint to upload audio files for inference.

    The uploaded files must be wav files, 16 KHz sample rate, mono-channel audio samples.
    """
    # Try to get one or more files from form
    try:
        f = request.files.getlist('file')
    except werkzeug.exceptions.BadRequestKeyError:
        f = None

    # If user did not select any file to upload, notify them.
    if f is None or len(f) == 0:
        toast = render_template('toast_msg.html', toast_message="No file has been selected to upload !", timeout=2000)
        result = render_template('updates/upload_files_failed.html', pre_exec=toast, url=url_for('upload_audio_files'))
        result = unescape(result)
        return result

    # temporary id to store data
    uuid = str(uuid4())
    data_store = os.path.join(app.config[f'UPLOAD_FOLDER'], uuid)

    # If the user attempt to upload another set of files without first transcribing them,
    # delete the old cache of files and create a new cache entirely
    _remove_older_files_if_exists()

    # Save each file into this unique cache
    for fn in f:
        filename = secure_filename(fn.filename)
        if not os.path.exists(data_store):
            os.makedirs(data_store)

        fn.save(os.path.join(data_store, filename))
        logging.info(f"Saving file : {fn.filename}")

    # Update user that N files were uploaded.
    msg = f"{len(f)} file(s) uploaded. Click to upload more !"
    toast = render_template('toast_msg.html', toast_message=f"{len(f)} file(s) uploaded !", timeout=2000)
    result = render_template(
        'updates/upload_files_successful.html', pre_exec=toast, msg=msg, url=url_for('upload_audio_files')
    )
    result = unescape(result)

    result = make_response(result)
    result.set_cookie("uuid", uuid)
    return result


def _remove_older_files_if_exists():
    """
    Helper method to prevent cache leakage when user attempts to upload another set of files
    without first transcribing the files already uploaded.
    """
    # remove old data store (if exists)
    old_uuid = secure_filename(request.cookies.get('uuid', ''))
    if old_uuid is not None and old_uuid != '':
        # delete old data store
        old_data_store = os.path.join(app.config[f'UPLOAD_FOLDER'], old_uuid)

        logging.info("Tried uploading more data without using old uploaded data. Purging data cache.")
        shutil.rmtree(old_data_store, ignore_errors=True)


@app.route('/remove_audio_files', methods=['POST'])
def remove_audio_files():
    """
    API Endpoint for removing audio files

    # Note: Sometimes data may persist due to set of circumstances:

        - User uploads audio then closes app without transcribing anything

    In such a case, the files will be deleted when gunicorn shutsdown, or container is stopped.
    However the data may not be automatically deleted if the flast server is used as is.
    """
    # Get the unique cache id from cookie
    uuid = secure_filename(request.cookies.get("uuid", ""))
    data_store = os.path.join(app.config[f'UPLOAD_FOLDER'], uuid)

    # If the data does not exist (cache is empty), notify user
    if not os.path.exists(data_store) or uuid == "":
        files_dont_exist = render_template(
            'toast_msg.html', toast_message="No files have been uploaded !", timeout=2000
        )
        result = render_template(
            'updates/remove_files.html', pre_exec=files_dont_exist, url=url_for('remove_audio_files')
        )
        result = unescape(result)
        return result

    else:
        # delete data that exists in cache
        shutil.rmtree(data_store, ignore_errors=True)

        logging.info("Removed all data")

        # Notify user that cache was deleted.
        toast = render_template('toast_msg.html', toast_message="All files removed !", timeout=2000)
        result = render_template('updates/remove_files.html', pre_exec=toast, url=url_for('remove_audio_files'))
        result = unescape(result)

        result = make_response(result)
        result.set_cookie("uuid", '', expires=0)
        return result


@app.route('/transcribe', methods=['POST'])
def transcribe():
    """
    API Endpoint to transcribe a set of audio files.

    The files are sorted according to their name, so order may not be same as upload order.

    Utilizing the cached info inside the cookies, a model with selected name will be loaded into memory,
    and maybe onto a GPU (if it is supported on the device).

    Then the transcription api will be called from the model_api. If all is successful, a template is updated
    with results. If some issue occurs (memory ran out, file is invalid format), notify the user.
    """
    # load model name from cookie
    model_name = request.cookies.get('model_name')
    logging.info(f"Model name : {model_name}")

    # If model name is not selected via Load Model, notify user.
    if model_name is None or model_name == '':
        result = render_template('toast_msg.html', toast_message="Model has not been initialized !", timeout=2000)
        return result

    # load whether gpu should be used
    use_gpu_if_available = request.cookies.get('use_gpu') == 'on'
    gpu_used = torch.cuda.is_available() and use_gpu_if_available

    # Load audio from paths
    uuid = secure_filename(request.cookies.get("uuid", ""))
    data_store = os.path.join(app.config[f'UPLOAD_FOLDER'], uuid)

    files = list(glob.glob(os.path.join(data_store, "*.wav")))

    # If no files found in cache, notify user
    if len(files) == 0:
        result = render_template('toast_msg.html', toast_message="No audio files were found !", timeout=2000)
        return result

    # transcribe file via model api
    t1 = time.time()
    transcriptions = model_api.transcribe_all(files, model_name, use_gpu_if_available=use_gpu_if_available)
    t2 = time.time()

    # delete all transcribed files immediately
    for fp in files:
        try:
            os.remove(fp)
        except FileNotFoundError:
            logging.info(f"Failed to delete transcribed file : {os.path.basename(fp)}")

    # delete temporary transcription directory
    shutil.rmtree(data_store, ignore_errors=True)

    # If something happened during transcription, and it failed, notify user.
    if type(transcriptions) == str and transcriptions == model_api.TAG_ERROR_DURING_TRANSCRIPTION:
        toast = render_template(
            'toast_msg.html',
            toast_message=f"Failed to transcribe files due to unknown reason. "
            f"Please provide 16 KHz Monochannel wav files onle.",
            timeout=5000,
        )
        transcriptions = ["" for _ in range(len(files))]

    else:
        # Transcriptions obtained successfully, notify user.
        toast = render_template(
            'toast_msg.html',
            toast_message=f"Transcribed {len(files)} files using {model_name} (gpu={gpu_used}), "
            f"in {(t2 - t1): 0.2f} s",
            timeout=5000,
        )

    # Write results to data table
    results = []
    for filename, transcript in zip(files, transcriptions):
        results.append(dict(filename=os.path.basename(filename), transcription=transcript))

    result = render_template('transcripts.html', transcripts=results)
    result = toast + result
    result = unescape(result)

    result = make_response(result)
    result.set_cookie("uuid", "", expires=0)
    return result


def remove_tmp_dir_at_exit():
    """
    Helper method to attempt a deletion of audio file cache on flask api exit.
    Gunicorn and Docker container (based on gunicorn) will delete any remaining files on
    shutdown of the gunicorn server or the docker container.

    This is a patch that might not always work for Flask server, but in general should ensure
    that local audio file cache is deleted.

    This does *not* impact the model cache. Flask and Gunicorn servers will *never* delete uploaded models.
    Docker container will delete models *only* when the container is killed (since models are uploaded to
    local storage path inside container).
    """
    try:
        uuid = secure_filename(request.cookies.get("uuid", ""))

        if uuid is not None or uuid != "":
            cache_dir = os.path.join(os.path.join(app.config[f'UPLOAD_FOLDER'], uuid))
            logging.info(f"Removing cache file for worker : {os.getpid()}")

            if os.path.exists(cache_dir):
                shutil.rmtree(cache_dir, ignore_errors=True)
                logging.info(f"Deleted tmp folder : {cache_dir}")

    except RuntimeError:
        # Working outside of request context (probably shutdown)
        # simply delete entire tmp folder
        shutil.rmtree(app.config[f'UPLOAD_FOLDER'], ignore_errors=True)


@app.route('/')
def main():
    """
    API Endpoint for ASR Service.
    """
    nemo_model_names, local_model_names = model_api.get_model_names()
    model_names = []
    model_names.extend(local_model_names)  # prioritize local models
    model_names.extend(nemo_model_names)  # attach all other pretrained models

    # page initializations
    result = render_template('main.html', model_names=model_names)
    result = make_response(result)

    # Reset cookies
    result.set_cookie("model_name", '', expires=0)  # model name from pretrained model list
    result.set_cookie("use_gpu", '', expires=0)  # flag to use gpu (if available)
    result.set_cookie("uuid", '', expires=0)  # session id
    return result


# Register hook to delete file cache (for flask server only)
atexit.register(remove_tmp_dir_at_exit)


if __name__ == '__main__':
    app.run(False)