Spaces:
Running
Running
| import math | |
| import os | |
| import pandas as pd | |
| import torch | |
| import whisper | |
| from jiwer import wer | |
| from zeno import ( | |
| ZenoOptions, | |
| distill, | |
| metric, | |
| model, | |
| DistillReturn, | |
| ModelReturn, | |
| MetricReturn, | |
| ) | |
| def load_model(model_path): | |
| if "sst" in model_path: | |
| device = torch.device("cpu") | |
| model, decoder, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-models", | |
| model="silero_stt", | |
| language="en", | |
| device=device, | |
| ) | |
| (read_batch, _, _, prepare_model_input) = utils | |
| def pred(df, ops: ZenoOptions): | |
| files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]] | |
| input = prepare_model_input(read_batch(files), device=device) | |
| return ModelReturn(model_output=[decoder(x.cpu()) for x in model(input)]) | |
| return pred | |
| elif "whisper" in model_path: | |
| model = whisper.load_model("tiny") | |
| def pred(df, ops: ZenoOptions): | |
| files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]] | |
| outs = [] | |
| for f in files: | |
| outs.append(model.transcribe(f)["text"]) | |
| return ModelReturn(model_output=outs) | |
| return pred | |
| def country(df, ops: ZenoOptions): | |
| if df["birthplace"][0] == df["birthplace"][0]: | |
| return DistillReturn(distill_output=[df["birthplace"].str.split(", ")[-1][-1]]) | |
| return DistillReturn(distill_output=[""] * len(df)) | |
| def wer_m(df, ops: ZenoOptions): | |
| return DistillReturn( | |
| distill_output=df.apply( | |
| lambda x: wer(x[ops.label_column], x[ops.output_column]), axis=1 | |
| ) | |
| ) | |
| def avg_wer(df, ops: ZenoOptions): | |
| avg = df[ops.distill_columns["wer_m"]].mean() | |
| if pd.isnull(avg) or math.isnan(avg): | |
| return MetricReturn(metric=0) | |
| return MetricReturn(metric=avg) | |