""" Module for storing the Model class, which can be used for wrapping sklearn or PyTorch models. This is more so that evaluation can be abstracted. """ import pickle import os from abc import ABC, abstractmethod from typing import Optional import numpy as np import matplotlib.pyplot as plt # Torch import torch from torch.optim import AdamW from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, get_scheduler ) from torch.utils.data import DataLoader from tqdm.auto import tqdm from .dataset import JobDataset, SVMJobDataset, HuggingFaceJobDataset from .utils import FocalLoss, compute_metrics class Model(ABC): # Saving and loading @abstractmethod def save_model(self, path: str, *args): """Save the model into a serialized format (e.g. pickle, tensors)""" pass @abstractmethod def load_model(self, path: str, *args): """Loads the model from the serialized format""" pass # Training, inference, evaluation @abstractmethod def fit(self, dataset: JobDataset): """Given the dataset class, train the underlying model""" pass @abstractmethod def evaluate(self, dataset: JobDataset): """Given the dataset class, output the evaluation metrics""" pass @abstractmethod def __call__(self, *args, **kwargs): """Given model inputs, predict the test set labels""" pass class DistilBERTBaseModel(Model): def __init__(self, pretrained_model="distilbert-base-uncased", num_labels=2, freeze=False, class_frequencies: Optional[torch.Tensor] = None, cpu=False): self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if cpu: self._device = torch.device("cpu") print("Torch device: ", repr(self._device)) self._model = AutoModelForSequenceClassification.from_pretrained( pretrained_model, num_labels=num_labels ).to(self._device) self._tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # Initially we trained it with the earlier layers frozen to try and # speed up training, however we eventually undid this once we've # established the training loop. if freeze: self.freeze_layers() # If class frequencies were provided, use them to construct the focal # loss formulation self._loss = None if class_frequencies is not None: print(f"Loading a-balanced focal loss with weights {str(class_frequencies)}") self._loss = FocalLoss( class_frequencies=class_frequencies ) # Set defaults, calling this after the fact to re-set parameters is # simple enough self.set_training_args() def freeze_layers(self, layer_prefixes: Optional[set] = None): """ Freezes certain layers by prefixes in order to focus training on only certain layers. """ if layer_prefixes is None: layer_prefixes = set([ "distilbert.embeddings", "distilbert.transformer.layer.0", "distilbert.transformer.layer.1", "distilbert.transformer.layer.2", "distilbert.transformer.layer.3", ]) for name, params in self._model.named_parameters(): if any(prefix for prefix in layer_prefixes if name.startswith(prefix)): params.requires_grad = False def set_training_args(self, **training_args): training_args.setdefault("output_dir", "../models/DistilBERTBase") training_args.setdefault("learning_rate", 2e-5) training_args.setdefault("per_device_train_batch_size", 16) training_args.setdefault("per_device_eval_batch_size", 16) training_args.setdefault("num_train_epochs", 3) training_args.setdefault("weight_decay", 0.01) training_args.setdefault("save_strategy", "epoch") training_args.setdefault("evaluation_strategy", "epoch") training_args.setdefault("logging_strategy", "epoch") self._train_args = training_args def save_model(self, path, checkpoint_name: str = "checkpoint"): path = os.path.join(path, checkpoint_name) self._model.save_pretrained(path) def load_model(self, path): self._model = AutoModelForSequenceClassification \ .from_pretrained(path) \ .to(self._device) def fit(self, dataset: HuggingFaceJobDataset, subsample: bool = False, plot_loss: bool = False, eval_loss: bool = False): # Set up optimizer and LR scheduler train_dataloader = dataset.get_training_set(dataloader=True, subsample=subsample) eval_dataloader = dataset.get_validation_set(dataloader=True, subsample=subsample) num_epochs = self._train_args["num_train_epochs"] num_batches = len(train_dataloader) num_training_steps = num_epochs * num_batches optimizer = AdamW(self._model.parameters(), lr=5e-5) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps ) progress_bar = tqdm(range(num_training_steps)) losses = [] eval_losses = [] # Training loop self._model.train() for epoch in range(num_epochs): epoch_loss = 0.0 for batch in train_dataloader: batch = {k: v.to(self._device) for k, v in batch.items()} outputs = self._model(**batch) if self._loss is None: loss = outputs.loss else: logits = outputs.logits labels = batch["labels"] scores = torch.softmax(logits, dim=-1)[:len(labels), 1] loss = self._loss(scores, labels) loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) epoch_loss += loss.item() losses.append(loss.item()) avg_loss = epoch_loss / num_batches print(f"Epoch {epoch+1} avg_loss: {avg_loss:.5f}") if eval_loss: eval_epoch_loss = 0.0 num_eval_batches = len(eval_dataloader) for batch in eval_dataloader: batch = {k: v.to(self._device) for k, v in batch.items()} with torch.no_grad(): outputs = self._model(**batch) loss = outputs.loss eval_epoch_loss += loss.item() eval_losses.append(loss.item()) avg_loss = eval_epoch_loss / num_eval_batches print(f" eval avg_loss: {avg_loss:.5f}") # Plot the loss if requested # Note that this is a moving average of the per-batch loss, which is # different from the usual per-epoch loss, as we only fine-tune for a # small number of epochs if plot_loss: kernel = np.ones(8) / 8 losses = np.convolve(np.array(losses), kernel, mode='valid') fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(losses, label='Training Loss (MA-8)') if eval_losses: ax2 = ax.twiny() eval_losses = np.convolve(np.array(eval_losses), kernel, mode='valid') ax2.plot(eval_losses, color='orange', label='Eval Loss (MA-8)') ax2.legend() ax.set_xlabel('Batch') ax.set_ylabel('Average Loss') ax.set_title('Loss over Batches') ax.legend() fig.show() def evaluate(self, dataset: DataLoader, get_raw_results: bool = False, plot_pr_curve: bool = True): self._model.eval() targs_list = [] score_list = [] preds_list = [] for batch in tqdm(dataset): batch = {k: v.to(self._device) for k, v in batch.items()} with torch.no_grad(): outputs = self._model(**batch) logits = outputs.logits labels = batch["labels"] scores = torch.softmax(logits, dim=-1)[:len(labels), 1] predictions = torch.argmax(logits, dim=-1) targs_list.append(labels) score_list.append(scores) preds_list.append(predictions) targs = torch.concat(targs_list).cpu() scores = torch.concat(score_list).cpu() preds = torch.concat(preds_list).cpu() if get_raw_results: return targs, scores, preds else: return compute_metrics(targs, scores, preds, plot_pr_curve) def __call__(self, title: str, description: str) -> bool: inputs = self._tokenizer( title + " " + description, return_tensors="pt", truncation=True, padding=True ).to(self._device) with torch.inference_mode(): outputs = self._model(**inputs) predictions = torch.argmax(outputs.logits, dim=-1).tolist()[0] return bool(predictions)