| |
| """ |
| FunctionGemma SFT fine-tuning script. |
| |
| Runs TRL SFTTrainer for FunctionGemma with two modes: |
| 1) LoRA (recommended): faster, lower memory, less overfit |
| 2) Full-parameter: higher cost, maximal capacity |
| |
| Usage: |
| # LoRA (default) |
| python -m src.train \ |
| --model_path /path/to/model \ |
| --dataset_path ./data/training_data.json \ |
| --bf16 |
| |
| # Full-parameter |
| python -m src.train \ |
| --model_path /path/to/model \ |
| --dataset_path ./data/training_data.json \ |
| --no-use-lora \ |
| --bf16 |
| """ |
|
|
| import os |
| import json |
| import argparse |
| import logging |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from datasets import Dataset, load_dataset |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| TrainingArguments, |
| BitsAndBytesConfig, |
| ) |
| from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training |
| from trl import SFTTrainer, SFTConfig |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json" |
| DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs" |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_args(): |
| """Parse CLI arguments.""" |
| parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)") |
| |
| |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| default="google/functiongemma-270m-it", |
| help="Model path or HF model id" |
| ) |
| parser.add_argument( |
| "--tokenizer_path", |
| type=str, |
| default=None, |
| help="Tokenizer path (defaults to model_path)" |
| ) |
| |
| |
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default=str(DEFAULT_DATA_PATH), |
| help="Training dataset path" |
| ) |
| parser.add_argument( |
| "--val_split", |
| type=float, |
| default=0.1, |
| help="Validation split ratio" |
| ) |
| |
| |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default=str(DEFAULT_OUTPUT_DIR), |
| help="Root output directory" |
| ) |
| parser.add_argument( |
| "--run_name", |
| type=str, |
| default=None, |
| help="Run name for logging and saving" |
| ) |
| |
| |
| parser.add_argument( |
| "--use_lora", |
| action="store_true", |
| default=True, |
| help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune" |
| ) |
| parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune") |
| |
| |
| parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank") |
| parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") |
| parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout") |
| parser.add_argument( |
| "--target_modules", |
| type=str, |
| nargs="+", |
| default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| help="Target modules for LoRA" |
| ) |
| |
| |
| parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)") |
| parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)") |
| parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device") |
| parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size") |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps") |
| parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") |
| parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") |
| parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)") |
| parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)") |
| parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)") |
| |
| |
| parser.add_argument("--bf16", action="store_true", help="Use BF16") |
| parser.add_argument("--fp16", action="store_true", help="Use FP16") |
| parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)") |
| parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant") |
| parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2") |
| parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing") |
| |
| |
| parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps") |
| parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps") |
| parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps") |
| parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep") |
| |
| |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint") |
| parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub") |
| parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id") |
| |
| return parser.parse_args() |
|
|
|
|
| def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1): |
| """Load and normalize dataset structure for SFT.""" |
| logger.info(f"Loading dataset: {dataset_path}") |
| |
| |
| with open(dataset_path, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| logger.info(f"Dataset size: {len(data)} samples") |
| |
| |
| |
| processed_data = [] |
| for idx, item in enumerate(data): |
| if 'input' in item and 'messages' in item['input']: |
| |
| messages = json.loads(json.dumps(item['input']['messages'])) |
| |
| |
| for msg in messages: |
| if 'tool_calls' in msg and msg['tool_calls']: |
| for tc in msg['tool_calls']: |
| if 'function' in tc and 'arguments' in tc['function']: |
| args = tc['function']['arguments'] |
| |
| if not isinstance(args, str): |
| tc['function']['arguments'] = json.dumps(args) |
| |
| |
| if 'expected' in item and item['expected']: |
| expected = item['expected'] |
| |
| if messages[-1]['role'] != 'assistant': |
| |
| function_name = expected.get('function_name') |
| arguments = expected.get('arguments') |
| response = expected.get('response', '') |
| |
| if function_name is not None and arguments is not None: |
| |
| arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments) |
| |
| assistant_msg = { |
| "role": "assistant", |
| "content": None, |
| "tool_calls": [{ |
| "id": f"call_{hash(function_name + arguments_str) % 1000000}", |
| "type": "function", |
| "function": { |
| "name": function_name, |
| "arguments": arguments_str |
| } |
| }] |
| } |
| messages.append(assistant_msg) |
| logger.debug(f"Added assistant tool_calls: {function_name}") |
| elif function_name is None and arguments is None and response: |
| |
| assistant_msg = { |
| "role": "assistant", |
| "content": response |
| } |
| messages.append(assistant_msg) |
| logger.debug(f"Added assistant refusal response: {response[:50]}") |
| else: |
| logger.warning(f"Unknown expected format: {expected}") |
| |
| processed_item = { |
| 'messages': messages |
| } |
| |
| |
| if 'tools' in item['input']: |
| processed_item['tools'] = item['input']['tools'] |
| |
| |
| if 'id' in item: |
| processed_item['id'] = item['id'] |
| |
| |
| for msg in processed_item['messages']: |
| if 'tool_calls' in msg and msg['tool_calls']: |
| for tc in msg['tool_calls']: |
| if 'function' in tc and 'arguments' in tc['function']: |
| if not isinstance(tc['function']['arguments'], str): |
| logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}") |
| tc['function']['arguments'] = json.dumps(tc['function']['arguments']) |
| |
| processed_data.append(processed_item) |
| |
| elif 'messages' in item: |
| |
| messages = json.loads(json.dumps(item['messages'])) |
| for msg in messages: |
| if 'tool_calls' in msg and msg['tool_calls']: |
| for tc in msg['tool_calls']: |
| if 'function' in tc and 'arguments' in tc['function']: |
| if not isinstance(tc['function']['arguments'], str): |
| tc['function']['arguments'] = json.dumps(tc['function']['arguments']) |
| item_copy = dict(item) |
| item_copy['messages'] = messages |
| processed_data.append(item_copy) |
| else: |
| logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}") |
| |
| logger.info(f"Processed dataset size: {len(processed_data)}") |
| |
| |
| tool_calls_count = 0 |
| for item in processed_data: |
| for msg in item['messages']: |
| if 'tool_calls' in msg and msg['tool_calls']: |
| tool_calls_count += 1 |
| for tc in msg['tool_calls']: |
| if 'function' in tc and 'arguments' in tc['function']: |
| if not isinstance(tc['function']['arguments'], str): |
| logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}") |
| logger.info(f"Messages containing tool_calls: {tool_calls_count}") |
| |
| |
| dataset = Dataset.from_list(processed_data) |
| |
| |
| if val_split > 0: |
| dataset = dataset.train_test_split(test_size=val_split, seed=42) |
| train_dataset = dataset['train'] |
| eval_dataset = dataset['test'] |
| logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") |
| else: |
| train_dataset = dataset |
| eval_dataset = None |
| logger.info(f"Train: {len(train_dataset)}, no eval split") |
| |
| return train_dataset, eval_dataset |
|
|
|
|
| def get_quantization_config(use_4bit: bool, use_8bit: bool): |
| """Build quantization config if requested.""" |
| if use_4bit: |
| logger.info("Using 4-bit quantization (QLoRA)") |
| return BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| elif use_8bit: |
| logger.info("Using 8-bit quantization") |
| return BitsAndBytesConfig( |
| load_in_8bit=True, |
| ) |
| return None |
|
|
|
|
| def load_model_and_tokenizer(args): |
| """Load model and tokenizer.""" |
| logger.info(f"Loading model: {args.model_path}") |
| |
| tokenizer_path = args.tokenizer_path or args.model_path |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path, |
| trust_remote_code=True, |
| padding_side="right", |
| ) |
| |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
| |
| |
| quantization_config = get_quantization_config(args.use_4bit, args.use_8bit) |
| |
| |
| model_kwargs = { |
| "trust_remote_code": True, |
| "device_map": "auto", |
| } |
| |
| if quantization_config: |
| model_kwargs["quantization_config"] = quantization_config |
| |
| |
| if args.bf16 and not (args.use_4bit or args.use_8bit): |
| model_kwargs["torch_dtype"] = torch.bfloat16 |
| elif args.fp16 and not (args.use_4bit or args.use_8bit): |
| model_kwargs["torch_dtype"] = torch.float16 |
| |
| |
| if args.use_flash_attention: |
| model_kwargs["attn_implementation"] = "flash_attention_2" |
| logger.info("Using Flash Attention 2") |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_path, |
| **model_kwargs |
| ) |
| |
| |
| if args.use_4bit or args.use_8bit: |
| model = prepare_model_for_kbit_training(model) |
| |
| |
| if args.gradient_checkpointing: |
| model.gradient_checkpointing_enable() |
| logger.info("Enabled gradient checkpointing") |
| |
| logger.info(f"Model parameters: {model.num_parameters():,}") |
| |
| return model, tokenizer |
|
|
|
|
| def get_lora_config(args): |
| """Build LoRA config.""" |
| logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}") |
| logger.info(f"Target modules: {args.target_modules}") |
| |
| return LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| target_modules=args.target_modules, |
| bias="none", |
| task_type=TaskType.CAUSAL_LM, |
| ) |
|
|
|
|
| def formatting_func(example): |
| """ |
| Format function: pass data through for SFTTrainer. |
| |
| Dataset format: |
| { |
| "messages": [ |
| {"role": "developer", "content": "..."}, |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} |
| ], |
| "tools": [...] |
| } |
| """ |
| |
| return example |
|
|
|
|
| def main(): |
| args = parse_args() |
| |
| |
| if args.run_name is None: |
| args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| |
| |
| output_dir = os.path.join(args.output_dir, args.run_name) |
| os.makedirs(output_dir, exist_ok=True) |
| |
| logger.info("=" * 60) |
| logger.info("FunctionGemma SFT LoRA training") |
| logger.info("=" * 60) |
| logger.info(f"Output dir: {output_dir}") |
| |
| |
| config_path = os.path.join(output_dir, "training_config.json") |
| with open(config_path, 'w') as f: |
| json.dump(vars(args), f, indent=2) |
| logger.info(f"Config saved to: {config_path}") |
| |
| |
| train_dataset, eval_dataset = load_and_prepare_dataset( |
| args.dataset_path, |
| args.val_split |
| ) |
| |
| |
| model, tokenizer = load_model_and_tokenizer(args) |
| |
| |
| if args.use_lora: |
| logger.info("=" * 60) |
| logger.info("LoRA fine-tuning mode") |
| logger.info("=" * 60) |
| lora_config = get_lora_config(args) |
| else: |
| logger.info("=" * 60) |
| logger.info("Full-parameter fine-tuning mode") |
| logger.info("Warning: full fine-tuning needs more memory and time!") |
| logger.info("=" * 60) |
| lora_config = None |
| |
| |
| training_args = SFTConfig( |
| output_dir=output_dir, |
| run_name=args.run_name, |
| |
| |
| max_length=args.max_seq_length, |
| packing=False, |
| |
| |
| num_train_epochs=args.num_train_epochs, |
| max_steps=args.max_steps, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| per_device_eval_batch_size=args.per_device_eval_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| |
| |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| warmup_ratio=args.warmup_ratio, |
| lr_scheduler_type=args.lr_scheduler_type, |
| optim="adamw_torch_fused", |
| |
| |
| bf16=args.bf16, |
| fp16=args.fp16, |
| |
| |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| eval_steps=args.eval_steps if eval_dataset else None, |
| eval_strategy="steps" if eval_dataset else "no", |
| save_total_limit=args.save_total_limit, |
| load_best_model_at_end=True if eval_dataset else False, |
| |
| |
| seed=args.seed, |
| report_to=["tensorboard"], |
| |
| |
| push_to_hub=args.push_to_hub, |
| hub_model_id=args.hub_model_id, |
| |
| |
| gradient_checkpointing=args.gradient_checkpointing, |
| gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None, |
| ) |
| |
| |
| |
| trainer = SFTTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| processing_class=tokenizer, |
| peft_config=lora_config, |
| ) |
| |
| |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0 |
| |
| logger.info("=" * 60) |
| logger.info("Model parameter stats:") |
| logger.info(f" Total params: {total_params:,}") |
| logger.info(f" Trainable params: {trainable_params:,}") |
| logger.info(f" Trainable ratio: {trainable_percentage:.2f}%") |
| logger.info(f" Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}") |
| logger.info("=" * 60) |
| |
| |
| logger.info("Start training...") |
| |
| if args.resume_from_checkpoint: |
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| else: |
| trainer.train() |
| |
| |
| logger.info("Saving final model...") |
| final_model_path = os.path.join(output_dir, "final_model") |
| trainer.save_model(final_model_path) |
| tokenizer.save_pretrained(final_model_path) |
| |
| logger.info("=" * 60) |
| logger.info("Training done.") |
| logger.info(f"Model saved at: {final_model_path}") |
| |
| if args.use_lora: |
| |
| lora_path = os.path.join(output_dir, "lora_adapter") |
| model.save_pretrained(lora_path) |
| tokenizer.save_pretrained(lora_path) |
| logger.info(f"LoRA adapter saved to: {lora_path}") |
| logger.info("") |
| logger.info("Usage:") |
| logger.info(f" 1. LoRA adapter: {lora_path}") |
| logger.info(f" 2. Merge adapters with your base model before inference") |
| else: |
| |
| logger.info("") |
| logger.info("Usage:") |
| logger.info(f" Use model directly from: {final_model_path}") |
| |
| logger.info("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|