from collections import defaultdict from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from datasets import Dataset from transformers.trainer_callback import TrainerCallback from setfit.span.modeling import AbsaModel, AspectModel, PolarityModel from setfit.training_args import TrainingArguments from .. import logging from ..trainer import ColumnMappingMixin, Trainer if TYPE_CHECKING: import optuna logger = logging.get_logger(__name__) class AbsaTrainer(ColumnMappingMixin): """Trainer to train a SetFit ABSA model. Args: model (`AbsaModel`): The AbsaModel model to train. args (`TrainingArguments`, *optional*): The training arguments to use. If `polarity_args` is not defined, then `args` is used for both the aspect and the polarity model. polarity_args (`TrainingArguments`, *optional*): The training arguments to use for the polarity model. If not defined, `args` is used for both the aspect and the polarity model. train_dataset (`Dataset`): The training dataset. The dataset must have "text", "span", "label" and "ordinal" columns. eval_dataset (`Dataset`, *optional*): The evaluation dataset. The dataset must have "text", "span", "label" and "ordinal" columns. metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings. If a callable is provided, it must take two arguments (`y_pred`, `y_test`). metric_kwargs (`Dict[str, Any]`, *optional*): Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". For example useful for providing an averaging strategy for computing f1 in a multi-label setting. callbacks (`List[`[`~transformers.TrainerCallback`]`]`, *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. column_mapping (`Dict[str, str]`, *optional*): A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: `{"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}`. """ _REQUIRED_COLUMNS = {"text", "span", "label", "ordinal"} def __init__( self, model: AbsaModel, args: Optional[TrainingArguments] = None, polarity_args: Optional[TrainingArguments] = None, train_dataset: Optional["Dataset"] = None, eval_dataset: Optional["Dataset"] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", metric_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[TrainerCallback]] = None, column_mapping: Optional[Dict[str, str]] = None, ) -> None: self.model = model self.aspect_extractor = model.aspect_extractor if train_dataset is not None and column_mapping: train_dataset = self._apply_column_mapping(train_dataset, column_mapping) aspect_train_dataset, polarity_train_dataset = self.preprocess_dataset( model.aspect_model, model.polarity_model, train_dataset ) if eval_dataset is not None and column_mapping: eval_dataset = self._apply_column_mapping(eval_dataset, column_mapping) aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( model.aspect_model, model.polarity_model, eval_dataset ) self.aspect_trainer = Trainer( model.aspect_model, args=args, train_dataset=aspect_train_dataset, eval_dataset=aspect_eval_dataset, metric=metric, metric_kwargs=metric_kwargs, callbacks=callbacks, ) self.aspect_trainer._set_logs_mapper( { "eval_embedding_loss": "eval_aspect_embedding_loss", "embedding_loss": "aspect_embedding_loss", } ) self.polarity_trainer = Trainer( model.polarity_model, args=polarity_args or args, train_dataset=polarity_train_dataset, eval_dataset=polarity_eval_dataset, metric=metric, metric_kwargs=metric_kwargs, callbacks=callbacks, ) self.polarity_trainer._set_logs_mapper( { "eval_embedding_loss": "eval_polarity_embedding_loss", "embedding_loss": "polarity_embedding_loss", } ) def preprocess_dataset( self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset ) -> Dataset: if dataset is None: return dataset, dataset # Group by "text" grouped_data = defaultdict(list) for sample in dataset: text = sample.pop("text") grouped_data[text].append(sample) def index_ordinal(text: str, target: str, ordinal: int) -> Tuple[int, int]: find_from = 0 for _ in range(ordinal + 1): start_idx = text.index(target, find_from) find_from = start_idx + 1 return start_idx, start_idx + len(target) def overlaps(aspect: slice, aspects: List[slice]) -> bool: for test_aspect in aspects: overlapping_indices = set(range(aspect.start, aspect.stop + 1)) & set( range(test_aspect.start, test_aspect.stop + 1) ) if overlapping_indices: return True return False docs, aspects_list = self.aspect_extractor(grouped_data.keys()) aspect_aspect_list = [] aspect_labels = [] polarity_aspect_list = [] polarity_labels = [] for doc, aspects, text in zip(docs, aspects_list, grouped_data): # Collect all of the gold aspects gold_aspects = [] gold_polarity_labels = [] for annotation in grouped_data[text]: try: start, end = index_ordinal(text, annotation["span"], annotation["ordinal"]) except ValueError: logger.info( f"The ordinal of {annotation['ordinal']} for span {annotation['span']!r} in {text!r} is too high. " "Skipping this sample." ) continue gold_aspect_span = doc.char_span(start, end) if gold_aspect_span is None: continue gold_aspects.append(slice(gold_aspect_span.start, gold_aspect_span.end)) gold_polarity_labels.append(annotation["label"]) # The Aspect model uses all gold aspects as "True", and all non-overlapping predicted # aspects as "False" aspect_labels.extend([True] * len(gold_aspects)) aspect_aspect_list.append(gold_aspects[:]) for aspect in aspects: if not overlaps(aspect, gold_aspects): aspect_labels.append(False) aspect_aspect_list[-1].append(aspect) # The Polarity model uses only the gold aspects and labels polarity_labels.extend(gold_polarity_labels) polarity_aspect_list.append(gold_aspects) aspect_texts = list(aspect_model.prepend_aspects(docs, aspect_aspect_list)) polarity_texts = list(polarity_model.prepend_aspects(docs, polarity_aspect_list)) return Dataset.from_dict({"text": aspect_texts, "label": aspect_labels}), Dataset.from_dict( {"text": polarity_texts, "label": polarity_labels} ) def train( self, args: Optional[TrainingArguments] = None, polarity_args: Optional[TrainingArguments] = None, trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, **kwargs, ) -> None: """ Main training entry point. Args: args (`TrainingArguments`, *optional*): Temporarily change the aspect training arguments for this training call. polarity_args (`TrainingArguments`, *optional*): Temporarily change the polarity training arguments for this training call. trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): The trial run or the hyperparameter dictionary for hyperparameter search. """ self.train_aspect(args=args, trial=trial, **kwargs) self.train_polarity(args=polarity_args, trial=trial, **kwargs) def train_aspect( self, args: Optional[TrainingArguments] = None, trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, **kwargs, ) -> None: """ Train the aspect model only. Args: args (`TrainingArguments`, *optional*): Temporarily change the aspect training arguments for this training call. trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): The trial run or the hyperparameter dictionary for hyperparameter search. """ self.aspect_trainer.train(args=args, trial=trial, **kwargs) def train_polarity( self, args: Optional[TrainingArguments] = None, trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, **kwargs, ) -> None: """ Train the polarity model only. Args: args (`TrainingArguments`, *optional*): Temporarily change the aspect training arguments for this training call. trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): The trial run or the hyperparameter dictionary for hyperparameter search. """ self.polarity_trainer.train(args=args, trial=trial, **kwargs) def add_callback(self, callback: Union[type, TrainerCallback]) -> None: """ Add a callback to the current list of [`~transformers.TrainerCallback`]. Args: callback (`type` or [`~transformers.TrainerCallback`]): A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the first case, will instantiate a member of that class. """ self.aspect_trainer.add_callback(callback) self.polarity_trainer.add_callback(callback) def pop_callback(self, callback: Union[type, TrainerCallback]) -> Tuple[TrainerCallback, TrainerCallback]: """ Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it. If the callback is not found, returns `None` (and no error is raised). Args: callback (`type` or [`~transformers.TrainerCallback`]): A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the first case, will pop the first member of that class found in the list of callbacks. Returns: `Tuple[`[`~transformers.TrainerCallback`], [`~transformers.TrainerCallback`]`]`: The callbacks removed from the aspect and polarity trainers, if found. """ return self.aspect_trainer.pop_callback(callback), self.polarity_trainer.pop_callback(callback) def remove_callback(self, callback: Union[type, TrainerCallback]) -> None: """ Remove a callback from the current list of [`~transformers.TrainerCallback`]. Args: callback (`type` or [`~transformers.TrainerCallback`]): A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the first case, will remove the first member of that class found in the list of callbacks. """ self.aspect_trainer.remove_callback(callback) self.polarity_trainer.remove_callback(callback) def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: """Upload model checkpoint to the Hub using `huggingface_hub`. See the full list of parameters for your `huggingface_hub` version in the\ [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub). Args: repo_id (`str`): The full repository ID to push to, e.g. `"tomaarsen/setfit-aspect"`. repo_id (`str`): The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. config (`dict`, *optional*): Configuration object to be saved alongside the model weights. commit_message (`str`, *optional*): Message to commit while pushing. private (`bool`, *optional*, defaults to `False`): Whether the repository created should be private. api_endpoint (`str`, *optional*): The API endpoint to use when pushing the model to the hub. token (`str`, *optional*): The token to use as HTTP bearer authorization for remote files. If not set, will use the token set when logging in with `transformers-cli login` (stored in `~/.huggingface`). branch (`str`, *optional*): The git branch on which to push the model. This defaults to the default branch as specified in your repository, which defaults to `"main"`. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are pushed. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not pushed. """ return self.model.push_to_hub(repo_id=repo_id, polarity_repo_id=polarity_repo_id, **kwargs) def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, Dict[str, float]]: """ Computes the metrics for a given classifier. Args: dataset (`Dataset`, *optional*): The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via the `eval_dataset` argument at `Trainer` initialization. Returns: `Dict[str, Dict[str, float]]`: The evaluation metrics. """ aspect_eval_dataset = polarity_eval_dataset = None if dataset: aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( self.model.aspect_model, self.model.polarity_model, dataset ) return { "aspect": self.aspect_trainer.evaluate(aspect_eval_dataset), "polarity": self.polarity_trainer.evaluate(polarity_eval_dataset), }