|
""" |
|
Module for storing the Model class, which can be used for wrapping sklearn or |
|
PyTorch models. This is more so that evaluation can be abstracted. |
|
""" |
|
|
|
import pickle |
|
import os |
|
from abc import ABC, abstractmethod |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from scipy.sparse import hstack |
|
from sklearn.svm import SVC |
|
|
|
|
|
import torch |
|
from torch.optim import AdamW |
|
from transformers import ( |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
get_scheduler |
|
) |
|
from torch.utils.data import DataLoader |
|
from tqdm.auto import tqdm |
|
|
|
from .dataset import JobDataset, SVMJobDataset, HuggingFaceJobDataset |
|
from .utils import FocalLoss, compute_metrics |
|
|
|
|
|
class Model(ABC): |
|
|
|
|
|
@abstractmethod |
|
def save_model(self, path: str, *args): |
|
"""Save the model into a serialized format (e.g. pickle, tensors)""" |
|
pass |
|
|
|
@abstractmethod |
|
def load_model(self, path: str, *args): |
|
"""Loads the model from the serialized format""" |
|
pass |
|
|
|
|
|
@abstractmethod |
|
def fit(self, dataset: JobDataset): |
|
"""Given the dataset class, train the underlying model""" |
|
pass |
|
|
|
@abstractmethod |
|
def evaluate(self, dataset: JobDataset): |
|
"""Given the dataset class, output the evaluation metrics""" |
|
pass |
|
|
|
@abstractmethod |
|
def __call__(self, *args, **kwargs): |
|
"""Given model inputs, predict the test set labels""" |
|
pass |
|
|
|
|
|
|
|
class DistilBERTBaseModel(Model): |
|
def __init__(self, |
|
pretrained_model="distilbert-base-uncased", |
|
num_labels=2, |
|
freeze=False, |
|
class_frequencies: Optional[torch.Tensor] = None, |
|
cpu=False): |
|
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if cpu: |
|
self._device = torch.device("cpu") |
|
print("Torch device: ", repr(self._device)) |
|
|
|
self._model = AutoModelForSequenceClassification.from_pretrained( |
|
pretrained_model, num_labels=num_labels |
|
).to(self._device) |
|
self._tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") |
|
|
|
|
|
|
|
|
|
if freeze: |
|
self.freeze_layers() |
|
|
|
|
|
|
|
self._loss = None |
|
if class_frequencies is not None: |
|
print(f"Loading a-balanced focal loss with weights {str(class_frequencies)}") |
|
self._loss = FocalLoss( |
|
class_frequencies=class_frequencies |
|
) |
|
|
|
|
|
|
|
self.set_training_args() |
|
|
|
def freeze_layers(self, layer_prefixes: Optional[set] = None): |
|
""" |
|
Freezes certain layers by prefixes in order to focus training on only |
|
certain layers. |
|
""" |
|
|
|
if layer_prefixes is None: |
|
layer_prefixes = set([ |
|
"distilbert.embeddings", |
|
"distilbert.transformer.layer.0", |
|
"distilbert.transformer.layer.1", |
|
"distilbert.transformer.layer.2", |
|
"distilbert.transformer.layer.3", |
|
]) |
|
|
|
for name, params in self._model.named_parameters(): |
|
if any(prefix for prefix in layer_prefixes if name.startswith(prefix)): |
|
params.requires_grad = False |
|
|
|
def set_training_args(self, **training_args): |
|
training_args.setdefault("output_dir", "../models/DistilBERTBase") |
|
training_args.setdefault("learning_rate", 2e-5) |
|
training_args.setdefault("per_device_train_batch_size", 16) |
|
training_args.setdefault("per_device_eval_batch_size", 16) |
|
training_args.setdefault("num_train_epochs", 3) |
|
training_args.setdefault("weight_decay", 0.01) |
|
training_args.setdefault("save_strategy", "epoch") |
|
training_args.setdefault("evaluation_strategy", "epoch") |
|
training_args.setdefault("logging_strategy", "epoch") |
|
|
|
self._train_args = training_args |
|
|
|
def save_model(self, path, checkpoint_name: str = "checkpoint"): |
|
path = os.path.join(path, checkpoint_name) |
|
self._model.save_pretrained(path) |
|
|
|
def load_model(self, path): |
|
self._model = AutoModelForSequenceClassification \ |
|
.from_pretrained(path) \ |
|
.to(self._device) |
|
|
|
def fit(self, |
|
dataset: HuggingFaceJobDataset, |
|
subsample: bool = False, |
|
plot_loss: bool = False, |
|
eval_loss: bool = False): |
|
|
|
|
|
train_dataloader = dataset.get_training_set(dataloader=True, subsample=subsample) |
|
eval_dataloader = dataset.get_validation_set(dataloader=True, subsample=subsample) |
|
num_epochs = self._train_args["num_train_epochs"] |
|
num_batches = len(train_dataloader) |
|
num_training_steps = num_epochs * num_batches |
|
optimizer = AdamW(self._model.parameters(), lr=5e-5) |
|
lr_scheduler = get_scheduler( |
|
name="linear", |
|
optimizer=optimizer, |
|
num_warmup_steps=0, |
|
num_training_steps=num_training_steps |
|
) |
|
progress_bar = tqdm(range(num_training_steps)) |
|
losses = [] |
|
eval_losses = [] |
|
|
|
|
|
self._model.train() |
|
for epoch in range(num_epochs): |
|
epoch_loss = 0.0 |
|
for batch in train_dataloader: |
|
batch = {k: v.to(self._device) for k, v in batch.items()} |
|
outputs = self._model(**batch) |
|
if self._loss is None: |
|
loss = outputs.loss |
|
else: |
|
logits = outputs.logits |
|
labels = batch["labels"] |
|
scores = torch.softmax(logits, dim=-1)[:len(labels), 1] |
|
loss = self._loss(scores, labels) |
|
loss.backward() |
|
|
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
progress_bar.update(1) |
|
|
|
epoch_loss += loss.item() |
|
losses.append(loss.item()) |
|
avg_loss = epoch_loss / num_batches |
|
print(f"Epoch {epoch+1} avg_loss: {avg_loss:.5f}") |
|
|
|
if eval_loss: |
|
eval_epoch_loss = 0.0 |
|
num_eval_batches = len(eval_dataloader) |
|
for batch in eval_dataloader: |
|
batch = {k: v.to(self._device) for k, v in batch.items()} |
|
with torch.no_grad(): |
|
outputs = self._model(**batch) |
|
loss = outputs.loss |
|
eval_epoch_loss += loss.item() |
|
eval_losses.append(loss.item()) |
|
avg_loss = eval_epoch_loss / num_eval_batches |
|
print(f" eval avg_loss: {avg_loss:.5f}") |
|
|
|
|
|
|
|
|
|
|
|
if plot_loss: |
|
kernel = np.ones(8) / 8 |
|
losses = np.convolve(np.array(losses), kernel, mode='valid') |
|
fig, ax = plt.subplots(figsize=(10, 5)) |
|
ax.plot(losses, label='Training Loss (MA-8)') |
|
if eval_losses: |
|
ax2 = ax.twiny() |
|
eval_losses = np.convolve(np.array(eval_losses), kernel, mode='valid') |
|
ax2.plot(eval_losses, color='orange', label='Eval Loss (MA-8)') |
|
ax2.legend() |
|
ax.set_xlabel('Batch') |
|
ax.set_ylabel('Average Loss') |
|
ax.set_title('Loss over Batches') |
|
ax.legend() |
|
fig.show() |
|
|
|
def evaluate(self, dataset: DataLoader, get_raw_results: bool = False, plot_pr_curve: bool = True): |
|
self._model.eval() |
|
targs_list = [] |
|
score_list = [] |
|
preds_list = [] |
|
|
|
for batch in tqdm(dataset): |
|
batch = {k: v.to(self._device) for k, v in batch.items()} |
|
with torch.no_grad(): |
|
outputs = self._model(**batch) |
|
|
|
logits = outputs.logits |
|
labels = batch["labels"] |
|
scores = torch.softmax(logits, dim=-1)[:len(labels), 1] |
|
predictions = torch.argmax(logits, dim=-1) |
|
|
|
targs_list.append(labels) |
|
score_list.append(scores) |
|
preds_list.append(predictions) |
|
|
|
targs = torch.concat(targs_list).cpu() |
|
scores = torch.concat(score_list).cpu() |
|
preds = torch.concat(preds_list).cpu() |
|
|
|
if get_raw_results: |
|
return targs, scores, preds |
|
else: |
|
return compute_metrics(targs, scores, preds, plot_pr_curve) |
|
|
|
def __call__(self, title: str, description: str) -> bool: |
|
inputs = self._tokenizer( |
|
title + " " + description, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True |
|
).to(self._device) |
|
|
|
with torch.inference_mode(): |
|
outputs = self._model(**inputs) |
|
predictions = torch.argmax(outputs.logits, dim=-1).tolist()[0] |
|
return bool(predictions) |