{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "1df3c609-62a6-49c3-bcc6-29c520f9501c", "metadata": {}, "outputs": [], "source": [ "# Pretty print\n", "from pprint import pprint\n", "# Datasets load_dataset function\n", "from datasets import load_dataset\n", "# Transformers Autokenizer\n", "from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertTokenizerFast, Trainer, TrainingArguments, AdamW\n", "from torch.utils.data import DataLoader\n", "import torch" ] }, { "cell_type": "code", "execution_count": 3, "id": "58167c28-eb27-4f82-b7d0-8216dbeaf650", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset hupd (C:/Users/calia/.cache/huggingface/datasets/HUPD___hupd/sample-5094df4de61ed3bc/0.0.0/6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a2f090474cb148548ce3eb73698fcc6c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00\n", " \n", " \n", " [2020/2020 11:47, Epoch 2/2]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
100.692000
200.685100
300.684000
400.685100
500.678400
600.687300
700.681900
800.691100
900.683200
1000.694100
1100.673300
1200.694100
1300.669500
1400.691100
1500.683400
1600.654900
1700.684300
1800.679300
1900.662600
2000.598400
2100.717700
2200.679100
2300.677500
2400.668800
2500.678100
2600.657500
2700.707200
2800.670300
2900.659900
3000.633300
3100.676300
3200.684800
3300.673100
3400.670500
3500.657500
3600.618100
3700.670000
3800.607400
3900.656200
4000.700000
4100.644800
4200.682800
4300.668800
4400.662600
4500.647700
4600.688600
4700.682400
4800.642900
4900.726900
5000.660400
5100.649500
5200.637200
5300.669700
5400.667100
5500.617000
5600.725300
5700.656800
5800.664600
5900.702600
6000.686300
6100.668400
6200.648200
6300.628700
6400.676700
6500.652400
6600.654300
6700.640800
6800.672000
6900.636100
7000.689100
7100.691100
7200.650300
7300.655200
7400.668400
7500.659200
7600.647800
7700.662800
7800.648500
7900.656700
8000.669400
8100.607800
8200.683200
8300.663800
8400.700900
8500.648200
8600.619400
8700.649200
8800.717500
8900.669600
9000.669700
9100.683900
9200.636900
9300.656400
9400.650000
9500.617800
9600.665600
9700.642700
9800.644000
9900.688900
10000.654700
10100.645800
10200.609200
10300.602300
10400.618800
10500.643500
10600.611000
10700.645000
10800.641000
10900.595400
11000.635100
11100.611600
11200.600300
11300.618100
11400.617200
11500.633400
11600.597600
11700.619400
11800.584200
11900.600700
12000.657400
12100.569600
12200.575500
12300.617900
12400.610300
12500.570600
12600.545700
12700.656300
12800.554700
12900.598200
13000.606300
13100.600500
13200.569800
13300.604700
13400.628300
13500.602700
13600.583700
13700.623800
13800.670300
13900.622400
14000.590200
14100.587000
14200.555500
14300.561000
14400.514300
14500.553100
14600.692400
14700.605200
14800.548000
14900.672600
15000.531100
15100.610600
15200.580200
15300.571300
15400.644400
15500.558500
15600.624000
15700.659200
15800.580500
15900.649900
16000.608700
16100.595100
16200.592900
16300.584000
16400.607100
16500.565800
16600.568300
16700.572200
16800.597500
16900.602700
17000.692900
17100.597900
17200.538600
17300.599400
17400.704300
17500.580500
17600.595600
17700.583100
17800.569500
17900.603300
18000.564500
18100.592100
18200.617000
18300.656500
18400.563600
18500.624800
18600.686700
18700.572300
18800.587700
18900.583000
19000.601500
19100.559700
19200.610100
19300.571300
19400.549900
19500.589200
19600.634800
19700.584200
19800.557000
19900.602700
20000.669700
20100.607500
20200.631800

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Saving model checkpoint to ./results/checkpoint-500\n", "Configuration saved in ./results/checkpoint-500\\config.json\n", "Model weights saved in ./results/checkpoint-500\\pytorch_model.bin\n", "Saving model checkpoint to ./results/checkpoint-1000\n", "Configuration saved in ./results/checkpoint-1000\\config.json\n", "Model weights saved in ./results/checkpoint-1000\\pytorch_model.bin\n", "Saving model checkpoint to ./results/checkpoint-1500\n", "Configuration saved in ./results/checkpoint-1500\\config.json\n", "Model weights saved in ./results/checkpoint-1500\\pytorch_model.bin\n", "Saving model checkpoint to ./results/checkpoint-2000\n", "Configuration saved in ./results/checkpoint-2000\\config.json\n", "Model weights saved in ./results/checkpoint-2000\\pytorch_model.bin\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=2020, training_loss=0.6342116433795136, metrics={'train_runtime': 708.5025, 'train_samples_per_second': 45.598, 'train_steps_per_second': 2.851, 'total_flos': 4279491780980736.0, 'train_loss': 0.6342116433795136, 'epoch': 2.0})" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }