"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"import IPython.display as ipd\n",
"import numpy as np\n",
@@ -1144,11 +1665,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 26,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:16:25.363356Z",
- "start_time": "2021-03-13T19:16:25.290149Z"
+ "end_time": "2021-03-14T10:06:50.779682Z",
+ "start_time": "2021-03-14T10:06:50.745322Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
@@ -1156,7 +1677,17 @@
"id": "1Po2g7YPuRTx",
"outputId": "96b0b82c-a5df-4ae6-d17b-9c7d4f710b42"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Target text: ξεπούλησε τα κτήματα μας \n",
+ "Input array shape: (47232,)\n",
+ "Sampling rate: 16000\n"
+ ]
+ }
+ ],
"source": [
"rand_int = random.randint(0, len(common_voice_train))\n",
"\n",
@@ -1192,11 +1723,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 27,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:16:27.583682Z",
- "start_time": "2021-03-13T19:16:27.581228Z"
+ "end_time": "2021-03-14T10:06:51.188480Z",
+ "start_time": "2021-03-14T10:06:51.185722Z"
},
"id": "eJY7I0XAwe9p"
},
@@ -1217,11 +1748,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 28,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:16:59.218803Z",
- "start_time": "2021-03-13T19:16:28.412442Z"
+ "end_time": "2021-03-14T10:07:12.818162Z",
+ "start_time": "2021-03-14T10:06:51.733417Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
@@ -1296,7 +1827,290 @@
"id": "-np9xYK-wl8q",
"outputId": "6155b5f0-a5a2-4e20-d0e2-0b3a60c13f98"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
+ " return array(a, dtype, copy=False, order=order)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " "
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "76843ffa65004ab894917eaf37673b94",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#1', max=59, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1e96ebd9483a491bbe2e78938b0c1444",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#4', max=59, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a22ca1bc906b44c68d8f3455cda46d21",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#2', max=59, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1c19299d349c465094ffb86a81d8445c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#3', max=59, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bb87efafd5c44ba49cfbb0b9dea203bf",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#5', max=58, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7859c3186f9a40ba8c96ef0b6dd03e78",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#6', max=58, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b6cd7f6248ed423ca43ef5fe5698aa05",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#7', max=58, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " "
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1087e951f9854d3f832e841838929931",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#0', max=59, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " "
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ce50341c8bf147a5aac784d409a84ea8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#1', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7a539299f9ad455ca53b7317b5d6fcc0",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#2', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "07b03f8be0da4857901f54f8f6f7b96f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#4', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "fbb85e7ec24a4f22a12525d2b8ef53a1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#3', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " "
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2c4d007a3cf14faaa37d656e2eb1da34",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#7', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8165748923df4b828a6595c7b416a2cb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#6', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c34f0a513be54597b83ba64aa3819d61",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#5', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e00e506138d046d1ab41aa243ced2680",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(IntProgress(value=0, description='#0', max=24, style=ProgressStyle(description_width='initial')…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
"source": [
"common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=8, batched=True)\n",
"common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)"
@@ -1339,11 +2153,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 29,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:16:59.275801Z",
- "start_time": "2021-03-13T19:16:59.270920Z"
+ "end_time": "2021-03-14T10:07:12.850867Z",
+ "start_time": "2021-03-14T10:07:12.844816Z"
},
"id": "tborvC9hx88e"
},
@@ -1419,11 +2233,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 30,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:16:59.329608Z",
- "start_time": "2021-03-13T19:16:59.328133Z"
+ "end_time": "2021-03-14T10:07:12.883379Z",
+ "start_time": "2021-03-14T10:07:12.881910Z"
},
"id": "lbQf5GuZyQ4_"
},
@@ -1444,11 +2258,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 31,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:17:02.227469Z",
- "start_time": "2021-03-13T19:16:59.387853Z"
+ "end_time": "2021-03-14T10:07:15.225831Z",
+ "start_time": "2021-03-14T10:07:12.909977Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
@@ -1486,11 +2300,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 32,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:17:02.292196Z",
- "start_time": "2021-03-13T19:17:02.289811Z"
+ "end_time": "2021-03-14T10:07:15.267780Z",
+ "start_time": "2021-03-14T10:07:15.265135Z"
},
"id": "1XZ-kjweyTy_"
},
@@ -1526,11 +2340,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 33,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:17:10.387115Z",
- "start_time": "2021-03-13T19:17:02.345154Z"
+ "end_time": "2021-03-14T10:07:23.145790Z",
+ "start_time": "2021-03-14T10:07:15.296367Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
@@ -1538,7 +2352,16 @@
"id": "e7cqAWIayn6w",
"outputId": "0a5ab559-6c38-47c6-b4f5-64480ed1df65"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.weight', 'lm_head.bias']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
"source": [
"from transformers import Wav2Vec2ForCTC\n",
"\n",
@@ -1562,7 +2385,7 @@
"id": "1DwR3XLSzGDD"
},
"source": [
- "NOTE: Since Greek is not one of the 53 languages that XLSR-Wav2Vec2 had been pretrained on, I did not follow the below suggestion of freezing the CNN layers. Thus, the following cell has been commented out\n",
+ "NOTE: Since Greek is not one of the 53 languages that XLSR-Wav2Vec2 had been pretrained on, we may or may not follow the below suggestion of freezing the CNN layers. \n",
"\n",
"Original text: The first component of XLSR-Wav2Vec2 consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the [paper](https://arxiv.org/pdf/2006.13979.pdf) does not need to be fine-tuned anymore. \n",
"Thus, we can set the `requires_grad` to `False` for all parameters of the *feature extraction* part."
@@ -1570,17 +2393,17 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 34,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:17:10.485699Z",
- "start_time": "2021-03-13T19:17:10.484259Z"
+ "end_time": "2021-03-14T10:07:23.193215Z",
+ "start_time": "2021-03-14T10:07:23.191472Z"
},
"id": "oGI8zObtZ3V0"
},
"outputs": [],
"source": [
- "#model.freeze_feature_extractor()"
+ "model.freeze_feature_extractor()"
]
},
{
@@ -1601,11 +2424,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 35,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:17:10.616187Z",
- "start_time": "2021-03-13T19:17:10.592735Z"
+ "end_time": "2021-03-14T10:07:23.256610Z",
+ "start_time": "2021-03-14T10:07:23.234660Z"
},
"id": "KbeKSV7uzGPP"
},
@@ -1620,7 +2443,7 @@
" per_device_train_batch_size=6,\n",
" gradient_accumulation_steps=2,\n",
" evaluation_strategy=\"steps\",\n",
- " num_train_epochs=30,\n",
+ " num_train_epochs=60,\n",
" fp16=True,\n",
" save_steps=400,\n",
" eval_steps=400,\n",
@@ -1642,11 +2465,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 36,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:17:12.452612Z",
- "start_time": "2021-03-13T19:17:10.702247Z"
+ "end_time": "2021-03-14T10:07:24.993889Z",
+ "start_time": "2021-03-14T10:07:23.305173Z"
},
"id": "rY7vBmFCPFgC"
},
@@ -1717,11 +2540,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 37,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T18:59:31.434948Z",
- "start_time": "2021-03-13T18:12:38.586326Z"
+ "end_time": "2021-03-14T17:43:05.019454Z",
+ "start_time": "2021-03-14T10:07:25.037347Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
@@ -1731,7 +2554,434 @@
"outputId": "2e23b190-ca76-48ad-8117-376d1d7c058e",
"scrolled": true
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:131: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
+ " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ " [18600/18600 7:32:53, Epoch 60/60]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ " Wer | \n",
+ " Runtime | \n",
+ " Samples Per Second | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 400 | \n",
+ " 7.193600 | \n",
+ " 3.278684 | \n",
+ " 1.000000 | \n",
+ " 123.640600 | \n",
+ " 12.310000 | \n",
+ "
\n",
+ " \n",
+ " 800 | \n",
+ " 3.198500 | \n",
+ " 2.831500 | \n",
+ " 0.996669 | \n",
+ " 121.957800 | \n",
+ " 12.480000 | \n",
+ "
\n",
+ " \n",
+ " 1200 | \n",
+ " 1.156600 | \n",
+ " 0.779333 | \n",
+ " 0.772585 | \n",
+ " 113.901900 | \n",
+ " 13.362000 | \n",
+ "
\n",
+ " \n",
+ " 1600 | \n",
+ " 0.529700 | \n",
+ " 0.647442 | \n",
+ " 0.675179 | \n",
+ " 122.788700 | \n",
+ " 12.395000 | \n",
+ "
\n",
+ " \n",
+ " 2000 | \n",
+ " 0.372500 | \n",
+ " 0.567449 | \n",
+ " 0.607348 | \n",
+ " 114.373900 | \n",
+ " 13.307000 | \n",
+ "
\n",
+ " \n",
+ " 2400 | \n",
+ " 0.287700 | \n",
+ " 0.563049 | \n",
+ " 0.562128 | \n",
+ " 114.404000 | \n",
+ " 13.304000 | \n",
+ "
\n",
+ " \n",
+ " 2800 | \n",
+ " 0.234200 | \n",
+ " 0.547992 | \n",
+ " 0.540325 | \n",
+ " 123.026600 | \n",
+ " 12.371000 | \n",
+ "
\n",
+ " \n",
+ " 3200 | \n",
+ " 0.196700 | \n",
+ " 0.627634 | \n",
+ " 0.557787 | \n",
+ " 114.868300 | \n",
+ " 13.250000 | \n",
+ "
\n",
+ " \n",
+ " 3600 | \n",
+ " 0.172900 | \n",
+ " 0.536431 | \n",
+ " 0.533259 | \n",
+ " 114.600900 | \n",
+ " 13.281000 | \n",
+ "
\n",
+ " \n",
+ " 4000 | \n",
+ " 0.160700 | \n",
+ " 0.510704 | \n",
+ " 0.495306 | \n",
+ " 115.385200 | \n",
+ " 13.191000 | \n",
+ "
\n",
+ " \n",
+ " 4400 | \n",
+ " 0.138100 | \n",
+ " 0.592979 | \n",
+ " 0.522055 | \n",
+ " 123.419400 | \n",
+ " 12.332000 | \n",
+ "
\n",
+ " \n",
+ " 4800 | \n",
+ " 0.126100 | \n",
+ " 0.663255 | \n",
+ " 0.530534 | \n",
+ " 124.109600 | \n",
+ " 12.263000 | \n",
+ "
\n",
+ " \n",
+ " 5200 | \n",
+ " 0.121600 | \n",
+ " 0.611431 | \n",
+ " 0.508428 | \n",
+ " 116.885800 | \n",
+ " 13.021000 | \n",
+ "
\n",
+ " \n",
+ " 5600 | \n",
+ " 0.111900 | \n",
+ " 0.609117 | \n",
+ " 0.504189 | \n",
+ " 123.424400 | \n",
+ " 12.331000 | \n",
+ "
\n",
+ " \n",
+ " 6000 | \n",
+ " 0.107000 | \n",
+ " 0.581159 | \n",
+ " 0.495811 | \n",
+ " 131.865500 | \n",
+ " 11.542000 | \n",
+ "
\n",
+ " \n",
+ " 6400 | \n",
+ " 0.098500 | \n",
+ " 0.653878 | \n",
+ " 0.508227 | \n",
+ " 132.657900 | \n",
+ " 11.473000 | \n",
+ "
\n",
+ " \n",
+ " 6800 | \n",
+ " 0.095900 | \n",
+ " 0.602002 | \n",
+ " 0.490663 | \n",
+ " 132.224900 | \n",
+ " 11.511000 | \n",
+ "
\n",
+ " \n",
+ " 7200 | \n",
+ " 0.089900 | \n",
+ " 0.600693 | \n",
+ " 0.488140 | \n",
+ " 115.003200 | \n",
+ " 13.234000 | \n",
+ "
\n",
+ " \n",
+ " 7600 | \n",
+ " 0.086700 | \n",
+ " 0.592105 | \n",
+ " 0.487635 | \n",
+ " 115.062900 | \n",
+ " 13.228000 | \n",
+ "
\n",
+ " \n",
+ " 8000 | \n",
+ " 0.082500 | \n",
+ " 0.615173 | \n",
+ " 0.493792 | \n",
+ " 115.057900 | \n",
+ " 13.228000 | \n",
+ "
\n",
+ " \n",
+ " 8400 | \n",
+ " 0.076800 | \n",
+ " 0.608913 | \n",
+ " 0.477238 | \n",
+ " 123.812500 | \n",
+ " 12.293000 | \n",
+ "
\n",
+ " \n",
+ " 8800 | \n",
+ " 0.069800 | \n",
+ " 0.618013 | \n",
+ " 0.473302 | \n",
+ " 123.652500 | \n",
+ " 12.309000 | \n",
+ "
\n",
+ " \n",
+ " 9200 | \n",
+ " 0.075300 | \n",
+ " 0.621384 | \n",
+ " 0.486626 | \n",
+ " 115.310600 | \n",
+ " 13.199000 | \n",
+ "
\n",
+ " \n",
+ " 9600 | \n",
+ " 0.067400 | \n",
+ " 0.638355 | \n",
+ " 0.481377 | \n",
+ " 123.632300 | \n",
+ " 12.311000 | \n",
+ "
\n",
+ " \n",
+ " 10000 | \n",
+ " 0.061700 | \n",
+ " 0.668408 | \n",
+ " 0.476734 | \n",
+ " 122.694800 | \n",
+ " 12.405000 | \n",
+ "
\n",
+ " \n",
+ " 10400 | \n",
+ " 0.059000 | \n",
+ " 0.654593 | \n",
+ " 0.467346 | \n",
+ " 115.387700 | \n",
+ " 13.190000 | \n",
+ "
\n",
+ " \n",
+ " 10800 | \n",
+ " 0.059800 | \n",
+ " 0.636886 | \n",
+ " 0.466842 | \n",
+ " 115.148800 | \n",
+ " 13.218000 | \n",
+ "
\n",
+ " \n",
+ " 11200 | \n",
+ " 0.055800 | \n",
+ " 0.646353 | \n",
+ " 0.465630 | \n",
+ " 115.132300 | \n",
+ " 13.220000 | \n",
+ "
\n",
+ " \n",
+ " 11600 | \n",
+ " 0.056800 | \n",
+ " 0.606136 | \n",
+ " 0.469668 | \n",
+ " 115.484500 | \n",
+ " 13.179000 | \n",
+ "
\n",
+ " \n",
+ " 12000 | \n",
+ " 0.050000 | \n",
+ " 0.606987 | \n",
+ " 0.463813 | \n",
+ " 115.399700 | \n",
+ " 13.189000 | \n",
+ "
\n",
+ " \n",
+ " 12400 | \n",
+ " 0.048900 | \n",
+ " 0.643559 | \n",
+ " 0.454931 | \n",
+ " 115.288900 | \n",
+ " 13.202000 | \n",
+ "
\n",
+ " \n",
+ " 12800 | \n",
+ " 0.050800 | \n",
+ " 0.637720 | \n",
+ " 0.454931 | \n",
+ " 115.713100 | \n",
+ " 13.153000 | \n",
+ "
\n",
+ " \n",
+ " 13200 | \n",
+ " 0.047500 | \n",
+ " 0.629017 | \n",
+ " 0.458363 | \n",
+ " 115.546300 | \n",
+ " 13.172000 | \n",
+ "
\n",
+ " \n",
+ " 13600 | \n",
+ " 0.041500 | \n",
+ " 0.669488 | \n",
+ " 0.452912 | \n",
+ " 115.620100 | \n",
+ " 13.164000 | \n",
+ "
\n",
+ " \n",
+ " 14000 | \n",
+ " 0.043700 | \n",
+ " 0.592979 | \n",
+ " 0.445947 | \n",
+ " 115.604300 | \n",
+ " 13.166000 | \n",
+ "
\n",
+ " \n",
+ " 14400 | \n",
+ " 0.040500 | \n",
+ " 0.597802 | \n",
+ " 0.448572 | \n",
+ " 115.567400 | \n",
+ " 13.170000 | \n",
+ "
\n",
+ " \n",
+ " 14800 | \n",
+ " 0.037300 | \n",
+ " 0.616714 | \n",
+ " 0.448774 | \n",
+ " 129.269600 | \n",
+ " 11.774000 | \n",
+ "
\n",
+ " \n",
+ " 15200 | \n",
+ " 0.035100 | \n",
+ " 0.616041 | \n",
+ " 0.441708 | \n",
+ " 132.656900 | \n",
+ " 11.473000 | \n",
+ "
\n",
+ " \n",
+ " 15600 | \n",
+ " 0.037300 | \n",
+ " 0.586855 | \n",
+ " 0.438579 | \n",
+ " 116.875400 | \n",
+ " 13.022000 | \n",
+ "
\n",
+ " \n",
+ " 16000 | \n",
+ " 0.034600 | \n",
+ " 0.619885 | \n",
+ " 0.435752 | \n",
+ " 124.183200 | \n",
+ " 12.256000 | \n",
+ "
\n",
+ " \n",
+ " 16400 | \n",
+ " 0.032800 | \n",
+ " 0.600389 | \n",
+ " 0.439992 | \n",
+ " 124.377300 | \n",
+ " 12.237000 | \n",
+ "
\n",
+ " \n",
+ " 16800 | \n",
+ " 0.031500 | \n",
+ " 0.608220 | \n",
+ " 0.437468 | \n",
+ " 124.193900 | \n",
+ " 12.255000 | \n",
+ "
\n",
+ " \n",
+ " 17200 | \n",
+ " 0.031200 | \n",
+ " 0.615735 | \n",
+ " 0.434642 | \n",
+ " 124.173200 | \n",
+ " 12.257000 | \n",
+ "
\n",
+ " \n",
+ " 17600 | \n",
+ " 0.031000 | \n",
+ " 0.611275 | \n",
+ " 0.430302 | \n",
+ " 125.763500 | \n",
+ " 12.102000 | \n",
+ "
\n",
+ " \n",
+ " 18000 | \n",
+ " 0.029600 | \n",
+ " 0.603103 | \n",
+ " 0.428889 | \n",
+ " 125.680000 | \n",
+ " 12.110000 | \n",
+ "
\n",
+ " \n",
+ " 18400 | \n",
+ " 0.028700 | \n",
+ " 0.606192 | \n",
+ " 0.428687 | \n",
+ " 124.510200 | \n",
+ " 12.224000 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=18600, training_loss=0.34002172352165305, metrics={'train_runtime': 27175.8808, 'train_samples_per_second': 0.684, 'total_flos': 2.8460932280012886e+19, 'epoch': 60.0, 'init_mem_cpu_alloc_delta': 8143405, 'init_mem_gpu_alloc_delta': 1261972480, 'init_mem_cpu_peaked_delta': 18258, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 1358069, 'train_mem_gpu_alloc_delta': 3779249152, 'train_mem_cpu_peaked_delta': 183666894, 'train_mem_gpu_peaked_delta': 1681072128})"
+ ]
+ },
+ "execution_count": 37,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"trainer.train()"
]
@@ -1750,8 +3000,7 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:20:42.505632Z",
- "start_time": "2021-03-13T19:20:36.981206Z"
+ "start_time": "2021-03-14T10:06:00.880Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
@@ -1761,7 +3010,7 @@
},
"outputs": [],
"source": [
- "model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-9200/\").to(\"cuda\")\n",
+ "model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-18400/\").to(\"cuda\")\n",
"processor = Wav2Vec2Processor.from_pretrained(\"wav2vec2-large-xlsr-greek/\")"
]
},
@@ -1771,16 +3020,16 @@
"id": "QsfGCQYSvY8C"
},
"source": [
- "Now, we will just take the first example of the test set, run it through the model and take the `argmax(...)` of the logits to retrieve the predicted token ids."
+ "Now, we will just take a random example of the test set, run it through the model and take the `argmax(...)` of the logits to retrieve the predicted token ids."
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 61,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:21:08.891118Z",
- "start_time": "2021-03-13T19:20:48.016164Z"
+ "end_time": "2021-03-14T17:50:43.012433Z",
+ "start_time": "2021-03-14T17:50:20.994866Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
@@ -1790,7 +3039,7 @@
},
"outputs": [],
"source": [
- "input_dict = processor(common_voice_test[\"input_values\"][42], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n",
+ "input_dict = processor(common_voice_test[\"input_values\"][345], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n",
"\n",
"logits = model(input_dict.input_values.to(\"cuda\")).logits\n",
"\n",
@@ -1808,11 +3057,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 51,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:21:18.172643Z",
- "start_time": "2021-03-13T19:21:14.933617Z"
+ "end_time": "2021-03-14T17:47:32.341035Z",
+ "start_time": "2021-03-14T17:47:29.882908Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
@@ -1820,7 +3069,16 @@
"id": "8dPE2GRIgtx-",
"outputId": "a211d1ee-d850-481d-8bac-dc46c3efa561"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using custom data configuration el-ac779bf2c9f7c09b\n",
+ "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n"
+ ]
+ }
+ ],
"source": [
"common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")"
]
@@ -1836,11 +3094,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 62,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-03-13T19:21:18.351131Z",
- "start_time": "2021-03-13T19:21:18.347466Z"
+ "end_time": "2021-03-14T17:50:43.087361Z",
+ "start_time": "2021-03-14T17:50:43.083839Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
@@ -1848,13 +3106,25 @@
"id": "Phqxa1O1jMDk",
"outputId": "60d48c9f-f745-45ac-9105-446dc71025ca"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prediction:\n",
+ "και τι να δούμε\n",
+ "\n",
+ "Reference:\n",
+ "και τι να δούμε;\n"
+ ]
+ }
+ ],
"source": [
"print(\"Prediction:\")\n",
"print(processor.decode(pred_ids[0]))\n",
"\n",
"print(\"\\nReference:\")\n",
- "print(common_voice_test_transcription[\"sentence\"][42].lower())\n"
+ "print(common_voice_test_transcription[\"sentence\"][345].lower())\n"
]
},
{
@@ -1871,16 +3141,9 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2021-03-13T19:04:51.390742Z",
- "start_time": "2021-03-13T19:04:51.384593Z"
- }
- },
+ "metadata": {},
"outputs": [],
- "source": [
- "print(common_voice_test_transcription[\"sentence\"][42].lower())\n"
- ]
+ "source": []
}
],
"metadata": {