{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Ноутбук с обучением модели" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from transformers import AutoTokenizer\n", "from datasets import load_dataset\n", "from transformers import DataCollatorWithPadding, Trainer, TrainingArguments, AutoModelForSequenceClassification" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained('roberta-base')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-bd943fc5bb724360\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading and preparing dataset csv/default to /Users/seal/.cache/huggingface/datasets/csv/default-bd943fc5bb724360/0.0.0/51cce309a08df9c4d82ffd9363bbe090bf173197fc01a71b034e8594995a1a58...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "548860682d5749cd979a64b49275ea32", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data files: 0%| | 0/2 [00:00 at 0x1641221f0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "484e702a879448a3b09e2c7a2d672407", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/11883 [00:00\n", " \n", " \n", " [4458/4458 20:39:50, Epoch 3/3]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5001.356300
10001.088700
15001.026400
20000.897900
25000.835500
30000.795100
35000.684000
40000.641000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Saving model checkpoint to logs/checkpoint-500\n", "Configuration saved in logs/checkpoint-500/config.json\n", "Model weights saved in logs/checkpoint-500/pytorch_model.bin\n", "Saving model checkpoint to logs/checkpoint-1000\n", "Configuration saved in logs/checkpoint-1000/config.json\n", "Model weights saved in logs/checkpoint-1000/pytorch_model.bin\n", "Saving model checkpoint to logs/checkpoint-1500\n", "Configuration saved in logs/checkpoint-1500/config.json\n", "Model weights saved in logs/checkpoint-1500/pytorch_model.bin\n", "Saving model checkpoint to logs/checkpoint-2000\n", "Configuration saved in logs/checkpoint-2000/config.json\n", "Model weights saved in logs/checkpoint-2000/pytorch_model.bin\n", "Saving model checkpoint to logs/checkpoint-2500\n", "Configuration saved in logs/checkpoint-2500/config.json\n", "Model weights saved in logs/checkpoint-2500/pytorch_model.bin\n", "Saving model checkpoint to logs/checkpoint-3000\n", "Configuration saved in logs/checkpoint-3000/config.json\n", "Model weights saved in logs/checkpoint-3000/pytorch_model.bin\n", "Saving model checkpoint to logs/checkpoint-3500\n", "Configuration saved in logs/checkpoint-3500/config.json\n", "Model weights saved in logs/checkpoint-3500/pytorch_model.bin\n", "Saving model checkpoint to logs/checkpoint-4000\n", "Configuration saved in logs/checkpoint-4000/config.json\n", "Model weights saved in logs/checkpoint-4000/pytorch_model.bin\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=4458, training_loss=0.8849190480509057, metrics={'train_runtime': 74409.2364, 'train_samples_per_second': 0.479, 'train_steps_per_second': 0.06, 'total_flos': 4731926386857384.0, 'train_loss': 0.8849190480509057, 'epoch': 3.0})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = Trainer(\n", " model, \n", " args,\n", " train_dataset=data_hf_tokenized[\"train\"],\n", " eval_dataset=data_hf_tokenized[\"test\"],\n", " data_collator=DataCollatorWithPadding(tokenizer),\n", ")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "***** Running Prediction *****\n", " Num examples = 2097\n", " Batch size = 8\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [263/263 21:44]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "predictions = trainer.predict(data_hf_tokenized[\"test\"])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Saving model checkpoint to model_roberta_trained\n", "Configuration saved in model_roberta_trained/config.json\n", "Model weights saved in model_roberta_trained/pytorch_model.bin\n" ] } ], "source": [ "trainer.save_model('model_roberta_trained')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Обучила ещё distilbert в коллабе, но там результат немного хуже.\n", "Пользовалась этим:\n", "https://github.com/ThilinaRajapakse/pytorch-transformers-classification\n", "https://github.com/huggingface/transformers/tree/main/src/transformers" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }