Spaces:
Running on A10G
Running on A10G
| from pathlib import Path | |
| from typing import Literal, Optional, List | |
| import yaml | |
| from pydantic import BaseModel, Field | |
| class DataConfig(BaseModel): | |
| dataset: Optional[str] = None | |
| local_dataset: Optional[List[str]] = None | |
| format_type: Literal["auto", "alpaca", "chatml", "sharegpt"] = "auto" | |
| class TrainingConfig(BaseModel): | |
| training_type: Literal["lora", "full"] = "lora" | |
| max_seq_length: int = 2048 | |
| load_in_4bit: bool = True | |
| output_dir: Path = Path("./outputs") | |
| num_epochs: int = 3 | |
| learning_rate: float = 2e-4 | |
| batch_size: int = 2 | |
| gradient_accumulation_steps: int = 4 | |
| warmup_steps: int = 5 | |
| max_steps: int = 0 | |
| save_steps: int = 0 | |
| weight_decay: float = 0.01 | |
| random_seed: int = 3407 | |
| packing: bool = False | |
| train_on_completions: bool = False | |
| gradient_checkpointing: Literal["unsloth", "true", "none"] = "unsloth" | |
| class LoraConfig(BaseModel): | |
| lora_r: int = 64 | |
| lora_alpha: int = 16 | |
| lora_dropout: float = 0.0 | |
| target_modules: str = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" | |
| vision_all_linear: bool = False | |
| use_rslora: bool = False | |
| use_loftq: bool = False | |
| finetune_vision_layers: bool = True | |
| finetune_language_layers: bool = True | |
| finetune_attention_modules: bool = True | |
| finetune_mlp_modules: bool = True | |
| class LoggingConfig(BaseModel): | |
| enable_wandb: bool = False | |
| wandb_project: str = "unsloth-training" | |
| wandb_token: Optional[str] = None | |
| enable_tensorboard: bool = False | |
| tensorboard_dir: str = "runs" | |
| hf_token: Optional[str] = None | |
| class Config(BaseModel): | |
| model: Optional[str] = None | |
| data: DataConfig = Field(default_factory=DataConfig) | |
| training: TrainingConfig = Field(default_factory=TrainingConfig) | |
| lora: LoraConfig = Field(default_factory=LoraConfig) | |
| logging: LoggingConfig = Field(default_factory=LoggingConfig) | |
| def apply_overrides(self, **kwargs): | |
| """Apply CLI overrides by matching arg names to config fields.""" | |
| for key, value in kwargs.items(): | |
| if value is None: | |
| continue | |
| if hasattr(self, key): | |
| setattr(self, key, value) | |
| else: | |
| for section in (self.data, self.training, self.lora, self.logging): | |
| if hasattr(section, key): | |
| setattr(section, key, value) | |
| break | |
| def model_kwargs(self, use_lora: bool, is_vision: bool) -> dict: | |
| """Return kwargs for trainer.prepare_model_for_training().""" | |
| # Determine target modules based on model type | |
| if use_lora and is_vision: | |
| # Vision models expect a string (e.g., "all-linear"); fall back to None to use trainer defaults | |
| target_modules = "all-linear" if self.lora.vision_all_linear else None | |
| else: | |
| parsed = [m.strip() for m in str(self.lora.target_modules).split(",") if m and m.strip()] | |
| target_modules = parsed or None | |
| return { | |
| "use_lora": use_lora, | |
| "finetune_vision_layers": self.lora.finetune_vision_layers, | |
| "finetune_language_layers": self.lora.finetune_language_layers, | |
| "finetune_attention_modules": self.lora.finetune_attention_modules, | |
| "finetune_mlp_modules": self.lora.finetune_mlp_modules, | |
| "target_modules": target_modules, | |
| "lora_r": self.lora.lora_r, | |
| "lora_alpha": self.lora.lora_alpha, | |
| "lora_dropout": self.lora.lora_dropout, | |
| "use_gradient_checkpointing": self.training.gradient_checkpointing, | |
| "use_rslora": self.lora.use_rslora, | |
| "use_loftq": self.lora.use_loftq, | |
| } | |
| def training_kwargs(self) -> dict: | |
| """Return kwargs for trainer.start_training().""" | |
| return { | |
| "output_dir": str(self.training.output_dir), | |
| "num_epochs": self.training.num_epochs, | |
| "learning_rate": self.training.learning_rate, | |
| "batch_size": self.training.batch_size, | |
| "gradient_accumulation_steps": self.training.gradient_accumulation_steps, | |
| "warmup_steps": self.training.warmup_steps, | |
| "max_steps": self.training.max_steps, | |
| "save_steps": self.training.save_steps, | |
| "weight_decay": self.training.weight_decay, | |
| "random_seed": self.training.random_seed, | |
| "packing": self.training.packing, | |
| "train_on_completions": self.training.train_on_completions, | |
| "max_seq_length": self.training.max_seq_length, | |
| "enable_wandb": self.logging.enable_wandb, | |
| "wandb_project": self.logging.wandb_project, | |
| "wandb_token": self.logging.wandb_token, | |
| "enable_tensorboard": self.logging.enable_tensorboard, | |
| "tensorboard_dir": self.logging.tensorboard_dir, | |
| } | |
| def load_config(path: Optional[Path]) -> Config: | |
| """Load config from YAML/JSON file, or return defaults if no path given.""" | |
| if not path: | |
| return Config() | |
| path = Path(path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Config file not found: {path}") | |
| text = path.read_text(encoding="utf-8") | |
| if path.suffix.lower() in {".yaml", ".yml"}: | |
| data = yaml.safe_load(text) or {} | |
| else: | |
| import json | |
| data = json.loads(text or "{}") | |
| return Config(**data) | |