|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
|
import numpy as np |
|
import logging |
|
import os |
|
from typing import Dict, Type, Callable, List |
|
import transformers |
|
import torch |
|
from torch import nn |
|
from torch.optim import Optimizer |
|
from torch.utils.data import DataLoader |
|
from tqdm.autonotebook import tqdm, trange |
|
from .. import SentenceTransformer, util |
|
from ..evaluation import SentenceEvaluator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CrossEncoder(): |
|
def __init__(self, model_name:str, num_labels:int = None, max_length:int = None, device:str = None, tokenizer_args:Dict = {}, |
|
automodel_args:Dict = {}, default_activation_function = None): |
|
""" |
|
A CrossEncoder takes exactly two sentences / texts as input and either predicts |
|
a score or label for this sentence pair. It can for example predict the similarity of the sentence pair |
|
on a scale of 0 ... 1. |
|
|
|
It does not yield a sentence embedding and does not work for individually sentences. |
|
|
|
:param model_name: Any model name from Huggingface Models Repository that can be loaded with AutoModel. We provide several pre-trained CrossEncoder models that can be used for common tasks |
|
:param num_labels: Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continous score 0...1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes. |
|
:param max_length: Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used |
|
:param device: Device that should be used for the model. If None, it will use CUDA if available. |
|
:param tokenizer_args: Arguments passed to AutoTokenizer |
|
:param automodel_args: Arguments passed to AutoModelForSequenceClassification |
|
:param default_activation_function: Callable (like nn.Sigmoid) about the default activation function that should be used on-top of model.predict(). If None. nn.Sigmoid() will be used if num_labels=1, else nn.Identity() |
|
""" |
|
|
|
self.config = AutoConfig.from_pretrained(model_name) |
|
classifier_trained = True |
|
if self.config.architectures is not None: |
|
classifier_trained = any([arch.endswith('ForSequenceClassification') for arch in self.config.architectures]) |
|
|
|
if num_labels is None and not classifier_trained: |
|
num_labels = 1 |
|
|
|
if num_labels is not None: |
|
self.config.num_labels = num_labels |
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config, **automodel_args) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args) |
|
self.max_length = max_length |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info("Use pytorch device: {}".format(device)) |
|
|
|
self._target_device = torch.device(device) |
|
|
|
if default_activation_function is not None: |
|
self.default_activation_function = default_activation_function |
|
try: |
|
self.config.sbert_ce_default_activation_function = util.fullname(self.default_activation_function) |
|
except Exception as e: |
|
logger.warning("Was not able to update config about the default_activation_function: {}".format(str(e)) ) |
|
elif hasattr(self.config, 'sbert_ce_default_activation_function') and self.config.sbert_ce_default_activation_function is not None: |
|
self.default_activation_function = util.import_from_string(self.config.sbert_ce_default_activation_function)() |
|
else: |
|
self.default_activation_function = nn.Sigmoid() if self.config.num_labels == 1 else nn.Identity() |
|
|
|
def smart_batching_collate(self, batch): |
|
texts = [[] for _ in range(len(batch[0].texts))] |
|
labels = [] |
|
|
|
for example in batch: |
|
for idx, text in enumerate(example.texts): |
|
texts[idx].append(text.strip()) |
|
|
|
labels.append(example.label) |
|
|
|
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) |
|
labels = torch.tensor(labels, dtype=torch.float if self.config.num_labels == 1 else torch.long).to(self._target_device) |
|
|
|
for name in tokenized: |
|
tokenized[name] = tokenized[name].to(self._target_device) |
|
|
|
return tokenized, labels |
|
|
|
def smart_batching_collate_text_only(self, batch): |
|
texts = [[] for _ in range(len(batch[0]))] |
|
|
|
for example in batch: |
|
for idx, text in enumerate(example): |
|
texts[idx].append(text.strip()) |
|
|
|
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) |
|
|
|
for name in tokenized: |
|
tokenized[name] = tokenized[name].to(self._target_device) |
|
|
|
return tokenized |
|
|
|
def fit(self, |
|
train_dataloader: DataLoader, |
|
evaluator: SentenceEvaluator = None, |
|
epochs: int = 1, |
|
loss_fct = None, |
|
activation_fct = nn.Identity(), |
|
scheduler: str = 'WarmupLinear', |
|
warmup_steps: int = 10000, |
|
optimizer_class: Type[Optimizer] = torch.optim.AdamW, |
|
optimizer_params: Dict[str, object] = {'lr': 2e-5}, |
|
weight_decay: float = 0.01, |
|
evaluation_steps: int = 0, |
|
output_path: str = None, |
|
save_best_model: bool = True, |
|
max_grad_norm: float = 1, |
|
use_amp: bool = False, |
|
callback: Callable[[float, int, int], None] = None, |
|
show_progress_bar: bool = True |
|
): |
|
""" |
|
Train the model with the given training objective |
|
Each training objective is sampled in turn for one batch. |
|
We sample only as many batches from each objective as there are in the smallest one |
|
to make sure of equal training with each dataset. |
|
|
|
:param train_dataloader: DataLoader with training InputExamples |
|
:param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc. |
|
:param epochs: Number of epochs for training |
|
:param loss_fct: Which loss function to use for training. If None, will use nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss() |
|
:param activation_fct: Activation function applied on top of logits output of model. |
|
:param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts |
|
:param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero. |
|
:param optimizer_class: Optimizer |
|
:param optimizer_params: Optimizer parameters |
|
:param weight_decay: Weight decay for model parameters |
|
:param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps |
|
:param output_path: Storage path for the model and evaluation files |
|
:param save_best_model: If true, the best model (according to evaluator) is stored at output_path |
|
:param max_grad_norm: Used for gradient normalization. |
|
:param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0 |
|
:param callback: Callback function that is invoked after each evaluation. |
|
It must accept the following three parameters in this order: |
|
`score`, `epoch`, `steps` |
|
:param show_progress_bar: If True, output a tqdm progress bar |
|
""" |
|
train_dataloader.collate_fn = self.smart_batching_collate |
|
|
|
if use_amp: |
|
from torch.cuda.amp import autocast |
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
self.model.to(self._target_device) |
|
|
|
if output_path is not None: |
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
self.best_score = -9999999 |
|
num_train_steps = int(len(train_dataloader) * epochs) |
|
|
|
|
|
param_optimizer = list(self.model.named_parameters()) |
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, |
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
] |
|
|
|
optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params) |
|
|
|
if isinstance(scheduler, str): |
|
scheduler = SentenceTransformer._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps) |
|
|
|
if loss_fct is None: |
|
loss_fct = nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss() |
|
|
|
|
|
skip_scheduler = False |
|
for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar): |
|
training_steps = 0 |
|
self.model.zero_grad() |
|
self.model.train() |
|
|
|
for features, labels in tqdm(train_dataloader, desc="Iteration", smoothing=0.05, disable=not show_progress_bar): |
|
if use_amp: |
|
with autocast(): |
|
model_predictions = self.model(**features, return_dict=True) |
|
logits = activation_fct(model_predictions.logits) |
|
if self.config.num_labels == 1: |
|
logits = logits.view(-1) |
|
loss_value = loss_fct(logits, labels) |
|
|
|
scale_before_step = scaler.get_scale() |
|
scaler.scale(loss_value).backward() |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
skip_scheduler = scaler.get_scale() != scale_before_step |
|
else: |
|
model_predictions = self.model(**features, return_dict=True) |
|
logits = activation_fct(model_predictions.logits) |
|
if self.config.num_labels == 1: |
|
logits = logits.view(-1) |
|
loss_value = loss_fct(logits, labels) |
|
loss_value.backward() |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) |
|
optimizer.step() |
|
|
|
optimizer.zero_grad() |
|
|
|
if not skip_scheduler: |
|
scheduler.step() |
|
|
|
training_steps += 1 |
|
|
|
if evaluator is not None and evaluation_steps > 0 and training_steps % evaluation_steps == 0: |
|
self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps, callback) |
|
|
|
self.model.zero_grad() |
|
self.model.train() |
|
|
|
if evaluator is not None: |
|
self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback) |
|
|
|
|
|
|
|
def predict(self, sentences: List[List[str]], |
|
batch_size: int = 32, |
|
show_progress_bar: bool = None, |
|
num_workers: int = 0, |
|
activation_fct = None, |
|
apply_softmax = False, |
|
convert_to_numpy: bool = True, |
|
convert_to_tensor: bool = False |
|
): |
|
""" |
|
Performs predicts with the CrossEncoder on the given sentence pairs. |
|
|
|
:param sentences: A list of sentence pairs [[Sent1, Sent2], [Sent3, Sent4]] |
|
:param batch_size: Batch size for encoding |
|
:param show_progress_bar: Output progress bar |
|
:param num_workers: Number of workers for tokenization |
|
:param activation_fct: Activation function applied on the logits output of the CrossEncoder. If None, nn.Sigmoid() will be used if num_labels=1, else nn.Identity |
|
:param convert_to_numpy: Convert the output to a numpy matrix. |
|
:param apply_softmax: If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output |
|
:param convert_to_tensor: Conver the output to a tensor. |
|
:return: Predictions for the passed sentence pairs |
|
""" |
|
input_was_string = False |
|
if isinstance(sentences[0], str): |
|
sentences = [sentences] |
|
input_was_string = True |
|
|
|
inp_dataloader = DataLoader(sentences, batch_size=batch_size, collate_fn=self.smart_batching_collate_text_only, num_workers=num_workers, shuffle=False) |
|
|
|
if show_progress_bar is None: |
|
show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) |
|
|
|
iterator = inp_dataloader |
|
if show_progress_bar: |
|
iterator = tqdm(inp_dataloader, desc="Batches") |
|
|
|
if activation_fct is None: |
|
activation_fct = self.default_activation_function |
|
|
|
pred_scores = [] |
|
self.model.eval() |
|
self.model.to(self._target_device) |
|
with torch.no_grad(): |
|
for features in iterator: |
|
model_predictions = self.model(**features, return_dict=True) |
|
logits = activation_fct(model_predictions.logits) |
|
|
|
if apply_softmax and len(logits[0]) > 1: |
|
logits = torch.nn.functional.softmax(logits, dim=1) |
|
pred_scores.extend(logits) |
|
|
|
if self.config.num_labels == 1: |
|
pred_scores = [score[0] for score in pred_scores] |
|
|
|
if convert_to_tensor: |
|
pred_scores = torch.stack(pred_scores) |
|
elif convert_to_numpy: |
|
pred_scores = np.asarray([score.cpu().detach().numpy() for score in pred_scores]) |
|
|
|
if input_was_string: |
|
pred_scores = pred_scores[0] |
|
|
|
return pred_scores |
|
|
|
|
|
def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback): |
|
"""Runs evaluation during the training""" |
|
if evaluator is not None: |
|
score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps) |
|
if callback is not None: |
|
callback(score, epoch, steps) |
|
if score > self.best_score: |
|
self.best_score = score |
|
if save_best_model: |
|
self.save(output_path) |
|
|
|
def save(self, path): |
|
""" |
|
Saves all model and tokenizer to path |
|
""" |
|
if path is None: |
|
return |
|
|
|
logger.info("Save model to {}".format(path)) |
|
self.model.save_pretrained(path) |
|
self.tokenizer.save_pretrained(path) |
|
|
|
def save_pretrained(self, path): |
|
""" |
|
Same function as save |
|
""" |
|
return self.save(path) |
|
|
|
|