File size: 1,936 Bytes
656f752
4c24cb9
 
 
 
656f752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c24cb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2b8c61
290ce4a
a2b8c61
4c24cb9
 
a2b8c61
4c24cb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import evaluate


DATA_SEED = 9843203
QUICK_TEST = True

# This is our baseline dataset
dataset = load_dataset("ClaudiaRichard/mbti_classification_v2")

# LLama3 8b
tokeniser = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

def tokenise_function(examples):
    return tokeniser(examples["text"], padding="max_length", truncation=True)

tokenised_dataset = dataset.map(tokenise_function, batched=True)


# Different sized datasets will allow for different training times
train_dataset = tokenised_datasets["train"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST  else  tokenised_datasets["train"].shuffle(seed=DATA_SEED)
test_dataset = tokenised_datasets["test"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["test"].shuffle(seed=DATA_SEED)


# Each of our Mtbi types has a specific label here
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Meta-Llama-3-8B", num_labels=16)

# Using default hyperparameters at the moment 
training_args = TrainingArguments(output_dir="test_trainer")

# A default metric for checking accuracy  
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# Extract arguments from training
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

# Builds a training object using previously defined data
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)


def train_model():
    trainer.train()
# Finally, fine-tune!
if __name__ == "__main__":
    train_model()