{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train IMDb Classifier\n", "> Train a IMDb classifier with DistilBERT." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!huggingface-cli login" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load IMDb dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset, load_metric" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds = load_dataset(\"imdb\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 25000\n", " })\n", " test: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 25000\n", " })\n", " unsupervised: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 50000\n", " })\n", "})" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'label': ClassLabel(num_classes=2, names=['neg', 'pos'], names_file=None, id=None),\n", " 'text': Value(dtype='string', id=None)}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds['train'].features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Pretrained DistilBERT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", "\n", "model_name = \"distilbert-base-uncased\"\n", "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepocess Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6ddef2e0d4a04e12ad7513950158236c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/25 [00:00