TomSmail commited on
Commit
4c24cb9
1 Parent(s): 656f752

feat: writes basic pipeline for psychologist model

Browse files
Files changed (1) hide show
  1. psy.py +38 -3
psy.py CHANGED
@@ -1,5 +1,8 @@
1
  from datasets import load_dataset
2
- from transformers import AutoTokenizer
 
 
 
3
 
4
  DATA_SEED = 9843203
5
  QUICK_TEST = True
@@ -17,5 +20,37 @@ tokenised_dataset = dataset.map(tokenise_function, batched=True)
17
 
18
 
19
  # Different sized datasets will allow for different training times
20
- train_dataset = tokenized_datasets["train"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenized_datasets["train"].shuffle(seed=DATA_SEED)
21
- test_dataset = tokenized_datasets["test"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenized_datasets["test"].shuffle(seed=DATA_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
3
+ import numpy as np
4
+ import evaluate
5
+
6
 
7
  DATA_SEED = 9843203
8
  QUICK_TEST = True
 
20
 
21
 
22
  # Different sized datasets will allow for different training times
23
+ train_dataset = tokenised_datasets["train"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["train"].shuffle(seed=DATA_SEED)
24
+ test_dataset = tokenised_datasets["test"].shuffle(seed=DATA_SEED).select(range(1000)) if QUICK_TEST else tokenised_datasets["test"].shuffle(seed=DATA_SEED)
25
+
26
+
27
+ # Each of our Mtbi types has a specific label here
28
+ model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Meta-Llama-3-8B", num_labels=16)
29
+
30
+ # Using default hyperparameters at the moment
31
+ training_args = TrainingArguments(output_dir="test_trainer")
32
+
33
+ # A default metric for checking accuracy
34
+ metric = evaluate.load("accuracy")
35
+
36
+ def compute_metrics(eval_pred):
37
+ logits, labels = eval_pred
38
+ predictions = np.argmax(logits, axis=-1)
39
+ return metric.compute(predictions=predictions, references=labels)
40
+
41
+ # Extract arguments from training
42
+ training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
43
+
44
+ # Builds a training object using previously defined data
45
+ trainer = Trainer(
46
+ model=model,
47
+ args=training_args,
48
+ train_dataset=train_dataset,
49
+ eval_dataset=test_dataset,
50
+ compute_metrics=compute_metrics,
51
+ )
52
+
53
+ # Finally, fine-tune!
54
+ if __name__ == "__main__":
55
+ trainer.train()
56
+