{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3ece795d", "metadata": { "cellId": "icbn5fcdkdjmv2xo6f1uym" }, "outputs": [], "source": [ "#!g1.1\n", "from sklearn.preprocessing import LabelEncoder\n", "import transformers\n", "import torch\n", "import nltk\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "id": "2383e35c", "metadata": { "cellId": "r7277d47zkhjj04zr4od8g" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a3fbc0c0072b4198bb84d870b39a6c74", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1202.0, style=ProgressStyle(description…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading and preparing dataset json/default-71bc0cd49f840871 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /tmp/xdg_cache/huggingface/datasets/json/default-71bc0cd49f840871/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f8063733bbb9475babf7469daf6e7d56", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f3df0e8c4d2a48968429ac5320020316", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "71a715a43bcd4f4a859c247b1f375e51", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Dataset json downloaded and prepared to /tmp/xdg_cache/huggingface/datasets/json/default-71bc0cd49f840871/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514. Subsequent calls will reuse this data.\n" ] } ], "source": [ "#!g1.1\n", "from datasets import load_dataset\n", "\n", "dataset_train_test_val = load_dataset('json', \n", " data_files={'train': 'train_dataset.json', 'test': 'test_dataset.json', 'val': 'val_dataset.json'})" ] }, { "cell_type": "code", "execution_count": 3, "id": "5affcf2d", "metadata": { "cellId": "d3dqrbyaerahlxtoqhusl" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['labels', 'input_ids', 'attention_mask'],\n", " num_rows: 44928\n", " })\n", " test: Dataset({\n", " features: ['labels', 'input_ids', 'attention_mask'],\n", " num_rows: 11981\n", " })\n", " val: Dataset({\n", " features: ['labels', 'input_ids', 'attention_mask'],\n", " num_rows: 14976\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#!g1.1\n", "dataset_train_test_val" ] }, { "cell_type": "code", "execution_count": 4, "id": "1a1956c6", "metadata": { "cellId": "iv6a51fd9tlbrs4he3kizo" }, "outputs": [], "source": [ "#!g1.1\n", "train_dataset = dataset_train_test_val['train']\n", "val_dataset = dataset_train_test_val['val']\n", "test_dataset = dataset_train_test_val['test']" ] }, { "cell_type": "code", "execution_count": 5, "id": "c161630b", "metadata": { "cellId": "t9fridyqfq20q78rkgitt" }, "outputs": [], "source": [ "#!g1.1\n", "train_dataset.set_format(\"torch\")\n", "val_dataset.set_format(\"torch\")\n", "test_dataset.set_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "7ee3ce1c", "metadata": { "cellId": "1y1jaan8t8gjs3masmvulu" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4144d1c375104f64a4376b44dc68167a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1248.0, style=ProgressStyle(description…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "#!g1.1\n", "from datasets import load_metric\n", "\n", "metric = load_metric(\"accuracy\")\n", "def compute_metrics(eval_pred):\n", " logits, labels = eval_pred\n", " predictions = np.argmax(logits, axis=-1)\n", " return metric.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c5d12dc8", "metadata": { "cellId": "6eds6is9lek1hcs87cizgy" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4de02bce2bd448efa4d6e7c1e02c427a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=483.0, style=ProgressStyle(description_…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e8f550b59f4b418094cbcb1d13c5dd97", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "#!g1.1\n", "from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=8)\n", "model = model.to(device)\n", "\n", "trainer = Trainer(\n", " model=model, \n", " train_dataset=train_dataset, \n", " eval_dataset=val_dataset,\n", " compute_metrics=compute_metrics,\n", " args=TrainingArguments(\n", " output_dir=\"./my_saved_model\", overwrite_output_dir=True,\n", " num_train_epochs=4, per_device_train_batch_size=32,\n", " save_steps=10000, save_total_limit=2),\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "59b4c995", "metadata": { "cellId": "enykeyqh04h85cnkvsnyvr" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "***** Running training *****\n", " Num examples = 44928\n", " Num Epochs = 4\n", " Instantaneous batch size per device = 32\n", " Total train batch size (w. parallel, distributed & accumulation) = 32\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 5616\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
500 | \n", "0.068200 | \n", "
1000 | \n", "0.065100 | \n", "
1500 | \n", "0.069500 | \n", "
2000 | \n", "0.064600 | \n", "
2500 | \n", "0.070400 | \n", "
3000 | \n", "0.069800 | \n", "
3500 | \n", "0.066200 | \n", "
4000 | \n", "0.070000 | \n", "
4500 | \n", "0.060200 | \n", "
5000 | \n", "0.064800 | \n", "
5500 | \n", "0.072600 | \n", "
"
],
"text/plain": [
"