latest-demo / app.py
msis's picture
Update model restoration method
4087c59
from datetime import datetime, timedelta
import os
import gradio as gr
import nemo.collections.asr as nemo_asr
import wandb
MODEL_HISTORY_DAYS = 180
WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "tarteel")
WANDB_PROJECT_NAME = os.environ.get("WANDB_PROJECT_NAME", "nemo-experiments")
wandb_api = wandb.Api(overrides={"entity": WANDB_ENTITY})
all_artifacts_versions = [
version
for version in [
collection.versions()
for collection in wandb_api.artifact_type(
type_name="model", project=WANDB_PROJECT_NAME
).collections()
]
]
latest_artifacts = [
artifact
for artifact_versions in all_artifacts_versions
for artifact in artifact_versions
if (
datetime.fromisoformat(artifact.created_at)
> datetime.now() - timedelta(days=MODEL_HISTORY_DAYS) # last 180 days
and artifact.state != "DELETED"
)
]
latest_artifacts.sort(key=lambda a: a.created_at, reverse=True)
models = {artifact.name: None for artifact in latest_artifacts}
def lazy_load_models(models_names):
for model_name in models_names:
model = models[model_name]
if not model:
models[model_name] = nemo_asr.models.EncDecCTCModelBPE.restore_from(
list(filter(lambda x: x.name == model_name, latest_artifacts))[0].file()
)
models[model_name].eval()
def transcribe(audio_mic, audio_file, models_names):
lazy_load_models(models_names)
# transcribe audio_mic and audio_file separately
# because transcribe() fails is path is empty
transcription_mic = "\n".join(
[
f"{model_name} => {models[model_name].transcribe([audio_mic])[0]}"
for model_name in models_names
]
if audio_mic
else ""
)
transcription_file = "\n".join(
[
f"{model_name} => {models[model_name].transcribe([audio_file])[0]}"
for model_name in models_names
]
if audio_file
else ""
)
return transcription_mic, transcription_file
model_selection = list(models.keys())
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# ﷽
These are the latest* Tarteel models.
Please note that the models are lazy loaded.
This means that the first time you use a model,
it might take some time to be downloaded and loaded for inference.
*: last 180 days since the space was launched.
To update the list, restart the space.
"""
)
with gr.Row():
audio_mic = gr.Audio(source="microphone", type="filepath", label="Microphone")
audio_file = gr.Audio(source="upload", type="filepath", label="File")
with gr.Row():
output_mic = gr.TextArea(label="Microphone Transcription")
output_file = gr.TextArea(label="Audio Transcription")
models_names = gr.CheckboxGroup(model_selection, label="Select Models to Use")
b1 = gr.Button("Transcribe")
b1.click(
transcribe,
inputs=[audio_mic, audio_file, models_names],
outputs=[output_mic, output_file],
)
demo.launch()