{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "IAGKskIWS9C0" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "import numpy as np\n", "import evaluate\n", "\n", "\n", "DATA_SEED = 9843203\n", "QUICK_TEST = True\n", "\n", "# This is our baseline dataset\n", "dataset = load_dataset(\"ClaudiaRichard/mbti_classification_v2\")\n", "\n", "# LLama3 8b\n", "tokeniser = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n", "\n", "def tokenise_function(examples):\n", " return tokeniser(examples[\"text\"], padding=\"max_length\", truncation=True)\n", "\n", "tokenised_dataset = dataset.map(tokenise_function, batched=True)\n", "\n", "\n", "# Different sized datasets will allow for different training times\n", "train_dataset = tokenised_datasets[\"train\"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets[\"train\"].shuffle(seed=DATA_SEED)\n", "test_dataset = tokenised_datasets[\"test\"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets[\"test\"].shuffle(seed=DATA_SEED)\n", "\n", "\n", "# Each of our Mtbi types has a specific label here\n", "model = AutoModelForSequenceClassification.from_pretrained(\"meta-llama/Meta-Llama-3-8B\", num_labels=16)\n", "\n", "# Using default hyperparameters at the moment\n", "training_args = TrainingArguments(output_dir=\"test_trainer\")\n", "\n", "# A default metric for checking accuracy\n", "metric = evaluate.load(\"accuracy\")\n", "\n", "def compute_metrics(eval_pred):\n", " logits, labels = eval_pred\n", " predictions = np.argmax(logits, axis=-1)\n", " return metric.compute(predictions=predictions, references=labels)\n", "\n", "# Extract arguments from training\n", "training_args = TrainingArguments(output_dir=\"test_trainer\", evaluation_strategy=\"epoch\")\n", "\n", "# Builds a training object using previously defined data\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=test_dataset,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "# Finally, fine-tune!\n", "if __name__ == \"__main__\":\n", " trainer.train()" ] } ] }