Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer | |
import numpy as np | |
import evaluate | |
DATA_SEED = 9843203 | |
QUICK_TEST = True | |
# This is our baseline dataset | |
dataset = load_dataset("ClaudiaRichard/mbti_classification_v2") | |
# LLama3 8b | |
tokeniser = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") | |
def tokenise_function(examples): | |
return tokeniser(examples["text"], padding="max_length", truncation=True) | |
tokenised_dataset = dataset.map(tokenise_function, batched=True) | |
# Different sized datasets will allow for different training times | |
train_dataset = tokenised_datasets["train"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["train"].shuffle(seed=DATA_SEED) | |
test_dataset = tokenised_datasets["test"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["test"].shuffle(seed=DATA_SEED) | |
# Each of our Mtbi types has a specific label here | |
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Meta-Llama-3-8B", num_labels=16) | |
# Using default hyperparameters at the moment | |
training_args = TrainingArguments(output_dir="test_trainer") | |
# A default metric for checking accuracy | |
metric = evaluate.load("accuracy") | |
def compute_metrics(eval_pred): | |
logits, labels = eval_pred | |
predictions = np.argmax(logits, axis=-1) | |
return metric.compute(predictions=predictions, references=labels) | |
# Extract arguments from training | |
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch") | |
# Builds a training object using previously defined data | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=test_dataset, | |
compute_metrics=compute_metrics, | |
) | |
def train_model(): | |
trainer.train() | |
# Finally, fine-tune! | |
if __name__ == "__main__": | |
train_model() | |