JanSt commited on
Commit
beefda1
1 Parent(s): b8403c9

Upload balanced_train_full.py

Browse files
Files changed (1) hide show
  1. balanced_train_full.py +98 -0
balanced_train_full.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ df = pd.read_feather("//media/data/mbti-reddit/disprop_sample100k_total.feather") #change this to proper path
4
+ #'/content/drive/MyDrive/Colab Notebooks/clickbait_hold_X.csv'
5
+ df=df.drop(columns=['authors','subreddit'])
6
+
7
+ df=df.sample(80000, random_state=1) #random sampling
8
+
9
+
10
+ df['labels'] = df['labels'].replace(['INTP','ISTP','ENTP','ESTP','INFP','ISFP','ENFP','ESFP', \
11
+ 'INTJ','ISTJ','ENTJ','ESTJ','INFJ','ISFJ','ENFJ','ESFJ'], \
12
+ [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
13
+ df=df.rename(columns={'labels':'labels','comments':'text'})
14
+
15
+ from datasets import Dataset
16
+
17
+ dataset = Dataset.from_pandas(df)
18
+ dataset.shuffle(seed=27)
19
+ split_set = dataset.train_test_split(test_size=0.2)
20
+
21
+ from transformers import AlbertTokenizer, AlbertModel
22
+
23
+ tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
24
+
25
+ from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
26
+
27
+ model = AutoModelForSequenceClassification.from_pretrained("albert-base-v2", num_labels=16)
28
+
29
+
30
+ def preprocess_function(examples):
31
+ return tokenizer(examples["text"], truncation=True)
32
+
33
+ tokenized_dataset = split_set.map(preprocess_function, batched=True)
34
+
35
+
36
+ from transformers import DataCollatorWithPadding
37
+ #tokenized_datasets = tokenized_datasets.remove_columns(books_dataset["train"].column_names)
38
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
39
+
40
+
41
+ import evaluate
42
+ import numpy as np
43
+ def compute_metrics(eval_preds):
44
+ metric = evaluate.combine([
45
+
46
+ evaluate.load("precision"),
47
+ evaluate.load("recall")])
48
+
49
+
50
+ #evaluate.load("precision", average="weighted"),
51
+ #evaluate.load("recall", average="weighted")])
52
+
53
+ logits, labels = eval_preds
54
+ predictions = np.argmax(logits, axis=-1)
55
+ return metric.compute(predictions=predictions, references=labels, average='weighted')
56
+
57
+
58
+ training_args = TrainingArguments(
59
+
60
+ evaluation_strategy="epoch",
61
+ #save_strategy="epoch",
62
+
63
+ output_dir="/home/deimann/mbti-project/balanced_train",
64
+
65
+ #save_total_limit=5,
66
+ #load_best_model_at_end = True,
67
+
68
+ learning_rate=2e-5,#2e
69
+
70
+ per_device_train_batch_size=36 ,#16
71
+
72
+ per_device_eval_batch_size=16,#16
73
+
74
+ num_train_epochs=10,
75
+
76
+ weight_decay=0.01,
77
+
78
+ )
79
+
80
+ trainer = Trainer(
81
+
82
+ model=model,
83
+
84
+ args=training_args,
85
+
86
+ train_dataset=tokenized_dataset["train"],
87
+
88
+ eval_dataset=tokenized_dataset["test"],
89
+
90
+ tokenizer=tokenizer,
91
+
92
+ data_collator=data_collator,
93
+
94
+ #compute_metrics=compute_metrics,
95
+
96
+ )
97
+
98
+ trainer.train()