| import argparse | |
| from typing import TYPE_CHECKING, Any, Dict, List, Union | |
| from finetrainers.utils import ArgsConfigMixin | |
| if TYPE_CHECKING: | |
| from finetrainers.args import BaseArgs | |
| class SFTLowRankConfig(ArgsConfigMixin): | |
| r""" | |
| Configuration class for SFT low rank training. | |
| Args: | |
| rank (int): | |
| Rank of the low rank approximation matrix. | |
| lora_alpha (int): | |
| The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. | |
| target_modules (`str` or `List[str]`): | |
| Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings. | |
| """ | |
| rank: int = 64 | |
| lora_alpha: int = 64 | |
| target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" | |
| def add_args(self, parser: argparse.ArgumentParser): | |
| parser.add_argument("--rank", type=int, default=64) | |
| parser.add_argument("--lora_alpha", type=int, default=64) | |
| parser.add_argument( | |
| "--target_modules", | |
| type=str, | |
| nargs="+", | |
| default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"], | |
| ) | |
| def validate_args(self, args: "BaseArgs"): | |
| assert self.rank > 0, "Rank must be a positive integer." | |
| assert self.lora_alpha > 0, "lora_alpha must be a positive integer." | |
| def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): | |
| mapped_args.rank = argparse_args.rank | |
| mapped_args.lora_alpha = argparse_args.lora_alpha | |
| mapped_args.target_modules = ( | |
| argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules | |
| ) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return {"rank": self.rank, "lora_alpha": self.lora_alpha, "target_modules": self.target_modules} | |
| class SFTFullRankConfig(ArgsConfigMixin): | |
| r""" | |
| Configuration class for SFT full rank training. | |
| """ | |
| def add_args(self, parser: argparse.ArgumentParser): | |
| pass | |
| def validate_args(self, args: "BaseArgs"): | |
| pass | |
| def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): | |
| pass | |