dnnsdunca commited on
Commit
718f9ec
1 Parent(s): be6e5a6

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +43 -0
train.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer
3
+ from datasets import load_dataset
4
+ import json
5
+
6
+ # Load configuration
7
+ with open('../config/config.json') as f:
8
+ config = json.load(f)
9
+
10
+ # Load dataset
11
+ dataset = load_dataset('csv', data_files={'train': '../data/train.csv', 'validation': '../data/valid.csv'})
12
+
13
+ # Load model and tokenizer
14
+ model = AutoModelForSequenceClassification.from_pretrained(config['model_name'], num_labels=config['num_labels'])
15
+ tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
16
+
17
+ # Tokenize dataset
18
+ def tokenize_function(examples):
19
+ return tokenizer(examples['text'], padding="max_length", truncation=True)
20
+
21
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
22
+
23
+ # Training arguments
24
+ training_args = TrainingArguments(
25
+ output_dir='./results',
26
+ learning_rate=config['learning_rate'],
27
+ per_device_train_batch_size=config['batch_size'],
28
+ num_train_epochs=config['num_epochs'],
29
+ evaluation_strategy="epoch",
30
+ save_strategy="epoch",
31
+ logging_dir='./logs'
32
+ )
33
+
34
+ trainer = Trainer(
35
+ model=model,
36
+ args=training_args,
37
+ train_dataset=tokenized_datasets['train'],
38
+ eval_dataset=tokenized_datasets['validation'],
39
+ tokenizer=tokenizer
40
+ )
41
+
42
+ trainer.train()
43
+ trainer.save_model('../model')