Canstralian commited on
Commit
61560c5
·
verified ·
1 Parent(s): 2ce6627

Create fine_tune.py

Browse files
Files changed (1) hide show
  1. fine_tune.py +70 -0
fine_tune.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
+ import pandas as pd
5
+
6
+ class CustomDataset(Dataset):
7
+ def __init__(self, data, tokenizer, max_len):
8
+ self.data = data
9
+ self.tokenizer = tokenizer
10
+ self.max_len = max_len
11
+
12
+ def __len__(self):
13
+ return len(self.data)
14
+
15
+ def __getitem__(self, index):
16
+ row = self.data.iloc[index]
17
+ inputs = self.tokenizer.encode_plus(
18
+ row['text'],
19
+ add_special_tokens=True,
20
+ max_length=self.max_len,
21
+ padding='max_length',
22
+ return_attention_mask=True,
23
+ return_tensors='pt'
24
+ )
25
+ return {
26
+ 'input_ids': inputs['input_ids'].flatten(),
27
+ 'attention_mask': inputs['attention_mask'].flatten(),
28
+ 'labels': torch.tensor(row['label'], dtype=torch.long)
29
+ }
30
+
31
+ def train_model(model_name, train_data_path, output_dir, epochs=3, batch_size=16, max_len=128):
32
+ # Load the dataset
33
+ df = pd.read_csv(train_data_path)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ dataset = CustomDataset(df, tokenizer, max_len)
36
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
37
+
38
+ # Load the model
39
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(df['label'].unique()))
40
+
41
+ # Define training arguments
42
+ training_args = TrainingArguments(
43
+ output_dir=output_dir,
44
+ num_train_epochs=epochs,
45
+ per_device_train_batch_size=batch_size,
46
+ evaluation_strategy="epoch",
47
+ save_total_limit=2,
48
+ save_steps=10_000,
49
+ logging_dir=f'{output_dir}/logs',
50
+ )
51
+
52
+ # Initialize the Trainer
53
+ trainer = Trainer(
54
+ model=model,
55
+ args=training_args,
56
+ train_dataset=dataset,
57
+ )
58
+
59
+ # Train the model
60
+ trainer.train()
61
+
62
+ # Save the model
63
+ model.save_pretrained(output_dir)
64
+ tokenizer.save_pretrained(output_dir)
65
+
66
+ if __name__ == "__main__":
67
+ model_name = "bert-base-uncased"
68
+ train_data_path = "data/example_dataset.csv"
69
+ output_dir = "output"
70
+ train_model(model_name, train_data_path, output_dir)