| | import argparse |
| | import json |
| | import inspect |
| | import math |
| | import time |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional, Tuple, List |
| |
|
| | import torch |
| | import yaml |
| | from datasets import load_dataset, DatasetDict |
| | from huggingface_hub import snapshot_download |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | AutoModel, |
| | AutoConfig, |
| | BitsAndBytesConfig, |
| | TrainingArguments, |
| | Trainer, |
| | TrainerCallback, |
| | EarlyStoppingCallback, |
| | default_data_collator, |
| | set_seed, |
| | ) |
| | from transformers.trainer_utils import get_last_checkpoint |
| | from peft import ( |
| | LoraConfig, |
| | get_peft_model, |
| | prepare_model_for_kbit_training, |
| | PeftModel, |
| | ) |
| |
|
| | try: |
| | import wandb |
| | WANDB_AVAILABLE = True |
| | except ImportError: |
| | WANDB_AVAILABLE = False |
| | wandb = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def _dtype_from_str(s: str) -> torch.dtype: |
| | s = (s or "").lower() |
| | if s in ("float16", "fp16"): |
| | return torch.float16 |
| | if s in ("bfloat16", "bf16"): |
| | return torch.bfloat16 |
| | if s in ("float32", "fp32"): |
| | return torch.float32 |
| | raise ValueError(f"Unknown torch_dtype: {s}") |
| |
|
| |
|
| | def _now_iso() -> str: |
| | return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) |
| |
|
| |
|
| | def _safe_exp(x: float) -> float: |
| | x = min(float(x), 50.0) |
| | return float(math.exp(x)) |
| |
|
| |
|
| | def _ensure_dir(p: Path) -> Path: |
| | p.mkdir(parents=True, exist_ok=True) |
| | return p |
| |
|
| |
|
| | def _looks_like_model_dir(p: Path) -> bool: |
| | if not p.exists() or not p.is_dir(): |
| | return False |
| | if (p / "config.json").exists(): |
| | return True |
| | if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")): |
| | return True |
| | return False |
| |
|
| |
|
| | def _infer_target_modules(model) -> List[str]: |
| | names = set() |
| | for n, _ in model.named_modules(): |
| | names.add(n.split(".")[-1]) |
| |
|
| | for group in [ |
| | ["q_proj", "k_proj", "v_proj", "o_proj"], |
| | ["Wqkv", "out_proj"], |
| | ["query_key_value", "dense"], |
| | ["c_attn", "c_proj"], |
| | ]: |
| | if all(x in names for x in group): |
| | return group |
| |
|
| | fallback = [ |
| | x |
| | for x in [ |
| | "q_proj", |
| | "k_proj", |
| | "v_proj", |
| | "o_proj", |
| | "c_attn", |
| | "c_proj", |
| | "out_proj", |
| | "dense", |
| | ] |
| | if x in names |
| | ] |
| | if fallback: |
| | return fallback |
| |
|
| | raise ValueError( |
| | "Could not auto-infer target_modules. Set peft.target_modules explicitly." |
| | ) |
| |
|
| |
|
| | def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]: |
| | return cfg.get("model", {}).get("attn_implementation", None) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def setup_wandb(cfg: Dict[str, Any], run_dir: Path): |
| | """Initialize Wandb if enabled in configuration.""" |
| | wandb_cfg = cfg.get("wandb", {}) |
| | |
| | if not wandb_cfg.get("enabled", False): |
| | print("Wandb logging disabled") |
| | return None |
| | |
| | if not WANDB_AVAILABLE: |
| | print("Wandb not available. Install with: pip install wandb") |
| | return None |
| | |
| | |
| | project = wandb_cfg.get("project", "sft-training") |
| | entity = wandb_cfg.get("entity", None) |
| | name = wandb_cfg.get("name", None) |
| | tags = wandb_cfg.get("tags", []) |
| | notes = wandb_cfg.get("notes", None) |
| | |
| | |
| | try: |
| | wandb.init( |
| | project=project, |
| | entity=entity, |
| | name=name, |
| | tags=tags, |
| | notes=notes, |
| | dir=str(run_dir), |
| | config={ |
| | "model": cfg.get("model", {}), |
| | "data": cfg.get("data", {}), |
| | "peft": cfg.get("peft", {}), |
| | "train": cfg.get("train", {}), |
| | "run_dir": str(run_dir), |
| | } |
| | ) |
| | print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'") |
| | return wandb |
| | except Exception as e: |
| | print(f"Failed to initialize Wandb: {e}") |
| | return None |
| |
|
| |
|
| | def finish_wandb(): |
| | """Finish Wandb run if active.""" |
| | if WANDB_AVAILABLE and wandb.run is not None: |
| | wandb.finish() |
| | print("Wandb run finished") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class JsonlLoggerCallback(TrainerCallback): |
| | def __init__(self, run_dir: Path): |
| | self.run_dir = run_dir |
| | self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl" |
| | self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl" |
| | self.start_time = None |
| |
|
| | def _eta(self, global_step: int, max_steps: int) -> Optional[str]: |
| | if self.start_time is None or global_step <= 0 or max_steps <= 0: |
| | return None |
| | elapsed = time.time() - self.start_time |
| | sec_per_step = elapsed / global_step |
| | remaining = max(0, max_steps - global_step) * sec_per_step |
| | h = int(remaining // 3600) |
| | m = int((remaining % 3600) // 60) |
| | s = int(remaining % 60) |
| | return f"{h:02d}:{m:02d}:{s:02d}" |
| |
|
| | def on_train_begin(self, args, state, control, **kwargs): |
| | self.start_time = time.time() |
| |
|
| | def on_log(self, args, state, control, logs=None, **kwargs): |
| | if not logs: |
| | return |
| |
|
| | max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0 |
| | progress_pct = ( |
| | (100.0 * state.global_step / max_steps) if max_steps > 0 else None |
| | ) |
| | epoch_pct = None |
| | if ( |
| | state.epoch is not None |
| | and args.num_train_epochs |
| | and args.num_train_epochs > 0 |
| | ): |
| | epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs)) |
| |
|
| | payload = { |
| | "ts": _now_iso(), |
| | "event": "train_log", |
| | "step": int(state.global_step), |
| | "epoch": round(float(state.epoch), 4) if state.epoch is not None else None, |
| | "progress_pct": ( |
| | round(progress_pct, 2) if progress_pct is not None else None |
| | ), |
| | "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None, |
| | "eta": self._eta(int(state.global_step), max_steps), |
| | "max_grad_norm": getattr(args, "max_grad_norm", None), |
| | **logs, |
| | } |
| |
|
| | with self.train_log_path.open("a", encoding="utf-8") as f: |
| | f.write(json.dumps(payload, ensure_ascii=False) + "\n") |
| |
|
| | def on_evaluate(self, args, state, control, metrics=None, **kwargs): |
| | if not metrics: |
| | return |
| | eval_loss = metrics.get("eval_loss", None) |
| | ppl = _safe_exp(eval_loss) if eval_loss is not None else None |
| |
|
| | payload = { |
| | "ts": _now_iso(), |
| | "event": "eval", |
| | "step": int(state.global_step), |
| | "epoch": float(state.epoch) if state.epoch is not None else None, |
| | **metrics, |
| | "perplexity": ppl, |
| | } |
| | with self.eval_log_path.open("a", encoding="utf-8") as f: |
| | f.write(json.dumps(payload, ensure_ascii=False) + "\n") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def format_instruction( |
| | example: Dict[str, Any], cfg: Dict[str, Any], tokenizer |
| | ) -> Dict[str, Any]: |
| | """ |
| | Format instruction data for training. |
| | Supports multiple formats: chatml, alpaca, custom templates. |
| | Returns both formatted text and the response start position for loss masking. |
| | """ |
| | data_cfg = cfg["data"] |
| | format_type = data_cfg.get("format_type", "chatml") |
| |
|
| | |
| | input_field = data_cfg.get("input_field", "input") |
| | output_field = data_cfg.get("output_field", "output") |
| | instruction_field = data_cfg.get("instruction_field", "instruction") |
| |
|
| | |
| | instruction = example.get(instruction_field, "") |
| | input_text = example.get(input_field, "") |
| | output_text = example.get(output_field, "") |
| |
|
| | if format_type == "chatml": |
| | |
| | system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.") |
| |
|
| | messages = [] |
| | if system_prompt: |
| | messages.append({"role": "system", "content": system_prompt}) |
| |
|
| | user_content = instruction |
| | if input_text: |
| | user_content = f"{instruction}\n\n{input_text}" |
| | messages.append({"role": "user", "content": user_content}) |
| | messages.append({"role": "assistant", "content": output_text}) |
| |
|
| | |
| | formatted_text = tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=False |
| | ) |
| | |
| | |
| | if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token): |
| | formatted_text += tokenizer.eos_token |
| |
|
| | |
| | |
| | markers = ["<|im_start|>assistant", "<|assistant|>", "Assistant:", "assistant\n"] |
| | response_start_pos = -1 |
| | |
| | for marker in markers: |
| | idx = formatted_text.find(marker) |
| | if idx != -1: |
| | |
| | newline_idx = formatted_text.find("\n", idx) |
| | if newline_idx != -1: |
| | response_start_pos = newline_idx + 1 |
| | break |
| | |
| | |
| | if response_start_pos == -1: |
| | output_idx = formatted_text.find(output_text) |
| | if output_idx != -1: |
| | response_start_pos = output_idx |
| | else: |
| | |
| | response_start_pos = formatted_text.rfind("\n", 0, len(formatted_text) - len(output_text)) + 1 |
| |
|
| | elif format_type == "alpaca": |
| | |
| | if input_text: |
| | prefix = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n" |
| | else: |
| | prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n" |
| |
|
| | formatted_text = prefix + output_text |
| |
|
| | |
| | if tokenizer.eos_token: |
| | formatted_text += tokenizer.eos_token |
| |
|
| | |
| | response_start_pos = len(prefix) |
| |
|
| | elif format_type == "custom": |
| | |
| | template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}") |
| | |
| | |
| | if not instruction: |
| | instruction = data_cfg.get("system_prompt", "") |
| |
|
| | |
| | template_parts = template.split("{output}") |
| | prefix = template_parts[0].format(instruction=instruction, input=input_text) |
| | formatted_text = prefix + output_text |
| |
|
| | |
| | if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token): |
| | formatted_text += tokenizer.eos_token |
| |
|
| | |
| | response_start_pos = len(prefix) |
| | else: |
| | raise ValueError(f"Unsupported format_type: {format_type}") |
| |
|
| | return {"text": formatted_text, "response_start_pos": response_start_pos} |
| |
|
| |
|
| | def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]: |
| | """ |
| | Build datasets for instruction fine-tuning. |
| | """ |
| | data_cfg = cfg["data"] |
| | train_path = data_cfg["train_jsonl"] |
| | eval_path = data_cfg.get("eval_jsonl", None) |
| | split_ratio = float(data_cfg.get("eval_split_ratio", 0.0)) |
| | max_length = int(data_cfg.get("max_length", 2048)) |
| | shuffle = bool(data_cfg.get("shuffle", True)) |
| | num_proc = int(data_cfg.get("num_proc", 4)) |
| |
|
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | |
| | ds = load_dataset("json", data_files={"train": train_path}) |
| |
|
| | if eval_path: |
| | ds_eval = load_dataset("json", data_files={"eval": eval_path}) |
| | dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]}) |
| | else: |
| | if 0.0 < split_ratio < 1.0: |
| | split = ds["train"].train_test_split( |
| | test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)) |
| | ) |
| | dsd = DatasetDict({"train": split["train"], "eval": split["test"]}) |
| | else: |
| | dsd = DatasetDict({"train": ds["train"], "eval": None}) |
| |
|
| | |
| | def format_fn(examples): |
| | formatted_examples = [] |
| | response_start_positions = [] |
| | for i in range(len(examples[list(examples.keys())[0]])): |
| | example = {k: examples[k][i] for k in examples.keys()} |
| | formatted = format_instruction(example, cfg, tokenizer) |
| | formatted_examples.append(formatted["text"]) |
| | response_start_positions.append(formatted["response_start_pos"]) |
| | return { |
| | "text": formatted_examples, |
| | "response_start_pos": response_start_positions |
| | } |
| |
|
| | formatted_train = dsd["train"].map( |
| | format_fn, |
| | batched=True, |
| | num_proc=num_proc, |
| | remove_columns=dsd["train"].column_names, |
| | desc="Formatting train instructions", |
| | ) |
| |
|
| | formatted_eval = None |
| | if dsd["eval"] is not None: |
| | formatted_eval = dsd["eval"].map( |
| | format_fn, |
| | batched=True, |
| | num_proc=num_proc, |
| | remove_columns=dsd["eval"].column_names, |
| | desc="Formatting eval instructions", |
| | ) |
| |
|
| | |
| | def tokenize_and_mask_fn(examples): |
| | tokenized = tokenizer( |
| | examples["text"], |
| | truncation=True, |
| | padding=False, |
| | max_length=max_length, |
| | return_overflowing_tokens=False, |
| | ) |
| | |
| | |
| | labels = [] |
| | attention_masks = [] |
| | |
| | for i in range(len(tokenized["input_ids"])): |
| | input_ids = tokenized["input_ids"][i] |
| | response_start_pos = examples["response_start_pos"][i] |
| | |
| | |
| | full_text = examples["text"][i] |
| | instruction_text = full_text[:response_start_pos] |
| | |
| | |
| | label_ids = [-100] * len(input_ids) |
| | |
| | |
| | |
| | |
| | char_ratio = response_start_pos / max(len(full_text), 1) |
| | response_start_idx = int(len(input_ids) * char_ratio) |
| | |
| | |
| | response_start_idx = max(1, min(response_start_idx, len(input_ids) - 1)) |
| | |
| | |
| | for j in range(response_start_idx, len(input_ids)): |
| | label_ids[j] = input_ids[j] |
| | |
| | |
| | attention_mask = [1] * len(input_ids) |
| | |
| | labels.append(label_ids) |
| | attention_masks.append(attention_mask) |
| | |
| | tokenized["labels"] = labels |
| | tokenized["attention_mask"] = attention_masks |
| | return tokenized |
| |
|
| | tokenized_train = formatted_train.map( |
| | tokenize_and_mask_fn, |
| | batched=True, |
| | num_proc=num_proc, |
| | desc="Tokenizing and masking train", |
| | ) |
| |
|
| | tokenized_eval = None |
| | if formatted_eval is not None: |
| | tokenized_eval = formatted_eval.map( |
| | tokenize_and_mask_fn, |
| | batched=True, |
| | num_proc=num_proc, |
| | desc="Tokenizing and masking eval", |
| | ) |
| |
|
| | if shuffle: |
| | tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42))) |
| |
|
| | return tokenized_train, tokenized_eval |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path): |
| | model_cfg = cfg["model"] |
| | trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) |
| | use_fast = bool(model_cfg.get("tokenizer_use_fast", True)) |
| | device_map = model_cfg.get("device_map", "auto") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained( |
| | str(base_dir), |
| | use_fast=use_fast, |
| | trust_remote_code=trust_remote_code, |
| | ) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) |
| | use_4bit = bool(model_cfg.get("use_4bit", False)) |
| |
|
| | quant_cfg = None |
| | if use_4bit: |
| | quant_cfg = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")), |
| | bnb_4bit_use_double_quant=bool( |
| | model_cfg.get("bnb_4bit_use_double_quant", True) |
| | ), |
| | bnb_4bit_compute_dtype=_dtype_from_str( |
| | model_cfg.get("bnb_4bit_compute_dtype", "bfloat16") |
| | ), |
| | ) |
| |
|
| | attn_impl = _choose_attn_impl(cfg) |
| |
|
| | |
| | try: |
| | config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True) |
| | model_type = config.model_type |
| | architectures = getattr(config, 'architectures', []) |
| | |
| | |
| | if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]): |
| | print(f"[info] Detected Mistral3 model architecture, loading with specific class") |
| | from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration |
| | |
| | try: |
| | model = Mistral3ForConditionalGeneration.from_pretrained( |
| | str(base_dir), |
| | config=config, |
| | device_map=device_map, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=(torch_dtype if not use_4bit else None), |
| | quantization_config=quant_cfg, |
| | attn_implementation=attn_impl, |
| | ) |
| | except Exception as e: |
| | if attn_impl is not None: |
| | print(f"[warn] attn_implementation='{attn_impl}' failed: {e}") |
| | print("[warn] Falling back to default attention implementation.") |
| | model = Mistral3ForConditionalGeneration.from_pretrained( |
| | str(base_dir), |
| | config=config, |
| | device_map=device_map, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=(torch_dtype if not use_4bit else None), |
| | quantization_config=quant_cfg, |
| | ) |
| | else: |
| | raise e |
| | else: |
| | |
| | try: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | str(base_dir), |
| | device_map=device_map, |
| | trust_remote_code=True, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=(torch_dtype if not use_4bit else None), |
| | quantization_config=quant_cfg, |
| | attn_implementation=attn_impl, |
| | ) |
| | except Exception as e: |
| | if attn_impl is not None: |
| | print(f"[warn] attn_implementation='{attn_impl}' failed: {e}") |
| | print("[warn] Falling back to default attention implementation.") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | str(base_dir), |
| | device_map=device_map, |
| | trust_remote_code=True, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=(torch_dtype if not use_4bit else None), |
| | quantization_config=quant_cfg, |
| | ) |
| | else: |
| | raise e |
| | except Exception as e: |
| | print(f"[error] Failed to load model: {e}") |
| | raise e |
| |
|
| | |
| | print("[info] Ensuring all parameters are materialized...") |
| | meta_params = [] |
| | for name, param in model.named_parameters(): |
| | if param.device.type == 'meta': |
| | meta_params.append(name) |
| | |
| | if meta_params: |
| | print(f"[warn] Found {len(meta_params)} parameters on meta device") |
| | |
| | if hasattr(model, 'vision_tower'): |
| | print("[info] Freezing vision tower for text-only training") |
| | for param in model.vision_tower.parameters(): |
| | param.requires_grad = False |
| | |
| | return model, tokenizer |
| |
|
| |
|
| | def apply_peft(cfg: Dict[str, Any], model): |
| | peft_cfg = cfg["peft"] |
| | model_cfg = cfg["model"] |
| | tr_cfg = cfg["train"] |
| |
|
| | if not bool(peft_cfg.get("enabled", True)): |
| | return model, None |
| |
|
| | use_4bit = bool(model_cfg.get("use_4bit", False)) |
| | gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True)) |
| |
|
| | |
| | if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): |
| | if hasattr(model, 'vision_tower'): |
| | print("[info] Disabling gradient checkpointing for vision tower") |
| | |
| | if hasattr(model, 'language_model'): |
| | model.language_model.gradient_checkpointing_enable() |
| | elif hasattr(model, 'lm_head'): |
| | model.gradient_checkpointing_enable() |
| | else: |
| | model.gradient_checkpointing_enable() |
| | |
| | if hasattr(model, "config"): |
| | model.config.use_cache = False |
| |
|
| | if use_4bit: |
| | model = prepare_model_for_kbit_training( |
| | model, |
| | use_gradient_checkpointing=gradient_checkpointing, |
| | ) |
| |
|
| | target_modules = peft_cfg.get("target_modules", "auto") |
| | if target_modules == "auto": |
| | target_modules = _infer_target_modules(model) |
| | |
| | |
| | if hasattr(model, 'vision_tower') and isinstance(target_modules, list): |
| | print(f"[info] Filtering target modules to exclude vision tower") |
| | |
| | target_modules = [m for m in target_modules if 'vision' not in m.lower()] |
| | print(f"[info] LoRA target modules: {target_modules}") |
| |
|
| | lora_config = LoraConfig( |
| | r=int(peft_cfg.get("r", 16)), |
| | lora_alpha=int(peft_cfg.get("lora_alpha", 32)), |
| | lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)), |
| | bias=str(peft_cfg.get("bias", "none")), |
| | task_type="CAUSAL_LM", |
| | target_modules=target_modules, |
| | modules_to_save=None, |
| | ) |
| | model = get_peft_model(model, lora_config) |
| | return model, lora_config |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def merge_adapter( |
| | cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path |
| | ): |
| | print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---") |
| |
|
| | model_cfg = cfg["model"] |
| | merge_cfg = cfg.get("merge", {}) |
| | trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) |
| |
|
| | merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16")) |
| | max_shard_size = str(merge_cfg.get("max_shard_size", "2GB")) |
| |
|
| | base = AutoModelForCausalLM.from_pretrained( |
| | str(base_dir), |
| | torch_dtype=merged_dtype, |
| | device_map="cpu", |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=trust_remote_code, |
| | ) |
| |
|
| | merged = PeftModel.from_pretrained(base, str(adapter_dir)) |
| | merged = merged.merge_and_unload() |
| |
|
| | _ensure_dir(final_dir) |
| | merged.save_pretrained( |
| | str(final_dir), safe_serialization=True, max_shard_size=max_shard_size |
| | ) |
| |
|
| | tok = AutoTokenizer.from_pretrained( |
| | str(base_dir), trust_remote_code=trust_remote_code |
| | ) |
| | if tok.pad_token is None: |
| | tok.pad_token = tok.eos_token |
| | tok.save_pretrained(str(final_dir)) |
| |
|
| | print("--- Merge complete ---") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def main(): |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--config", required=True, help="Path to YAML config") |
| | ap.add_argument( |
| | "--merge-only", action="store_true", help="Skip training, just merge adapter" |
| | ) |
| | args = ap.parse_args() |
| |
|
| | with open(args.config, "r", encoding="utf-8") as f: |
| | cfg = yaml.safe_load(f) |
| |
|
| | run_dir = _ensure_dir(Path(cfg["run"]["run_dir"])) |
| | _ensure_dir(run_dir / "logs") |
| |
|
| | with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f: |
| | yaml.safe_dump(cfg, f, sort_keys=False) |
| |
|
| | model_cfg = cfg["model"] |
| | repo_id = str(model_cfg["repo_id"]).strip() |
| | repo_path = Path(repo_id) |
| |
|
| | |
| | if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path): |
| | base_dir = repo_path |
| | print(f"Using local model at: {base_dir}") |
| | elif repo_path.exists() and repo_path.is_dir(): |
| | raise ValueError( |
| | f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}" |
| | ) |
| | else: |
| | |
| | base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model")) |
| | if not _looks_like_model_dir(base_dir): |
| | print(f"Base model not found at {base_dir}, downloading from {repo_id} ...") |
| | snapshot_download( |
| | repo_id=repo_id, |
| | revision=model_cfg.get("revision", None), |
| | local_dir=str(base_dir), |
| | local_dir_use_symlinks=False, |
| | ) |
| |
|
| | ckpt_dir = _ensure_dir(run_dir / "checkpoints") |
| | best_adapter_dir = _ensure_dir(run_dir / "best_adapter") |
| |
|
| | merge_cfg = cfg.get("merge", {}) or {} |
| | if merge_cfg.get("output_dir"): |
| | od = Path(str(merge_cfg["output_dir"])) |
| | final_dir = od if od.is_absolute() else (run_dir / od) |
| | else: |
| | final_dir = run_dir / "final_model" |
| |
|
| | |
| | if args.merge_only: |
| | if not _looks_like_model_dir(best_adapter_dir): |
| | raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}") |
| | merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) |
| | return |
| |
|
| | |
| | wandb_run = setup_wandb(cfg, run_dir) |
| |
|
| | |
| | set_seed(int(cfg["run"].get("seed", 42))) |
| |
|
| | model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir) |
| | model, _ = apply_peft(cfg, model) |
| |
|
| | train_ds, eval_ds = build_datasets(cfg, tokenizer) |
| |
|
| | tr_cfg = cfg["train"] |
| |
|
| | dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) |
| | use_fp16 = dtype == torch.float16 |
| | use_bf16 = dtype == torch.bfloat16 |
| |
|
| | max_steps = int(tr_cfg.get("max_steps", 0)) |
| | num_train_epochs = float(tr_cfg.get("num_train_epochs", 1)) |
| |
|
| | |
| | ta_params = inspect.signature(TrainingArguments.__init__).parameters |
| | eval_key = ( |
| | "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy" |
| | ) |
| |
|
| | |
| | report_to = [] |
| | if wandb_run is not None: |
| | report_to.append("wandb") |
| |
|
| | ta_kwargs = dict( |
| | output_dir=str(ckpt_dir), |
| | max_steps=max_steps if max_steps > 0 else -1, |
| | num_train_epochs=num_train_epochs, |
| | per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)), |
| | per_device_eval_batch_size=int( |
| | tr_cfg.get( |
| | "per_device_eval_batch_size", |
| | tr_cfg.get("per_device_train_batch_size", 1), |
| | ) |
| | ), |
| | gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)), |
| | learning_rate=float(tr_cfg.get("learning_rate", 2e-5)), |
| | weight_decay=float(tr_cfg.get("weight_decay", 0.0)), |
| | warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), |
| | lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")), |
| | optim=str( |
| | tr_cfg.get( |
| | "optim", |
| | ( |
| | "paged_adamw_8bit" |
| | if bool(model_cfg.get("use_4bit", False)) |
| | else "adamw_torch" |
| | ), |
| | ) |
| | ), |
| | max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)), |
| | logging_steps=int(tr_cfg.get("logging_steps", 10)), |
| | save_strategy=str(tr_cfg.get("save_strategy", "steps")), |
| | save_steps=int(tr_cfg.get("save_steps", 200)), |
| | save_total_limit=int(tr_cfg.get("save_total_limit", 3)), |
| | eval_steps=int(tr_cfg.get("eval_steps", 200)), |
| | load_best_model_at_end=( |
| | bool(tr_cfg.get("load_best_model_at_end", True)) |
| | if eval_ds is not None |
| | else False |
| | ), |
| | metric_for_best_model="eval_loss", |
| | greater_is_better=False, |
| | fp16=use_fp16, |
| | bf16=use_bf16, |
| | report_to=report_to, |
| | remove_unused_columns=False, |
| | ) |
| |
|
| | |
| | ta_kwargs[eval_key] = str( |
| | tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no") |
| | ) |
| |
|
| | training_args = TrainingArguments(**ta_kwargs) |
| |
|
| | |
| | callbacks = [JsonlLoggerCallback(run_dir)] |
| | |
| | |
| | early_stopping_cfg = tr_cfg.get("early_stopping", {}) |
| | if early_stopping_cfg.get("enabled", False) and eval_ds is not None: |
| | early_stopping_callback = EarlyStoppingCallback( |
| | early_stopping_patience=int(early_stopping_cfg.get("patience", 3)), |
| | early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)), |
| | ) |
| | callbacks.append(early_stopping_callback) |
| | print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, " |
| | f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}") |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_ds, |
| | eval_dataset=eval_ds, |
| | data_collator=default_data_collator, |
| | callbacks=callbacks, |
| | ) |
| |
|
| | |
| | resume_from = tr_cfg.get("resume_from_checkpoint", None) |
| | if resume_from == "auto": |
| | last = get_last_checkpoint(str(ckpt_dir)) |
| | resume_from = last if last else None |
| | if resume_from: |
| | print(f"Resuming from {resume_from}") |
| |
|
| | print("Starting instruction fine-tuning...") |
| | trainer.train(resume_from_checkpoint=resume_from) |
| |
|
| | trainer.save_model(str(best_adapter_dir)) |
| | print(f"Saved best adapter -> {best_adapter_dir}") |
| |
|
| | if eval_ds is not None: |
| | metrics = trainer.evaluate() |
| | eval_loss = metrics.get("eval_loss", None) |
| | metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None |
| | with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f: |
| | json.dump(metrics, f, indent=2) |
| | print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}") |
| |
|
| | if bool(cfg.get("merge", {}).get("enabled", False)): |
| | del trainer, model |
| | torch.cuda.empty_cache() |
| | merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) |
| | else: |
| | print("Merge disabled. Run with --merge-only later if needed.") |
| |
|
| | |
| | finish_wandb() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|