| | from dataclasses import dataclass, field, fields, asdict |
| | from typing import Optional, List, Literal, Dict, Any, Union |
| | from transformers import TrainingArguments, Trainer |
| | from omegaconf import OmegaConf |
| | import sys |
| |
|
| |
|
| | @dataclass |
| | class ModelConfig: |
| | model_name: str = "" |
| | dropout: float = 0.0 |
| | model_max_seq_length: int = field(default=512) |
| | data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"}) |
| | lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) |
| | adapter_path: Optional[str] = field(default=None) |
| |
|
| | merge_adapter_path: Optional[str] = field(default=None) |
| | merge_output_path: Optional[str] = field(default=None) |
| |
|
| | @dataclass |
| | class RotationConfig: |
| | r: int = field(default=4) |
| | num_rotations: int = field(default=4) |
| | task_type: str = "CAUSAL_LM" |
| | target_modules: List[str] = field(default_factory=lambda: ["q_proj",]) |
| |
|
| | @dataclass |
| | class DataConfig: |
| | dataset_name: str = 'math' |
| | split_ratio: float = field(default=0.01) |
| | path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json" |
| | dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"}) |
| | adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) |
| | dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."}) |
| |
|
| |
|
| | @dataclass |
| | class TrainingOverride: |
| | optim: str=field(default="adamw_torch") |
| | eval_strategy: str=field(default='no') |
| | per_device_train_batch_size: int=field(default=8) |
| | per_device_eval_batch_size: int=field(default=8) |
| |
|
| | learning_rate: float = field(default=1e-05) |
| | lr_scheduler_type: str = field(default='cosine') |
| | |
| | warmup_steps: int = field(default=0) |
| | |
| | gradient_checkpointing: bool = field(default=False) |
| | gradient_accumulation_steps: int=field(default=1) |
| | output_dir: str = field(default="runs") |
| | save_steps: float = field(default=0) |
| | save_strategy: str =field(default='no') |
| | |
| | bf16: bool=field(default=False) |
| | bf16_full_eval: bool=field(default=False) |
| | save_safetensors: bool=field(default=False) |
| |
|
| | report_to: Union[None, str, list[str]]=field(default="none") |
| | logging_steps: int=field(default=25) |
| | |
| | eval_steps: Union[None,int]=field(default=None) |
| |
|
| | dataloader_num_workers: int = field(default=1) |
| | dataloader_pin_memory: bool = field(default=True) |
| | dataloader_persistent_workers: bool=field(default=True) |
| | dataloader_prefetch_factor: int = field(default=1) |
| |
|
| | num_train_epochs: float = field(default=1.0) |
| | max_steps: int=field(default=-1) |
| | load_best_model_at_end: bool = field(default=True) |
| |
|
| | @dataclass |
| | class GlueConfig: |
| | task_name: str = field(default='mnli') |
| | pad_to_max_length: bool = field(default=True) |
| |
|
| |
|
| | @dataclass |
| | class MainConfig: |
| | model: ModelConfig = field(default_factory=ModelConfig) |
| | rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig) |
| | data: DataConfig = field(default_factory=DataConfig) |
| | trainer_args: TrainingOverride = field(default_factory=TrainingOverride) |
| |
|
| | glue: GlueConfig = field(default_factory=GlueConfig) |
| | project_name: str = "llm_rotation" |
| | seed: int = 42 |
| | run_text: str=field(default='def') |
| | |
| |
|
| | @dataclass |
| | class HFTrainingArguments(TrainingArguments): |
| | extension: Optional[Dict[str, Any]] = field( |
| | default=None, |
| | metadata={"help": "Serialized MainConfig excluding training args"} |
| | ) |
| |
|
| | def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments: |
| | """ |
| | Maps MainConfig to MyTrainingArguments. |
| | Logic: |
| | 1. Extract 'training' fields -> Pass to TrainingArguments constructor. |
| | 2. Pack 'model', 'data', etc. -> Put into 'extension'. |
| | """ |
| | KEY = "trainer_args" |
| | |
| | |
| | full_dict = asdict(main_cfg) |
| | |
| | |
| | |
| | train_args_dict = full_dict.pop(KEY) |
| | |
| | |
| | extension_payload = full_dict |
| | |
| | |
| | |
| | try: |
| | args = HFTrainingArguments(**train_args_dict) |
| | except TypeError as e: |
| | print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}") |
| | sys.exit(1) |
| | |
| | |
| | args.extension = extension_payload |
| | |
| | return args |
| |
|
| |
|
| |
|
| |
|
| | @dataclass |
| | class Training: |
| | model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b") |
| | adapter_name_or_path: Optional[str] = field(default=None) |
| | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
| | dataset_split: str = field( |
| | default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"} |
| | ) |
| | dataset_field: List[str] = field( |
| | default=None, metadata={"help": "Fields of dataset input and output."} |
| | ) |
| | optim: str = field(default="adamw_torch") |
| | model_max_length: int = field(default=512, metadata={ |
| | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) |
| | hrft_r: int = field(default=8, metadata={ |
| | "help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."}) |
| | init_a: float = field(default=1e-4, metadata={"help": "The initial weights"}) |
| | eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."}) |
| | lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) |
| | add_orth: str = field(default='none', metadata={"help": ""}) |
| | init_weights: Literal[True, "pissa"] = field( |
| | default=True, |
| | metadata={ |
| | "help": ( |
| | "Passing True (default) results in the LoRA initialization." |
| | "Passing `pissa` results in PiSSA initialization." |
| | ), |
| | }, |
| | ) |
| | extension: Optional[Dict[str, Any]] = field( |
| | default=None, |
| | metadata={"help": "Serialized MainConfig excluding training args"} |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |