Spaces:
Running
Running
File size: 12,214 Bytes
c24ff9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from zeno import zeno\n",
"import math\n",
"import os\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(\"metadata.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"df.set_index('id', inplace=True, drop=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "'id' is both an index level and a column label, which is ambiguous.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m df\u001b[39m.\u001b[39;49mgroupby(\u001b[39m'\u001b[39;49m\u001b[39mid\u001b[39;49m\u001b[39m'\u001b[39;49m)\n",
"File \u001b[0;32m~/dev-research/22-zeno/zeno/.venv/lib/python3.8/site-packages/pandas/core/frame.py:8402\u001b[0m, in \u001b[0;36mDataFrame.groupby\u001b[0;34m(self, by, axis, level, as_index, sort, group_keys, squeeze, observed, dropna)\u001b[0m\n\u001b[1;32m 8399\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mYou have to supply one of \u001b[39m\u001b[39m'\u001b[39m\u001b[39mby\u001b[39m\u001b[39m'\u001b[39m\u001b[39m and \u001b[39m\u001b[39m'\u001b[39m\u001b[39mlevel\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 8400\u001b[0m axis \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_axis_number(axis)\n\u001b[0;32m-> 8402\u001b[0m \u001b[39mreturn\u001b[39;00m DataFrameGroupBy(\n\u001b[1;32m 8403\u001b[0m obj\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m,\n\u001b[1;32m 8404\u001b[0m keys\u001b[39m=\u001b[39;49mby,\n\u001b[1;32m 8405\u001b[0m axis\u001b[39m=\u001b[39;49maxis,\n\u001b[1;32m 8406\u001b[0m level\u001b[39m=\u001b[39;49mlevel,\n\u001b[1;32m 8407\u001b[0m as_index\u001b[39m=\u001b[39;49mas_index,\n\u001b[1;32m 8408\u001b[0m sort\u001b[39m=\u001b[39;49msort,\n\u001b[1;32m 8409\u001b[0m group_keys\u001b[39m=\u001b[39;49mgroup_keys,\n\u001b[1;32m 8410\u001b[0m squeeze\u001b[39m=\u001b[39;49msqueeze,\n\u001b[1;32m 8411\u001b[0m observed\u001b[39m=\u001b[39;49mobserved,\n\u001b[1;32m 8412\u001b[0m dropna\u001b[39m=\u001b[39;49mdropna,\n\u001b[1;32m 8413\u001b[0m )\n",
"File \u001b[0;32m~/dev-research/22-zeno/zeno/.venv/lib/python3.8/site-packages/pandas/core/groupby/groupby.py:965\u001b[0m, in \u001b[0;36mGroupBy.__init__\u001b[0;34m(self, obj, keys, axis, level, grouper, exclusions, selection, as_index, sort, group_keys, squeeze, observed, mutated, dropna)\u001b[0m\n\u001b[1;32m 962\u001b[0m \u001b[39mif\u001b[39;00m grouper \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 963\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mpandas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcore\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mgroupby\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mgrouper\u001b[39;00m \u001b[39mimport\u001b[39;00m get_grouper\n\u001b[0;32m--> 965\u001b[0m grouper, exclusions, obj \u001b[39m=\u001b[39m get_grouper(\n\u001b[1;32m 966\u001b[0m obj,\n\u001b[1;32m 967\u001b[0m keys,\n\u001b[1;32m 968\u001b[0m axis\u001b[39m=\u001b[39;49maxis,\n\u001b[1;32m 969\u001b[0m level\u001b[39m=\u001b[39;49mlevel,\n\u001b[1;32m 970\u001b[0m sort\u001b[39m=\u001b[39;49msort,\n\u001b[1;32m 971\u001b[0m observed\u001b[39m=\u001b[39;49mobserved,\n\u001b[1;32m 972\u001b[0m mutated\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmutated,\n\u001b[1;32m 973\u001b[0m dropna\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdropna,\n\u001b[1;32m 974\u001b[0m )\n\u001b[1;32m 976\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobj \u001b[39m=\u001b[39m obj\n\u001b[1;32m 977\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maxis \u001b[39m=\u001b[39m obj\u001b[39m.\u001b[39m_get_axis_number(axis)\n",
"File \u001b[0;32m~/dev-research/22-zeno/zeno/.venv/lib/python3.8/site-packages/pandas/core/groupby/grouper.py:878\u001b[0m, in \u001b[0;36mget_grouper\u001b[0;34m(obj, key, axis, level, sort, observed, mutated, validate, dropna)\u001b[0m\n\u001b[1;32m 876\u001b[0m \u001b[39mif\u001b[39;00m gpr \u001b[39min\u001b[39;00m obj:\n\u001b[1;32m 877\u001b[0m \u001b[39mif\u001b[39;00m validate:\n\u001b[0;32m--> 878\u001b[0m obj\u001b[39m.\u001b[39;49m_check_label_or_level_ambiguity(gpr, axis\u001b[39m=\u001b[39;49maxis)\n\u001b[1;32m 879\u001b[0m in_axis, name, gpr \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m, gpr, obj[gpr]\n\u001b[1;32m 880\u001b[0m \u001b[39mif\u001b[39;00m gpr\u001b[39m.\u001b[39mndim \u001b[39m!=\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 881\u001b[0m \u001b[39m# non-unique columns; raise here to get the name in the\u001b[39;00m\n\u001b[1;32m 882\u001b[0m \u001b[39m# exception message\u001b[39;00m\n",
"File \u001b[0;32m~/dev-research/22-zeno/zeno/.venv/lib/python3.8/site-packages/pandas/core/generic.py:1797\u001b[0m, in \u001b[0;36mNDFrame._check_label_or_level_ambiguity\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1789\u001b[0m label_article, label_type \u001b[39m=\u001b[39m (\n\u001b[1;32m 1790\u001b[0m (\u001b[39m\"\u001b[39m\u001b[39ma\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mif\u001b[39;00m axis \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39melse\u001b[39;00m (\u001b[39m\"\u001b[39m\u001b[39man\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mindex\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 1791\u001b[0m )\n\u001b[1;32m 1793\u001b[0m msg \u001b[39m=\u001b[39m (\n\u001b[1;32m 1794\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mkey\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m is both \u001b[39m\u001b[39m{\u001b[39;00mlevel_article\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00mlevel_type\u001b[39m}\u001b[39;00m\u001b[39m level and \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1795\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mlabel_article\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00mlabel_type\u001b[39m}\u001b[39;00m\u001b[39m label, which is ambiguous.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1796\u001b[0m )\n\u001b[0;32m-> 1797\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(msg)\n",
"\u001b[0;31mValueError\u001b[0m: 'id' is both an index level and a column label, which is ambiguous."
]
}
],
"source": [
"df.groupby('id')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"zeno({\n",
" \"metadata\": df[0:10],\n",
" \"view\": \"audio-transcription\",\n",
" \"data_path\": \"/Users/acabrera/dev/data/speech-accent-archive/recordings/recordings/\",\n",
" \"label_column\": \"label\",\n",
" \"data_column\": \"id\"\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import whisper\n",
"from jiwer import wer\n",
"from zeno import ZenoOptions, distill, metric, model\n",
"import numpy as np\n",
"from zeno import ZenoOptions, distill"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@model\n",
"def load_model(model_path):\n",
" if \"sst\" in model_path:\n",
" device = torch.device(\"cpu\")\n",
" model, decoder, utils = torch.hub.load(\n",
" repo_or_dir=\"snakers4/silero-models\",\n",
" model=\"silero_stt\",\n",
" language=\"en\",\n",
" device=device,\n",
" )\n",
" (read_batch, _, _, prepare_model_input) = utils\n",
"\n",
" def pred(df, ops: ZenoOptions):\n",
" files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]\n",
" input = prepare_model_input(read_batch(files), device=device)\n",
" return [decoder(x.cpu()) for x in model(input)]\n",
"\n",
" return pred\n",
"\n",
" elif \"whisper\" in model_path:\n",
" model = whisper.load_model(\"tiny\")\n",
"\n",
" def pred(df, ops: ZenoOptions):\n",
" files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]\n",
" outs = []\n",
" for f in files:\n",
" outs.append(model.transcribe(f)[\"text\"])\n",
" return outs\n",
"\n",
" return pred\n",
"\n",
"\n",
"@distill\n",
"def country(df, ops: ZenoOptions):\n",
" if df[\"0birthplace\"][0] == df[\"0birthplace\"][0]:\n",
" return df[\"0birthplace\"].str.split(\", \")[-1][-1]\n",
" return \"\"\n",
"\n",
"\n",
"@distill\n",
"def wer_m(df, ops: ZenoOptions):\n",
" return df.apply(lambda x: wer(x[ops.label_column], x[ops.output_column]), axis=1)\n",
"\n",
"\n",
"@metric\n",
"def avg_wer(df, ops: ZenoOptions):\n",
" avg = df[ops.distill_columns[\"wer_m\"]].mean()\n",
" if math.isnan(avg):\n",
" return 0\n",
" return avg\n",
"\n",
"# @distill\n",
"# def amplitude(df, ops: ZenoOptions):\n",
"# files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]\n",
"# amps = []\n",
"# for audio in files:\n",
"# y, _ = librosa.load(audio)\n",
"# amps.append(float(np.abs(y).mean()))\n",
"# return amps\n",
"\n",
"\n",
"# @distill\n",
"# def length(df, ops: ZenoOptions):\n",
"# files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]\n",
"# amps = []\n",
"# for audio in files:\n",
"# y, _ = librosa.load(audio)\n",
"# amps.append(len(y))\n",
"# return amps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"zeno({\n",
" \"metadata\": df,\n",
" \"functions\": [load_model, country, wer_m, avg_wer],\n",
" \"view\": \"audio-transcription\",\n",
" \"models\": [\"silero_sst\", \"whisper\"],\n",
" \"data_path\": \"/Users/acabrera/dev/data/speech-accent-archive/recordings/recordings/\",\n",
" \"data_column\": \"id\",\n",
" \"label_column\": \"label\",\n",
" \"samples\": 10,\n",
"})\n",
"# metadata = \"metadata.csv\"\n",
"# # data_path = \"https://zenoml.s3.amazonaws.com/accents/\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"vscode": {
"interpreter": {
"hash": "59d606a796fde3c997548ee5ab3f3009081de8aa2fb58c2406e58b3c7613e786"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|