Alex Cabrera
audio
c24ff9a
raw history blame
No virus
1.91 kB
import math
import os
import pandas as pd
import torch
import whisper
from jiwer import wer
from zeno import (
ZenoOptions,
distill,
metric,
model,
DistillReturn,
ModelReturn,
MetricReturn,
)
@model
def load_model(model_path):
if "sst" in model_path:
device = torch.device("cpu")
model, decoder, utils = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_stt",
language="en",
device=device,
)
(read_batch, _, _, prepare_model_input) = utils
def pred(df, ops: ZenoOptions):
files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
input = prepare_model_input(read_batch(files), device=device)
return ModelReturn(model_output=[decoder(x.cpu()) for x in model(input)])
return pred
elif "whisper" in model_path:
model = whisper.load_model("tiny")
def pred(df, ops: ZenoOptions):
files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
outs = []
for f in files:
outs.append(model.transcribe(f)["text"])
return ModelReturn(model_output=outs)
return pred
@distill
def country(df, ops: ZenoOptions):
if df["birthplace"][0] == df["birthplace"][0]:
return DistillReturn(distill_output=[df["birthplace"].str.split(", ")[-1][-1]])
return DistillReturn(distill_output=[""] * len(df))
@distill
def wer_m(df, ops: ZenoOptions):
return DistillReturn(
distill_output=df.apply(
lambda x: wer(x[ops.label_column], x[ops.output_column]), axis=1
)
)
@metric
def avg_wer(df, ops: ZenoOptions):
avg = df[ops.distill_columns["wer_m"]].mean()
if pd.isnull(avg) or math.isnan(avg):
return MetricReturn(metric=0)
return MetricReturn(metric=avg)