File size: 3,072 Bytes
a638d55
ea1b084
 
 
 
 
 
 
 
 
a638d55
 
ea1b084
a638d55
 
ea1b084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a638d55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IAGKskIWS9C0"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
        "import numpy as np\n",
        "import evaluate\n",
        "\n",
        "\n",
        "DATA_SEED = 9843203\n",
        "QUICK_TEST = True\n",
        "\n",
        "# This is our baseline dataset\n",
        "dataset = load_dataset(\"ClaudiaRichard/mbti_classification_v2\")\n",
        "\n",
        "# LLama3 8b\n",
        "tokeniser = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n",
        "\n",
        "def tokenise_function(examples):\n",
        "    return tokeniser(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
        "\n",
        "tokenised_dataset = dataset.map(tokenise_function, batched=True)\n",
        "\n",
        "\n",
        "# Different sized datasets will allow for different training times\n",
        "train_dataset = tokenised_datasets[\"train\"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST  else  tokenised_datasets[\"train\"].shuffle(seed=DATA_SEED)\n",
        "test_dataset = tokenised_datasets[\"test\"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets[\"test\"].shuffle(seed=DATA_SEED)\n",
        "\n",
        "\n",
        "# Each of our Mtbi types has a specific label here\n",
        "model = AutoModelForSequenceClassification.from_pretrained(\"meta-llama/Meta-Llama-3-8B\", num_labels=16)\n",
        "\n",
        "# Using default hyperparameters at the moment\n",
        "training_args = TrainingArguments(output_dir=\"test_trainer\")\n",
        "\n",
        "# A default metric for checking accuracy\n",
        "metric = evaluate.load(\"accuracy\")\n",
        "\n",
        "def compute_metrics(eval_pred):\n",
        "    logits, labels = eval_pred\n",
        "    predictions = np.argmax(logits, axis=-1)\n",
        "    return metric.compute(predictions=predictions, references=labels)\n",
        "\n",
        "# Extract arguments from training\n",
        "training_args = TrainingArguments(output_dir=\"test_trainer\", evaluation_strategy=\"epoch\")\n",
        "\n",
        "# Builds a training object using previously defined data\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=test_dataset,\n",
        "    compute_metrics=compute_metrics,\n",
        ")\n",
        "\n",
        "# Finally, fine-tune!\n",
        "if __name__ == \"__main__\":\n",
        "    trainer.train()"
      ]
    }
  ]
}