dnnsdunca commited on
Commit
774e73a
·
verified ·
1 Parent(s): a5dd61d

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +46 -0
main.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
2
+ from dataset import MyDataset
3
+ from data_collator import MyDataCollator
4
+
5
+ # Set hyperparameters
6
+ model_name = 'bert-base-uncased'
7
+ batch_size = 16
8
+ num_epochs = 3
9
+
10
+ # Load data
11
+ train_data = MyDataset('train.csv', AutoTokenizer.from_pretrained(model_name))
12
+ val_data = MyDataset('val.csv', AutoTokenizer.from_pretrained(model_name))
13
+
14
+ # Create data collator
15
+ data_collator = MyDataCollator(AutoTokenizer.from_pretrained(model_name))
16
+
17
+ # Create model
18
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
19
+
20
+ # Create training arguments
21
+ training_args = TrainingArguments(
22
+ output_dir='./results',
23
+ num_train_epochs=num_epochs,
24
+ per_device_train_batch_size=batch_size,
25
+ per_device_eval_batch_size=batch_size,
26
+ evaluation_strategy='epoch',
27
+ save_total_limit=2,
28
+ save_steps=500,
29
+ load_best_model_at_end=True,
30
+ metric_for_best_model='accuracy',
31
+ greater_is_better=True,
32
+ save_on_each_node=True,
33
+ )
34
+
35
+ # Create trainer
36
+ trainer = Trainer(
37
+ model=model,
38
+ args=training_args,
39
+ train_dataset=train_data,
40
+ eval_dataset=val_data,
41
+ compute_metrics=lambda pred: {'accuracy': torch.sum(torch.argmax(pred.label_ids, dim=1) == torch.argmax(pred.predictions, dim=1))},
42
+ data_collator=data_collator,
43
+ )
44
+
45
+ # Train model
46
+ trainer.train()