Spaces:
Paused
Paused
File size: 47,444 Bytes
0fdb130 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 |
import math
import os
import shutil
import time
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import evaluate
import torch
from datasets import Dataset, DatasetDict
from sentence_transformers import InputExample, SentenceTransformer, losses
from sentence_transformers.datasets import SentenceLabelDataset
from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
from sentence_transformers.util import batch_to_device
from sklearn.preprocessing import LabelEncoder
from torch import nn
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from transformers.integrations import WandbCallback, get_reporting_integration_callbacks
from transformers.trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
IntervalStrategy,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from transformers.trainer_utils import (
HPSearchBackend,
default_compute_objective,
number_of_arguments,
set_seed,
speed_metrics,
)
from transformers.utils.import_utils import is_in_notebook
from setfit.model_card import ModelCardCallback
from . import logging
from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna
from .losses import SupConLoss
from .sampler import ContrastiveDataset
from .training_args import TrainingArguments
from .utils import BestRun, default_hp_space_optuna
# For Python 3.7 compatibility
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
if TYPE_CHECKING:
import optuna
from .modeling import SetFitModel
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from transformers.utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
class ColumnMappingMixin:
_REQUIRED_COLUMNS = {"text", "label"}
def _validate_column_mapping(self, dataset: "Dataset") -> None:
"""
Validates the provided column mapping against the dataset.
"""
column_names = set(dataset.column_names)
if self.column_mapping is None and not self._REQUIRED_COLUMNS.issubset(column_names):
# Issue #226: load_dataset will automatically assign points to "train" if no split is specified
if column_names == {"train"} and isinstance(dataset, DatasetDict):
raise ValueError(
"SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. "
"Did you mean to select the training split with dataset['train']?"
)
elif isinstance(dataset, DatasetDict):
raise ValueError(
f"SetFit expected a Dataset, but it got a DatasetDict with the splits {sorted(column_names)}. "
"Did you mean to select one of these splits from the dataset?"
)
else:
raise ValueError(
f"SetFit expected the dataset to have the columns {sorted(self._REQUIRED_COLUMNS)}, "
f"but only the columns {sorted(column_names)} were found. "
"Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer."
)
if self.column_mapping is not None:
missing_columns = set(self._REQUIRED_COLUMNS)
# Remove columns that will be provided via the column mapping
missing_columns -= set(self.column_mapping.values())
# Remove columns that will be provided because they are in the dataset & not mapped away
missing_columns -= set(dataset.column_names) - set(self.column_mapping.keys())
if missing_columns:
raise ValueError(
f"The following columns are missing from the column mapping: {missing_columns}. "
"Please provide a mapping for all required columns."
)
if not set(self.column_mapping.keys()).issubset(column_names):
raise ValueError(
f"The column mapping expected the columns {sorted(self.column_mapping.keys())} in the dataset, "
f"but the dataset had the columns {sorted(column_names)}."
)
def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, str]) -> "Dataset":
"""
Applies the provided column mapping to the dataset, renaming columns accordingly.
Extra features not in the column mapping are prefixed with `"feat_"`.
"""
dataset = dataset.rename_columns(
{
**column_mapping,
**{
col: f"feat_{col}"
for col in dataset.column_names
if col not in column_mapping and col not in self._REQUIRED_COLUMNS
},
}
)
dset_format = dataset.format
dataset = dataset.with_format(
type=dset_format["type"],
columns=dataset.column_names,
output_all_columns=dset_format["output_all_columns"],
**dset_format["format_kwargs"],
)
return dataset
class Trainer(ColumnMappingMixin):
"""Trainer to train a SetFit model.
Args:
model (`SetFitModel`, *optional*):
The model to train. If not provided, a `model_init` must be passed.
args (`TrainingArguments`, *optional*):
The training arguments to use.
train_dataset (`Dataset`):
The training dataset.
eval_dataset (`Dataset`, *optional*):
The evaluation dataset.
model_init (`Callable[[], SetFitModel]`, *optional*):
A function that instantiates the model to be used. If provided, each call to
[`Trainer.train`] will start from a new instance of the model as given by this
function when a `trial` is passed.
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
The metric to use for evaluation. If a string is provided, we treat it as the metric
name and load it with default settings. If a callable is provided, it must take two arguments
(`y_pred`, `y_test`) and return a dictionary with metric keys to values.
metric_kwargs (`Dict[str, Any]`, *optional*):
Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".
For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
callbacks (`List[`[`~transformers.TrainerCallback`]`]`, *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback).
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
column_mapping (`Dict[str, str]`, *optional*):
A mapping from the column names in the dataset to the column names expected by the model.
The expected format is a dictionary with the following format:
`{"text_column_name": "text", "label_column_name: "label"}`.
"""
def __init__(
self,
model: Optional["SetFitModel"] = None,
args: Optional[TrainingArguments] = None,
train_dataset: Optional["Dataset"] = None,
eval_dataset: Optional["Dataset"] = None,
model_init: Optional[Callable[[], "SetFitModel"]] = None,
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
metric_kwargs: Optional[Dict[str, Any]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
column_mapping: Optional[Dict[str, str]] = None,
) -> None:
if args is not None and not isinstance(args, TrainingArguments):
raise ValueError("`args` must be a `TrainingArguments` instance imported from `setfit`.")
self.args = args or TrainingArguments()
self.column_mapping = column_mapping
if train_dataset:
self._validate_column_mapping(train_dataset)
if self.column_mapping is not None:
logger.info("Applying column mapping to the training dataset")
train_dataset = self._apply_column_mapping(train_dataset, self.column_mapping)
self.train_dataset = train_dataset
if eval_dataset:
self._validate_column_mapping(eval_dataset)
if self.column_mapping is not None:
logger.info("Applying column mapping to the evaluation dataset")
eval_dataset = self._apply_column_mapping(eval_dataset, self.column_mapping)
self.eval_dataset = eval_dataset
self.model_init = model_init
self.metric = metric
self.metric_kwargs = metric_kwargs
self.logs_mapper = {}
# Seed must be set before instantiating the model when using model_init.
set_seed(12)
if model is None:
if model_init is not None:
model = self.call_model_init()
else:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument.")
else:
if model_init is not None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument, but not both.")
self.model = model
self.hp_search_backend = None
# Setup the callbacks
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
if WandbCallback in callbacks:
# Set the W&B project via environment variables if it's not already set
os.environ.setdefault("WANDB_PROJECT", "setfit")
# TODO: Observe optimizer and scheduler by wrapping SentenceTransformer._get_scheduler
self.callback_handler = CallbackHandler(callbacks, self.model, self.model.model_body.tokenizer, None, None)
self.state = TrainerState()
self.control = TrainerControl()
self.add_callback(DEFAULT_PROGRESS_CALLBACK if self.args.show_progress_bar else PrinterCallback)
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
# Add the callback for filling the model card data with hyperparameters
# and evaluation results
self.add_callback(ModelCardCallback(self))
self.callback_handler.on_init_end(args, self.state, self.control)
def add_callback(self, callback: Union[type, TrainerCallback]) -> None:
"""
Add a callback to the current list of [`~transformers.TrainerCallback`].
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will instantiate a member of that class.
"""
self.callback_handler.add_callback(callback)
def pop_callback(self, callback: Union[type, TrainerCallback]) -> TrainerCallback:
"""
Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
If the callback is not found, returns `None` (and no error is raised).
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will pop the first member of that class found in the list of callbacks.
Returns:
[`~transformers.TrainerCallback`]: The callback removed, if found.
"""
return self.callback_handler.pop_callback(callback)
def remove_callback(self, callback: Union[type, TrainerCallback]) -> None:
"""
Remove a callback from the current list of [`~transformers.TrainerCallback`].
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will remove the first member of that class found in the list of callbacks.
"""
self.callback_handler.remove_callback(callback)
def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = False) -> None:
"""Applies a dictionary of hyperparameters to both the trainer and the model
Args:
params (`Dict[str, Any]`): The parameters, usually from `BestRun.hyperparameters`
final_model (`bool`, *optional*, defaults to `False`): If `True`, replace the `model_init()` function with a fixed model based on the parameters.
"""
if self.args is not None:
self.args = self.args.update(params, ignore_extra=True)
else:
self.args = TrainingArguments.from_dict(params, ignore_extra=True)
# Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed)
self.model = self.model_init(params)
if final_model:
self.model_init = None
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]) -> None:
"""HP search setup code"""
# Heavily inspired by transformers.Trainer._hp_search_setup
if self.hp_search_backend is None or trial is None:
return
if isinstance(trial, Dict): # For passing a Dict to train() -- mostly unused for now
params = trial
elif self.hp_search_backend == HPSearchBackend.OPTUNA:
params = self.hp_space(trial)
else:
raise ValueError("Invalid trial parameter")
logger.info(f"Trial: {params}")
self.apply_hyperparameters(params, final_model=False)
def call_model_init(self, params: Optional[Dict[str, Any]] = None) -> "SetFitModel":
model_init_argcount = number_of_arguments(self.model_init)
if model_init_argcount == 0:
model = self.model_init()
elif model_init_argcount == 1:
model = self.model_init(params)
else:
raise RuntimeError("`model_init` should have 0 or 1 argument.")
if model is None:
raise RuntimeError("`model_init` should not return None.")
return model
def freeze(self, component: Optional[Literal["body", "head"]] = None) -> None:
"""Freeze the model body and/or the head, preventing further training on that component until unfrozen.
This method is deprecated, use `SetFitModel.freeze` instead.
Args:
component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to freeze that component.
If no component is provided, freeze both. Defaults to None.
"""
warnings.warn(
f"`{self.__class__.__name__}.freeze` is deprecated and will be removed in v2.0.0 of SetFit. "
"Please use `SetFitModel.freeze` directly instead.",
DeprecationWarning,
stacklevel=2,
)
return self.model.freeze(component)
def unfreeze(
self, component: Optional[Literal["body", "head"]] = None, keep_body_frozen: Optional[bool] = None
) -> None:
"""Unfreeze the model body and/or the head, allowing further training on that component.
This method is deprecated, use `SetFitModel.unfreeze` instead.
Args:
component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to unfreeze that component.
If no component is provided, unfreeze both. Defaults to None.
keep_body_frozen (`bool`, *optional*): Deprecated argument, use `component` instead.
"""
warnings.warn(
f"`{self.__class__.__name__}.unfreeze` is deprecated and will be removed in v2.0.0 of SetFit. "
"Please use `SetFitModel.unfreeze` directly instead.",
DeprecationWarning,
stacklevel=2,
)
return self.model.unfreeze(component, keep_body_frozen=keep_body_frozen)
def train(
self,
args: Optional[TrainingArguments] = None,
trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,
**kwargs,
) -> None:
"""
Main training entry point.
Args:
args (`TrainingArguments`, *optional*):
Temporarily change the training arguments for this training call.
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
The trial run or the hyperparameter dictionary for hyperparameter search.
"""
if len(kwargs):
warnings.warn(
f"`{self.__class__.__name__}.train` does not accept keyword arguments anymore. "
f"Please provide training arguments via a `TrainingArguments` instance to the `{self.__class__.__name__}` "
f"initialisation or the `{self.__class__.__name__}.train` method.",
DeprecationWarning,
stacklevel=2,
)
if trial: # Trial and model initialization
self._hp_search_setup(trial) # sets trainer parameters and initializes model
args = args or self.args or TrainingArguments()
if self.train_dataset is None:
raise ValueError(
f"Training requires a `train_dataset` given to the `{self.__class__.__name__}` initialization."
)
train_parameters = self.dataset_to_parameters(self.train_dataset)
full_parameters = (
train_parameters + self.dataset_to_parameters(self.eval_dataset) if self.eval_dataset else train_parameters
)
self.train_embeddings(*full_parameters, args=args)
self.train_classifier(*train_parameters, args=args)
def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]:
return [dataset["text"], dataset["label"]]
def train_embeddings(
self,
x_train: List[str],
y_train: Optional[Union[List[int], List[List[int]]]] = None,
x_eval: Optional[List[str]] = None,
y_eval: Optional[Union[List[int], List[List[int]]]] = None,
args: Optional[TrainingArguments] = None,
) -> None:
"""
Method to perform the embedding phase: finetuning the `SentenceTransformer` body.
Args:
x_train (`List[str]`): A list of training sentences.
y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences.
args (`TrainingArguments`, *optional*):
Temporarily change the training arguments for this training call.
"""
args = args or self.args or TrainingArguments()
# Since transformers v4.32.0, the log/eval/save steps should be saved on the state instead
self.state.logging_steps = args.logging_steps
self.state.eval_steps = args.eval_steps
self.state.save_steps = args.save_steps
# Reset the state
self.state.global_step = 0
self.state.total_flos = 0
train_max_pairs = -1 if args.max_steps == -1 else args.max_steps * args.embedding_batch_size
train_dataloader, loss_func, batch_size = self.get_dataloader(
x_train, y_train, args=args, max_pairs=train_max_pairs
)
if x_eval is not None and args.evaluation_strategy != IntervalStrategy.NO:
eval_max_pairs = -1 if args.eval_max_steps == -1 else args.eval_max_steps * args.embedding_batch_size
eval_dataloader, _, _ = self.get_dataloader(x_eval, y_eval, args=args, max_pairs=eval_max_pairs)
else:
eval_dataloader = None
if args.max_steps > 0:
total_train_steps = args.max_steps
else:
total_train_steps = len(train_dataloader) * args.embedding_num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader)}")
logger.info(f" Num epochs = {args.embedding_num_epochs}")
logger.info(f" Total optimization steps = {total_train_steps}")
logger.info(f" Total train batch size = {batch_size}")
warmup_steps = math.ceil(total_train_steps * args.warmup_proportion)
self._train_sentence_transformer(
self.model.model_body,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
args=args,
loss_func=loss_func,
warmup_steps=warmup_steps,
)
def get_dataloader(
self, x: List[str], y: Union[List[int], List[List[int]]], args: TrainingArguments, max_pairs: int = -1
) -> Tuple[DataLoader, nn.Module, int]:
# sentence-transformers adaptation
input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)]
if args.loss in [
losses.BatchAllTripletLoss,
losses.BatchHardTripletLoss,
losses.BatchSemiHardTripletLoss,
losses.BatchHardSoftMarginTripletLoss,
SupConLoss,
]:
data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.samples_per_label)
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=True)
if args.loss is losses.BatchHardSoftMarginTripletLoss:
loss = args.loss(
model=self.model.model_body,
distance_metric=args.distance_metric,
)
elif args.loss is SupConLoss:
loss = args.loss(model=self.model.model_body)
else:
loss = args.loss(
model=self.model.model_body,
distance_metric=args.distance_metric,
margin=args.margin,
)
else:
data_sampler = ContrastiveDataset(
input_data,
self.model.multi_target_strategy,
args.num_iterations,
args.sampling_strategy,
max_pairs=max_pairs,
)
# shuffle_sampler = True can be dropped in for further 'randomising'
shuffle_sampler = True if args.sampling_strategy == "unique" else False
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
loss = args.loss(self.model.model_body)
return dataloader, loss, batch_size
def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
logs = {self.logs_mapper.get(key, key): value for key, value in logs.items()}
if self.state.epoch is not None:
logs["epoch"] = round(self.state.epoch, 2)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
return self.callback_handler.on_log(args, self.state, self.control, logs)
def _set_logs_mapper(self, logs_mapper: Dict[str, str]) -> None:
"""Set the logging mapper.
Args:
logs_mapper (str): The logging mapper, e.g. {"eval_embedding_loss": "eval_aspect_embedding_loss"}.
"""
self.logs_mapper = logs_mapper
def _train_sentence_transformer(
self,
model_body: SentenceTransformer,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader],
args: TrainingArguments,
loss_func: nn.Module,
warmup_steps: int = 10000,
) -> None:
"""
Train the model with the given training objective
Each training objective is sampled in turn for one batch.
We sample only as many batches from each objective as there are in the smallest one
to make sure of equal training with each dataset.
"""
# TODO: args.gradient_accumulation_steps
# TODO: fp16/bf16, etc.
# TODO: Safetensors
# Hardcoded training arguments
max_grad_norm = 1
#
#
#
#
#
weight_decay = 5e-3 # 5e-3 best
#
#
#
#
#
self.state.epoch = 0
start_time = time.time()
if args.max_steps > 0:
self.state.max_steps = args.max_steps
else:
self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
steps_per_epoch = len(train_dataloader)
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
model_body.to(model_body._target_device)
loss_func.to(model_body._target_device)
# Use smart batching
train_dataloader.collate_fn = model_body.smart_batching_collate
if eval_dataloader:
eval_dataloader.collate_fn = model_body.smart_batching_collate
# Prepare optimizers
param_optimizer = list(loss_func.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **{"lr": args.body_embedding_learning_rate})
scheduler_obj = model_body._get_scheduler(
optimizer, scheduler="WarmupLinear", warmup_steps=warmup_steps, t_total=self.state.max_steps
)
self.callback_handler.optimizer = optimizer
self.callback_handler.lr_scheduler = scheduler_obj
self.callback_handler.train_dataloader = train_dataloader
self.callback_handler.eval_dataloader = eval_dataloader
self.callback_handler.on_train_begin(args, self.state, self.control)
data_iterator = iter(train_dataloader)
skip_scheduler = False
for epoch in range(args.embedding_num_epochs):
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
loss_func.zero_grad()
loss_func.train()
for step in range(steps_per_epoch):
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
try:
data = next(data_iterator)
except StopIteration:
data_iterator = iter(train_dataloader)
data = next(data_iterator)
features, labels = data
labels = labels.to(model_body._target_device)
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
if args.use_amp:
with autocast():
loss_value = loss_func(features, labels)
scale_before_step = scaler.get_scale()
scaler.scale(loss_value).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
skip_scheduler = scaler.get_scale() != scale_before_step
else:
loss_value = loss_func(features, labels)
loss_value.backward()
torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm)
optimizer.step()
optimizer.zero_grad()
if not skip_scheduler:
scheduler_obj.step()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_per_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self.maybe_log_eval_save(model_body, eval_dataloader, args, scheduler_obj, loss_func, loss_value)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self.maybe_log_eval_save(model_body, eval_dataloader, args, scheduler_obj, loss_func, loss_value)
if self.control.should_training_stop:
break
if self.args.load_best_model_at_end and self.state.best_model_checkpoint:
dir_name = Path(self.state.best_model_checkpoint).name
if dir_name.startswith("step_"):
step_to_load = dir_name[5:]
logger.info(f"Loading best SentenceTransformer model from step {step_to_load}.")
self.model.model_card_data.set_best_model_step(int(step_to_load))
self.model.model_body = SentenceTransformer(
self.state.best_model_checkpoint, device=model_body._target_device
)
self.model.model_body.to(model_body._target_device)
# Ensure logging the speed metrics
num_train_samples = self.state.max_steps * args.embedding_batch_size # * args.gradient_accumulation_steps
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.control.should_log = True
self.log(args, metrics)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
def maybe_log_eval_save(
self,
model_body: SentenceTransformer,
eval_dataloader: Optional[DataLoader],
args: TrainingArguments,
scheduler_obj,
loss_func,
loss_value: torch.Tensor,
) -> None:
if self.control.should_log:
learning_rate = scheduler_obj.get_last_lr()[0]
metrics = {"embedding_loss": round(loss_value.item(), 4), "learning_rate": learning_rate}
self.control = self.log(args, metrics)
eval_loss = None
if self.control.should_evaluate and eval_dataloader is not None:
eval_loss = self._evaluate_with_loss(model_body, eval_dataloader, args, loss_func)
learning_rate = scheduler_obj.get_last_lr()[0]
metrics = {"eval_embedding_loss": round(eval_loss, 4), "learning_rate": learning_rate}
self.control = self.log(args, metrics)
self.control = self.callback_handler.on_evaluate(args, self.state, self.control, metrics)
loss_func.zero_grad()
loss_func.train()
if self.control.should_save:
checkpoint_dir = self._checkpoint(self.args.output_dir, args.save_total_limit, self.state.global_step)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
if eval_loss is not None and (self.state.best_metric is None or eval_loss < self.state.best_metric):
self.state.best_metric = eval_loss
self.state.best_model_checkpoint = checkpoint_dir
def _evaluate_with_loss(
self,
model_body: SentenceTransformer,
eval_dataloader: DataLoader,
args: TrainingArguments,
loss_func: nn.Module,
) -> float:
model_body.eval()
losses = []
eval_steps = (
min(len(eval_dataloader), args.eval_max_steps) if args.eval_max_steps != -1 else len(eval_dataloader)
)
for step, data in enumerate(
tqdm(iter(eval_dataloader), total=eval_steps, leave=False, disable=not args.show_progress_bar), start=1
):
features, labels = data
labels = labels.to(model_body._target_device)
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
if args.use_amp:
with autocast():
loss_value = loss_func(features, labels)
losses.append(loss_value.item())
else:
losses.append(loss_func(features, labels).item())
if step >= eval_steps:
break
model_body.train()
return sum(losses) / len(losses)
def _checkpoint(self, checkpoint_path: str, checkpoint_save_total_limit: int, step: int) -> None:
# Delete old checkpoints
if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0:
old_checkpoints = []
for subdir in Path(checkpoint_path).glob("step_*"):
if subdir.name[5:].isdigit() and (
self.state.best_model_checkpoint is None or subdir != Path(self.state.best_model_checkpoint)
):
old_checkpoints.append({"step": int(subdir.name[5:]), "path": str(subdir)})
if len(old_checkpoints) > checkpoint_save_total_limit - 1:
old_checkpoints = sorted(old_checkpoints, key=lambda x: x["step"])
shutil.rmtree(old_checkpoints[0]["path"])
checkpoint_file_path = str(Path(checkpoint_path) / f"step_{step}")
self.model.save_pretrained(checkpoint_file_path)
return checkpoint_file_path
def train_classifier(
self, x_train: List[str], y_train: Union[List[int], List[List[int]]], args: Optional[TrainingArguments] = None
) -> None:
"""
Method to perform the classifier phase: fitting a classifier head.
Args:
x_train (`List[str]`): A list of training sentences.
y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences.
args (`TrainingArguments`, *optional*):
Temporarily change the training arguments for this training call.
"""
args = args or self.args or TrainingArguments()
self.model.fit(
x_train,
y_train,
num_epochs=args.classifier_num_epochs,
batch_size=args.classifier_batch_size,
body_learning_rate=args.body_classifier_learning_rate,
head_learning_rate=args.head_learning_rate,
l2_weight=args.l2_weight,
max_length=args.max_length,
show_progress_bar=args.show_progress_bar,
end_to_end=args.end_to_end,
)
def evaluate(self, dataset: Optional[Dataset] = None, metric_key_prefix: str = "test") -> Dict[str, float]:
"""
Computes the metrics for a given classifier.
Args:
dataset (`Dataset`, *optional*):
The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via
the `eval_dataset` argument at `Trainer` initialization.
Returns:
`Dict[str, float]`: The evaluation metrics.
"""
if dataset is not None:
self._validate_column_mapping(dataset)
if self.column_mapping is not None:
logger.info("Applying column mapping to the evaluation dataset")
eval_dataset = self._apply_column_mapping(dataset, self.column_mapping)
else:
eval_dataset = dataset
else:
eval_dataset = self.eval_dataset
if eval_dataset is None:
raise ValueError("No evaluation dataset provided to `Trainer.evaluate` nor the `Trainer` initialzation.")
x_test = eval_dataset["text"]
y_test = eval_dataset["label"]
logger.info("***** Running evaluation *****")
y_pred = self.model.predict(x_test, use_labels=False)
#
#
#
#
#
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.cpu()
#
#
#
#
#
# Normalize string outputs
if y_test and isinstance(y_test[0], str):
encoder = LabelEncoder()
encoder.fit(list(y_test) + list(y_pred))
y_test = encoder.transform(y_test)
y_pred = encoder.transform(y_pred)
metric_kwargs = self.metric_kwargs or {}
if isinstance(self.metric, str):
metric_config = "multilabel" if self.model.multi_target_strategy is not None else None
metric_fn = evaluate.load(self.metric, config_name=metric_config)
results = metric_fn.compute(predictions=y_pred, references=y_test, **metric_kwargs)
elif callable(self.metric):
results = self.metric(y_pred, y_test, **metric_kwargs)
else:
raise ValueError("metric must be a string or a callable")
if not isinstance(results, dict):
results = {"metric": results}
self.model.model_card_data.post_training_eval_results(
{f"{metric_key_prefix}_{key}": value for key, value in results.items()}
)
return results
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
n_trials: int = 10,
direction: str = "maximize",
backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs,
) -> BestRun:
"""
Launch a hyperparameter search using `optuna`. The optimized quantity is determined
by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
the sum of all metrics otherwise.
<Tip warning={true}>
To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
reinitialize the model at each new run.
</Tip>
Args:
hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
A function that defines the hyperparameter search space. Will default to
[`~transformers.trainer_utils.default_hp_space_optuna`].
compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
method. Will default to [`~transformers.trainer_utils.default_compute_objective`] which uses the sum of metrics.
n_trials (`int`, *optional*, defaults to 100):
The number of trial runs to test.
direction (`str`, *optional*, defaults to `"maximize"`):
Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
`"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
backend (`str` or [`~transformers.training_utils.HPSearchBackend`], *optional*):
The backend to use for hyperparameter search. Only optuna is supported for now.
TODO: add support for ray and sigopt.
hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
A function that defines the trial/run name. Will default to None.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to `optuna.create_study`. For more
information see:
- the documentation of
[optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
Returns:
[`trainer_utils.BestRun`]: All the information about the best run.
"""
if backend is None:
backend = default_hp_search_backend()
if backend is None:
raise RuntimeError("optuna should be installed. To install optuna run `pip install optuna`.")
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
elif backend != HPSearchBackend.OPTUNA:
raise RuntimeError("Only optuna backend is supported for hyperparameter search.")
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)
self.hp_space = default_hp_space_optuna if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
backend_dict = {
HPSearchBackend.OPTUNA: run_hp_search_optuna,
}
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
self.hp_search_backend = None
return best_run
def push_to_hub(self, repo_id: str, **kwargs) -> str:
"""Upload model checkpoint to the Hub using `huggingface_hub`.
See the full list of parameters for your `huggingface_hub` version in the\
[huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub).
Args:
repo_id (`str`):
The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`.
config (`dict`, *optional*):
Configuration object to be saved alongside the model weights.
commit_message (`str`, *optional*):
Message to commit while pushing.
private (`bool`, *optional*, defaults to `False`):
Whether the repository created should be private.
api_endpoint (`str`, *optional*):
The API endpoint to use when pushing the model to the hub.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files.
If not set, will use the token set when logging in with
`transformers-cli login` (stored in `~/.huggingface`).
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to
the default branch as specified in your repository, which
defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit.
Defaults to `False`.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are pushed.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not pushed.
Returns:
str: The url of the commit of your model in the given repository.
"""
if "/" not in repo_id:
raise ValueError(
'`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-sst2".'
)
commit_message = kwargs.pop("commit_message", "Add SetFit model")
return self.model.push_to_hub(repo_id, commit_message=commit_message, **kwargs)
class SetFitTrainer(Trainer):
"""
`SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit.
Please use `Trainer` instead.
"""
def __init__(
self,
model: Optional["SetFitModel"] = None,
train_dataset: Optional["Dataset"] = None,
eval_dataset: Optional["Dataset"] = None,
model_init: Optional[Callable[[], "SetFitModel"]] = None,
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
metric_kwargs: Optional[Dict[str, Any]] = None,
loss_class=losses.CosineSimilarityLoss,
num_iterations: int = 20,
num_epochs: int = 1,
learning_rate: float = 2e-5,
batch_size: int = 16,
seed: int = 42,
column_mapping: Optional[Dict[str, str]] = None,
use_amp: bool = False,
warmup_proportion: float = 0.1,
distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance,
margin: float = 0.25,
samples_per_label: int = 2,
):
warnings.warn(
"`SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. "
"Please use `Trainer` instead.",
DeprecationWarning,
stacklevel=2,
)
args = TrainingArguments(
num_iterations=num_iterations,
num_epochs=num_epochs,
body_learning_rate=learning_rate,
head_learning_rate=learning_rate,
batch_size=batch_size,
seed=seed,
use_amp=use_amp,
warmup_proportion=warmup_proportion,
distance_metric=distance_metric,
margin=margin,
samples_per_label=samples_per_label,
loss=loss_class,
)
super().__init__(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model_init=model_init,
metric=metric,
metric_kwargs=metric_kwargs,
column_mapping=column_mapping,
)
|