team3 / psy.py
TomSmail's picture
Update psy.py
290ce4a verified
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import evaluate
DATA_SEED = 9843203
QUICK_TEST = True
# This is our baseline dataset
dataset = load_dataset("ClaudiaRichard/mbti_classification_v2")
# LLama3 8b
tokeniser = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
def tokenise_function(examples):
return tokeniser(examples["text"], padding="max_length", truncation=True)
tokenised_dataset = dataset.map(tokenise_function, batched=True)
# Different sized datasets will allow for different training times
train_dataset = tokenised_datasets["train"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["train"].shuffle(seed=DATA_SEED)
test_dataset = tokenised_datasets["test"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["test"].shuffle(seed=DATA_SEED)
# Each of our Mtbi types has a specific label here
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Meta-Llama-3-8B", num_labels=16)
# Using default hyperparameters at the moment
training_args = TrainingArguments(output_dir="test_trainer")
# A default metric for checking accuracy
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Extract arguments from training
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
# Builds a training object using previously defined data
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=compute_metrics,
)
def train_model():
trainer.train()
# Finally, fine-tune!
if __name__ == "__main__":
train_model()