Spaces:
Runtime error
Runtime error
File size: 3,072 Bytes
a638d55 ea1b084 a638d55 ea1b084 a638d55 ea1b084 a638d55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IAGKskIWS9C0"
},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"import numpy as np\n",
"import evaluate\n",
"\n",
"\n",
"DATA_SEED = 9843203\n",
"QUICK_TEST = True\n",
"\n",
"# This is our baseline dataset\n",
"dataset = load_dataset(\"ClaudiaRichard/mbti_classification_v2\")\n",
"\n",
"# LLama3 8b\n",
"tokeniser = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n",
"\n",
"def tokenise_function(examples):\n",
" return tokeniser(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
"tokenised_dataset = dataset.map(tokenise_function, batched=True)\n",
"\n",
"\n",
"# Different sized datasets will allow for different training times\n",
"train_dataset = tokenised_datasets[\"train\"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets[\"train\"].shuffle(seed=DATA_SEED)\n",
"test_dataset = tokenised_datasets[\"test\"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets[\"test\"].shuffle(seed=DATA_SEED)\n",
"\n",
"\n",
"# Each of our Mtbi types has a specific label here\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"meta-llama/Meta-Llama-3-8B\", num_labels=16)\n",
"\n",
"# Using default hyperparameters at the moment\n",
"training_args = TrainingArguments(output_dir=\"test_trainer\")\n",
"\n",
"# A default metric for checking accuracy\n",
"metric = evaluate.load(\"accuracy\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)\n",
"\n",
"# Extract arguments from training\n",
"training_args = TrainingArguments(output_dir=\"test_trainer\", evaluation_strategy=\"epoch\")\n",
"\n",
"# Builds a training object using previously defined data\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=test_dataset,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"# Finally, fine-tune!\n",
"if __name__ == \"__main__\":\n",
" trainer.train()"
]
}
]
} |