arabic-dialect-classifier / src /model_training.py
zaidmehdi's picture
plot training history
c112e25
raw
history blame
4.7 kB
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import DatasetDict, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils import get_dataset, plot_training_history
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
train_losses = []
valid_losses = []
train_accuracies = []
valid_accuracies = []
best_valid_loss = float("inf")
epochs_no_improve = 0
best_model = None
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
correct_train = 0
total_train = 0
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
train_loss += loss.item()
_, predicted_train = torch.max(outputs.logits, 1)
labels_train = batch["labels"]
correct_train += (predicted_train == labels_train).sum().item()
total_train += labels_train.size(0)
train_accuracy = correct_train / total_train
train_losses.append(train_loss / len(train_loader))
train_accuracies.append(train_accuracy)
model.eval()
valid_loss = 0.0
correct_valid = 0
total_valid = 0
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
valid_loss += loss.item()
_, predicted_valid = torch.max(outputs.logits, 1)
labels_valid = batch["labels"]
correct_valid += (predicted_valid == labels_valid).sum().item()
total_valid += labels_valid.size(0)
valid_loss /= len(val_loader)
valid_losses.append(valid_loss)
valid_accuracy = correct_valid / total_valid
valid_accuracies.append(valid_accuracy)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}')
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
epochs_no_improve = 0
best_model = model.state_dict()
torch.save(best_model, "best_model_checkpoint.pth")
else:
epochs_no_improve += 1
if epochs_no_improve == patience:
print(f'Early stopping after {epoch+1} epochs with no improvement.')
break
model.load_state_dict(best_model)
history = {"train_loss": train_losses,
"valid_loss": valid_losses,
"train_accuracies": train_accuracies,
"valid_accuracies": valid_accuracies}
return model, history
def main():
model_name = "moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
for data in dataset:
dataset[data] = dataset[data].remove_columns(["tweet"])
dataset[data] = dataset[data].rename_column("label", "labels")
dataset[data].set_format("torch")
train_loader = DataLoader(dataset["train"], batch_size=8, shuffle=True)
val_loader = DataLoader(dataset["val"], batch_size=8)
test_loader = DataLoader(dataset["test"], batch_size=8)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
for param in model.parameters():
param.requires_grad = False # We don't retrain the pretrained model due to lack of GPU
for param in model.classification_head.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 100
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
plot_training_history(history)
if __name__ == "__main__":
main()