Spaces:
Runtime error
Runtime error
File size: 1,936 Bytes
656f752 4c24cb9 656f752 4c24cb9 a2b8c61 290ce4a a2b8c61 4c24cb9 a2b8c61 4c24cb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import evaluate
DATA_SEED = 9843203
QUICK_TEST = True
# This is our baseline dataset
dataset = load_dataset("ClaudiaRichard/mbti_classification_v2")
# LLama3 8b
tokeniser = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
def tokenise_function(examples):
return tokeniser(examples["text"], padding="max_length", truncation=True)
tokenised_dataset = dataset.map(tokenise_function, batched=True)
# Different sized datasets will allow for different training times
train_dataset = tokenised_datasets["train"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["train"].shuffle(seed=DATA_SEED)
test_dataset = tokenised_datasets["test"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["test"].shuffle(seed=DATA_SEED)
# Each of our Mtbi types has a specific label here
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Meta-Llama-3-8B", num_labels=16)
# Using default hyperparameters at the moment
training_args = TrainingArguments(output_dir="test_trainer")
# A default metric for checking accuracy
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Extract arguments from training
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
# Builds a training object using previously defined data
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=compute_metrics,
)
def train_model():
trainer.train()
# Finally, fine-tune!
if __name__ == "__main__":
train_model()
|