bvishnu123's picture
Update fake_job_detector/models.py
fa6a4f6 verified
"""
Module for storing the Model class, which can be used for wrapping sklearn or
PyTorch models. This is more so that evaluation can be abstracted.
"""
import pickle
import os
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
# Torch
import torch
from torch.optim import AdamW
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
get_scheduler
)
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from .dataset import JobDataset, SVMJobDataset, HuggingFaceJobDataset
from .utils import FocalLoss, compute_metrics
class Model(ABC):
# Saving and loading
@abstractmethod
def save_model(self, path: str, *args):
"""Save the model into a serialized format (e.g. pickle, tensors)"""
pass
@abstractmethod
def load_model(self, path: str, *args):
"""Loads the model from the serialized format"""
pass
# Training, inference, evaluation
@abstractmethod
def fit(self, dataset: JobDataset):
"""Given the dataset class, train the underlying model"""
pass
@abstractmethod
def evaluate(self, dataset: JobDataset):
"""Given the dataset class, output the evaluation metrics"""
pass
@abstractmethod
def __call__(self, *args, **kwargs):
"""Given model inputs, predict the test set labels"""
pass
class DistilBERTBaseModel(Model):
def __init__(self,
pretrained_model="distilbert-base-uncased",
num_labels=2,
freeze=False,
class_frequencies: Optional[torch.Tensor] = None,
cpu=False):
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if cpu:
self._device = torch.device("cpu")
print("Torch device: ", repr(self._device))
self._model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model, num_labels=num_labels
).to(self._device)
self._tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Initially we trained it with the earlier layers frozen to try and
# speed up training, however we eventually undid this once we've
# established the training loop.
if freeze:
self.freeze_layers()
# If class frequencies were provided, use them to construct the focal
# loss formulation
self._loss = None
if class_frequencies is not None:
print(f"Loading a-balanced focal loss with weights {str(class_frequencies)}")
self._loss = FocalLoss(
class_frequencies=class_frequencies
)
# Set defaults, calling this after the fact to re-set parameters is
# simple enough
self.set_training_args()
def freeze_layers(self, layer_prefixes: Optional[set] = None):
"""
Freezes certain layers by prefixes in order to focus training on only
certain layers.
"""
if layer_prefixes is None:
layer_prefixes = set([
"distilbert.embeddings",
"distilbert.transformer.layer.0",
"distilbert.transformer.layer.1",
"distilbert.transformer.layer.2",
"distilbert.transformer.layer.3",
])
for name, params in self._model.named_parameters():
if any(prefix for prefix in layer_prefixes if name.startswith(prefix)):
params.requires_grad = False
def set_training_args(self, **training_args):
training_args.setdefault("output_dir", "../models/DistilBERTBase")
training_args.setdefault("learning_rate", 2e-5)
training_args.setdefault("per_device_train_batch_size", 16)
training_args.setdefault("per_device_eval_batch_size", 16)
training_args.setdefault("num_train_epochs", 3)
training_args.setdefault("weight_decay", 0.01)
training_args.setdefault("save_strategy", "epoch")
training_args.setdefault("evaluation_strategy", "epoch")
training_args.setdefault("logging_strategy", "epoch")
self._train_args = training_args
def save_model(self, path, checkpoint_name: str = "checkpoint"):
path = os.path.join(path, checkpoint_name)
self._model.save_pretrained(path)
def load_model(self, path):
self._model = AutoModelForSequenceClassification \
.from_pretrained(path) \
.to(self._device)
def fit(self,
dataset: HuggingFaceJobDataset,
subsample: bool = False,
plot_loss: bool = False,
eval_loss: bool = False):
# Set up optimizer and LR scheduler
train_dataloader = dataset.get_training_set(dataloader=True, subsample=subsample)
eval_dataloader = dataset.get_validation_set(dataloader=True, subsample=subsample)
num_epochs = self._train_args["num_train_epochs"]
num_batches = len(train_dataloader)
num_training_steps = num_epochs * num_batches
optimizer = AdamW(self._model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps
)
progress_bar = tqdm(range(num_training_steps))
losses = []
eval_losses = []
# Training loop
self._model.train()
for epoch in range(num_epochs):
epoch_loss = 0.0
for batch in train_dataloader:
batch = {k: v.to(self._device) for k, v in batch.items()}
outputs = self._model(**batch)
if self._loss is None:
loss = outputs.loss
else:
logits = outputs.logits
labels = batch["labels"]
scores = torch.softmax(logits, dim=-1)[:len(labels), 1]
loss = self._loss(scores, labels)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
epoch_loss += loss.item()
losses.append(loss.item())
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1} avg_loss: {avg_loss:.5f}")
if eval_loss:
eval_epoch_loss = 0.0
num_eval_batches = len(eval_dataloader)
for batch in eval_dataloader:
batch = {k: v.to(self._device) for k, v in batch.items()}
with torch.no_grad():
outputs = self._model(**batch)
loss = outputs.loss
eval_epoch_loss += loss.item()
eval_losses.append(loss.item())
avg_loss = eval_epoch_loss / num_eval_batches
print(f" eval avg_loss: {avg_loss:.5f}")
# Plot the loss if requested
# Note that this is a moving average of the per-batch loss, which is
# different from the usual per-epoch loss, as we only fine-tune for a
# small number of epochs
if plot_loss:
kernel = np.ones(8) / 8
losses = np.convolve(np.array(losses), kernel, mode='valid')
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(losses, label='Training Loss (MA-8)')
if eval_losses:
ax2 = ax.twiny()
eval_losses = np.convolve(np.array(eval_losses), kernel, mode='valid')
ax2.plot(eval_losses, color='orange', label='Eval Loss (MA-8)')
ax2.legend()
ax.set_xlabel('Batch')
ax.set_ylabel('Average Loss')
ax.set_title('Loss over Batches')
ax.legend()
fig.show()
def evaluate(self, dataset: DataLoader, get_raw_results: bool = False, plot_pr_curve: bool = True):
self._model.eval()
targs_list = []
score_list = []
preds_list = []
for batch in tqdm(dataset):
batch = {k: v.to(self._device) for k, v in batch.items()}
with torch.no_grad():
outputs = self._model(**batch)
logits = outputs.logits
labels = batch["labels"]
scores = torch.softmax(logits, dim=-1)[:len(labels), 1]
predictions = torch.argmax(logits, dim=-1)
targs_list.append(labels)
score_list.append(scores)
preds_list.append(predictions)
targs = torch.concat(targs_list).cpu()
scores = torch.concat(score_list).cpu()
preds = torch.concat(preds_list).cpu()
if get_raw_results:
return targs, scores, preds
else:
return compute_metrics(targs, scores, preds, plot_pr_curve)
def __call__(self, title: str, description: str) -> bool:
inputs = self._tokenizer(
title + " " + description,
return_tensors="pt",
truncation=True,
padding=True
).to(self._device)
with torch.inference_mode():
outputs = self._model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1).tolist()[0]
return bool(predictions)