{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer, DataCollatorWithPadding\n", "\n", "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n", "checkpoint = \"bert-base-uncased\"\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "def tokenize_function(example):\n", " return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n", "\n", "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "\n", "training_args = TrainingArguments(\n", " 'test-trainer',\n", " save_strategy='epoch',\n", " push_to_hub=True\n", ")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoModelForSequenceClassification\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "from transformers import Trainer\n", "\n", "trainer = Trainer(\n", " model,\n", " training_args,\n", " train_dataset=tokenized_datasets['train'],\n", " eval_dataset=tokenized_datasets['validation'],\n", " # data_collator=data_collator, THE DEFAULT DATACOLLATOR IS DataCollatorWithPadding\n", " tokenizer=tokenizer\n", ")" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
500 | \n", "0.528100 | \n", "
1000 | \n", "0.284700 | \n", "
"
],
"text/plain": [
"