|
|
|
|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Literal, Optional |
|
|
|
|
|
from transformers import Seq2SeqTrainingArguments |
|
|
from transformers.utils.versions import require_version |
|
|
|
|
|
from swift.plugin import LOSS_MAPPING |
|
|
from swift.trainers import TrainerFactory |
|
|
from swift.trainers.arguments import TrainArgumentsMixin |
|
|
from swift.utils import (add_version_to_work_dir, get_device_count, get_logger, get_pai_tensorboard_dir, is_master, |
|
|
is_mp, is_pai_training_job, is_swanlab_available) |
|
|
from .base_args import BaseArguments, to_abspath |
|
|
from .tuner_args import TunerArguments |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Seq2SeqTrainingOverrideArguments(TrainArgumentsMixin, Seq2SeqTrainingArguments): |
|
|
"""Override the default value in `Seq2SeqTrainingArguments`""" |
|
|
output_dir: Optional[str] = None |
|
|
learning_rate: Optional[float] = None |
|
|
eval_strategy: Optional[str] = None |
|
|
fp16: Optional[bool] = None |
|
|
bf16: Optional[bool] = None |
|
|
|
|
|
def _init_output_dir(self): |
|
|
if self.output_dir is None: |
|
|
self.output_dir = f'output/{self.model_suffix}' |
|
|
self.output_dir = to_abspath(self.output_dir) |
|
|
|
|
|
def _init_eval_strategy(self): |
|
|
if self.eval_strategy is None: |
|
|
self.eval_strategy = self.save_strategy |
|
|
if self.eval_strategy == 'no': |
|
|
self.eval_steps = None |
|
|
self.split_dataset_ratio = 0. |
|
|
logger.info(f'Setting args.split_dataset_ratio: {self.split_dataset_ratio}') |
|
|
elif self.eval_strategy == 'steps' and self.eval_steps is None: |
|
|
self.eval_steps = self.save_steps |
|
|
self.evaluation_strategy = self.eval_strategy |
|
|
|
|
|
def _init_metric_for_best_model(self): |
|
|
if self.metric_for_best_model is None: |
|
|
self.metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss' |
|
|
|
|
|
def __post_init__(self): |
|
|
self._init_output_dir() |
|
|
self._init_metric_for_best_model() |
|
|
if self.greater_is_better is None and self.metric_for_best_model is not None: |
|
|
self.greater_is_better = 'loss' not in self.metric_for_best_model |
|
|
|
|
|
if self.learning_rate is None: |
|
|
if self.train_type == 'full': |
|
|
self.learning_rate = 1e-5 |
|
|
else: |
|
|
self.learning_rate = 1e-4 |
|
|
self._init_eval_strategy() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SwanlabArguments: |
|
|
|
|
|
swanlab_token: Optional[str] = None |
|
|
swanlab_project: Optional[str] = None |
|
|
swanlab_workspace: Optional[str] = None |
|
|
swanlab_exp_name: Optional[str] = None |
|
|
swanlab_mode: Literal['cloud', 'local'] = 'cloud' |
|
|
|
|
|
def _init_swanlab(self): |
|
|
if not is_swanlab_available(): |
|
|
raise ValueError('You are using swanlab as `report_to`, please install swanlab by ' '`pip install swanlab`') |
|
|
if not self.swanlab_exp_name: |
|
|
self.swanlab_exp_name = self.output_dir |
|
|
from transformers.integrations import INTEGRATION_TO_CALLBACK |
|
|
import swanlab |
|
|
from swanlab.integration.transformers import SwanLabCallback |
|
|
if self.swanlab_token: |
|
|
swanlab.login(self.swanlab_token) |
|
|
INTEGRATION_TO_CALLBACK['swanlab'] = SwanLabCallback( |
|
|
project=self.swanlab_project, |
|
|
workspace=self.swanlab_workspace, |
|
|
experiment_name=self.swanlab_exp_name, |
|
|
config={'UPPERFRAME': '🐦⬛ms-swift'}, |
|
|
mode=self.swanlab_mode, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainArguments(SwanlabArguments, TunerArguments, Seq2SeqTrainingOverrideArguments, BaseArguments): |
|
|
""" |
|
|
TrainArguments class is a dataclass that inherits from multiple argument classes: |
|
|
TunerArguments, Seq2SeqTrainingOverrideArguments, and BaseArguments. |
|
|
|
|
|
Args: |
|
|
add_version (bool): Flag to add version information to output_dir. Default is True. |
|
|
resume_only_model (bool): Flag to resume training only the model. Default is False. |
|
|
loss_type (Optional[str]): Type of loss function to use. Default is None. |
|
|
packing (bool): Flag to enable packing of datasets. Default is False. |
|
|
lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None. |
|
|
max_new_tokens (int): Maximum number of new tokens to generate. Default is 64. |
|
|
temperature (float): Temperature for sampling. Default is 0. |
|
|
optimizer (Optional[str]): Optimizer type to use, define it in the plugin package. Default is None. |
|
|
metric (Optional[str]): Metric to use for evaluation, define it in the plugin package. Default is None. |
|
|
""" |
|
|
add_version: bool = True |
|
|
resume_only_model: bool = False |
|
|
create_checkpoint_symlink: bool = False |
|
|
|
|
|
|
|
|
packing: bool = False |
|
|
lazy_tokenize: Optional[bool] = None |
|
|
|
|
|
|
|
|
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'}) |
|
|
optimizer: Optional[str] = None |
|
|
metric: Optional[str] = None |
|
|
|
|
|
|
|
|
max_new_tokens: int = 64 |
|
|
temperature: float = 0. |
|
|
load_args: bool = False |
|
|
|
|
|
|
|
|
zero_hpz_partition_size: Optional[int] = None |
|
|
|
|
|
def _init_lazy_tokenize(self): |
|
|
if self.streaming and self.lazy_tokenize: |
|
|
self.lazy_tokenize = False |
|
|
logger.warning('Streaming and lazy_tokenize are incompatible. ' |
|
|
f'Setting args.lazy_tokenize: {self.lazy_tokenize}.') |
|
|
if self.lazy_tokenize is None: |
|
|
self.lazy_tokenize = self.model_meta.is_multimodal and not self.streaming |
|
|
logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') |
|
|
|
|
|
def __post_init__(self) -> None: |
|
|
if self.packing and self.attn_impl != 'flash_attn': |
|
|
raise ValueError('The "packing" feature needs to be used in conjunction with "flash_attn". ' |
|
|
'Please specify `--attn_impl flash_attn`.') |
|
|
if self.resume_from_checkpoint: |
|
|
self.resume_from_checkpoint = to_abspath(self.resume_from_checkpoint, True) |
|
|
if self.resume_only_model: |
|
|
if self.train_type == 'full': |
|
|
self.model = self.resume_from_checkpoint |
|
|
else: |
|
|
self.adapters = [self.resume_from_checkpoint] |
|
|
BaseArguments.__post_init__(self) |
|
|
Seq2SeqTrainingOverrideArguments.__post_init__(self) |
|
|
TunerArguments.__post_init__(self) |
|
|
|
|
|
if self.optimizer is None: |
|
|
if self.lorap_lr_ratio: |
|
|
self.optimizer = 'lorap' |
|
|
elif self.use_galore: |
|
|
self.optimizer = 'galore' |
|
|
|
|
|
if len(self.dataset) == 0: |
|
|
raise ValueError(f'self.dataset: {self.dataset}, Please input the training dataset.') |
|
|
|
|
|
self._handle_pai_compat() |
|
|
|
|
|
self._init_deepspeed() |
|
|
self._init_device() |
|
|
self._init_lazy_tokenize() |
|
|
|
|
|
if getattr(self, 'accelerator_config', None) is None: |
|
|
self.accelerator_config = {'dispatch_batches': False} |
|
|
self.training_args = TrainerFactory.get_training_args(self) |
|
|
self.training_args.remove_unused_columns = False |
|
|
|
|
|
self._add_version() |
|
|
|
|
|
if 'swanlab' in self.report_to: |
|
|
self._init_swanlab() |
|
|
|
|
|
def _init_deepspeed(self): |
|
|
if self.deepspeed: |
|
|
require_version('deepspeed') |
|
|
if is_mp(): |
|
|
raise ValueError('DeepSpeed is not compatible with `device_map`. ' |
|
|
f'n_gpu: {get_device_count()}, ' |
|
|
f'local_world_size: {self.local_world_size}.') |
|
|
|
|
|
ds_config_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'ds_config')) |
|
|
deepspeed_mapping = { |
|
|
name: f'{name}.json' |
|
|
for name in ['zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'] |
|
|
} |
|
|
for ds_name, ds_config in deepspeed_mapping.items(): |
|
|
if self.deepspeed == ds_name: |
|
|
self.deepspeed = os.path.join(ds_config_folder, ds_config) |
|
|
break |
|
|
|
|
|
self.deepspeed = self.parse_to_dict(self.deepspeed) |
|
|
if self.zero_hpz_partition_size is not None: |
|
|
assert 'zero_optimization' in self.deepspeed |
|
|
self.deepspeed['zero_optimization']['zero_hpz_partition_size'] = self.zero_hpz_partition_size |
|
|
logger.warn('If `zero_hpz_partition_size`(ZeRO++) causes grad_norm NaN, please' |
|
|
' try `--torch_dtype float16`') |
|
|
logger.info(f'Using deepspeed: {self.deepspeed}') |
|
|
|
|
|
def _handle_pai_compat(self) -> None: |
|
|
if not is_pai_training_job(): |
|
|
return |
|
|
|
|
|
logger.info('Handle pai compat...') |
|
|
pai_tensorboard_dir = get_pai_tensorboard_dir() |
|
|
if self.logging_dir is None and pai_tensorboard_dir is not None: |
|
|
self.logging_dir = pai_tensorboard_dir |
|
|
logger.info(f'Setting args.logging_dir: {self.logging_dir}') |
|
|
self.add_version = False |
|
|
logger.info(f'Setting args.add_version: {self.add_version}') |
|
|
|
|
|
def _add_version(self): |
|
|
"""Prepare the output_dir""" |
|
|
if self.add_version: |
|
|
self.output_dir = add_version_to_work_dir(self.output_dir) |
|
|
logger.info(f'output_dir: {self.output_dir}') |
|
|
|
|
|
if self.logging_dir is None: |
|
|
self.logging_dir = f'{self.output_dir}/runs' |
|
|
|
|
|
self.logging_dir = to_abspath(self.logging_dir) |
|
|
if is_master(): |
|
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
|
|
if self.run_name is None: |
|
|
self.run_name = self.output_dir |
|
|
|
|
|
self.training_args.output_dir = self.output_dir |
|
|
self.training_args.run_name = self.run_name |
|
|
self.training_args.logging_dir = self.logging_dir |
|
|
|