{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "80baea1a", "metadata": {}, "outputs": [], "source": [ "# 1 Prepate dataset\n", "# 2 Load pretrained Tokenizer, call it with dataset -> encoding\n", "# 3 Build PyTorch Dataset with encodings\n", "# 4 Load pretrained model\n", "# 5 a) Load Trainer and train it\n", "# b) or use native Pytorch training pipeline\n", "from pathlib import Path\n", "from sklearn.model_selection import train_test_split\n", "import torch\n", "from torch.utils.data import Dataset\n", "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n", "from transformers import Trainer, TrainingArguments\n", "\n", "model_name = \"distilbert-base-uncased\"\n", "\n", "def read_imdb_split(split_dir): # helper function to get text and label\n", " split_dir = Path(split_dir)\n", " texts = []\n", " labels = []\n", " for label_dir in [\"pos\", \"neg\"]:\n", " thres = 0\n", " for text_file in (split_dir/label_dir).iterdir():\n", " if thres < 100:\n", " f = open(text_file, encoding='utf8')\n", " texts.append(f.read())\n", " labels.append(0 if label_dir == \"neg\" else 1)\n", " thres += 1\n", "\n", " return texts, labels\n", "\n", "train_texts, train_labels = read_imdb_split(\"aclImdb/train\")\n", "test_texts, test_labels = read_imdb_split(\"aclImdb/test\")\n", "\n", "train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)\n", "\n", "\n", "class IMDBDataset(Dataset):\n", " def __init__(self, encodings, labels):\n", " self.encodings = encodings\n", " self.labels = labels\n", "\n", " def __getitem__(self, idx):\n", " item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n", " item[\"labels\"] = torch.tensor(self.labels[idx])\n", " return item\n", " \n", " def __len__(self):\n", " return len(self.labels)\n", " \n", "tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)\n", "\n", "train_encodings = tokenizer(train_texts, truncation=True, padding=True)\n", "val_encodings = tokenizer(val_texts, truncation=True, padding=True)\n", "test_encodings = tokenizer(test_texts, truncation=True, padding=True)\n", "\n", "train_dataset = IMDBDataset(train_encodings, train_labels)\n", "val_dataset = IMDBDataset(val_encodings, val_labels)\n", "test_dataset = IMDBDataset(test_encodings, test_labels)\n", "\n", "training_args = TrainingArguments(\n", " output_dir='./results',\n", " num_train_epochs=2,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=64,\n", " warmup_steps=500,\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " logging_dir='./logs',\n", " logging_steps=10\n", ")\n", "\n", "model = DistilBertForSequenceClassification.from_pretrained(model_name)\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset\n", ")\n", "\n", "trainer.train() \n", "\n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }