| | """ |
| | Helion-OSC Training Script |
| | Fine-tuning and training utilities for Helion-OSC model |
| | """ |
| |
|
| | import os |
| | import torch |
| | import json |
| | import logging |
| | from typing import Optional, Dict, Any, List |
| | from dataclasses import dataclass, field |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | TrainingArguments, |
| | Trainer, |
| | DataCollatorForLanguageModeling, |
| | EarlyStoppingCallback |
| | ) |
| | from datasets import load_dataset, Dataset, DatasetDict |
| | from peft import ( |
| | LoraConfig, |
| | get_peft_model, |
| | prepare_model_for_kbit_training, |
| | TaskType |
| | ) |
| | import wandb |
| | from torch.utils.data import DataLoader |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class ModelArguments: |
| | """Arguments for model configuration""" |
| | model_name_or_path: str = field( |
| | default="DeepXR/Helion-OSC", |
| | metadata={"help": "Path to pretrained model or model identifier"} |
| | ) |
| | use_lora: bool = field( |
| | default=True, |
| | metadata={"help": "Whether to use LoRA for efficient fine-tuning"} |
| | ) |
| | lora_r: int = field( |
| | default=16, |
| | metadata={"help": "LoRA attention dimension"} |
| | ) |
| | lora_alpha: int = field( |
| | default=32, |
| | metadata={"help": "LoRA alpha parameter"} |
| | ) |
| | lora_dropout: float = field( |
| | default=0.05, |
| | metadata={"help": "LoRA dropout probability"} |
| | ) |
| | load_in_8bit: bool = field( |
| | default=False, |
| | metadata={"help": "Load model in 8-bit precision"} |
| | ) |
| | load_in_4bit: bool = field( |
| | default=False, |
| | metadata={"help": "Load model in 4-bit precision"} |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class DataArguments: |
| | """Arguments for data processing""" |
| | dataset_name: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Name of the dataset to use"} |
| | ) |
| | dataset_path: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Path to local dataset"} |
| | ) |
| | train_file: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Path to training data file"} |
| | ) |
| | validation_file: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Path to validation data file"} |
| | ) |
| | max_seq_length: int = field( |
| | default=2048, |
| | metadata={"help": "Maximum sequence length"} |
| | ) |
| | preprocessing_num_workers: int = field( |
| | default=4, |
| | metadata={"help": "Number of workers for preprocessing"} |
| | ) |
| |
|
| |
|
| | class HelionOSCTrainer: |
| | """Trainer class for Helion-OSC model""" |
| | |
| | def __init__( |
| | self, |
| | model_args: ModelArguments, |
| | data_args: DataArguments, |
| | training_args: TrainingArguments |
| | ): |
| | self.model_args = model_args |
| | self.data_args = data_args |
| | self.training_args = training_args |
| | |
| | |
| | self.tokenizer = self._load_tokenizer() |
| | |
| | |
| | self.model = self._load_model() |
| | |
| | |
| | self.datasets = self._load_datasets() |
| | |
| | logger.info("Trainer initialized successfully") |
| | |
| | def _load_tokenizer(self): |
| | """Load and configure tokenizer""" |
| | logger.info("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | self.model_args.model_name_or_path, |
| | trust_remote_code=True, |
| | padding_side="right" |
| | ) |
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | return tokenizer |
| | |
| | def _load_model(self): |
| | """Load and configure model""" |
| | logger.info("Loading model...") |
| | |
| | model_kwargs = { |
| | "trust_remote_code": True, |
| | "low_cpu_mem_usage": True |
| | } |
| | |
| | |
| | if self.model_args.load_in_8bit: |
| | model_kwargs["load_in_8bit"] = True |
| | elif self.model_args.load_in_4bit: |
| | model_kwargs["load_in_4bit"] = True |
| | model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 |
| | model_kwargs["bnb_4bit_use_double_quant"] = True |
| | model_kwargs["bnb_4bit_quant_type"] = "nf4" |
| | else: |
| | model_kwargs["torch_dtype"] = torch.bfloat16 |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | self.model_args.model_name_or_path, |
| | **model_kwargs |
| | ) |
| | |
| | |
| | if self.model_args.use_lora: |
| | logger.info("Applying LoRA configuration...") |
| | |
| | if self.model_args.load_in_8bit or self.model_args.load_in_4bit: |
| | model = prepare_model_for_kbit_training(model) |
| | |
| | lora_config = LoraConfig( |
| | r=self.model_args.lora_r, |
| | lora_alpha=self.model_args.lora_alpha, |
| | target_modules=[ |
| | "q_proj", |
| | "k_proj", |
| | "v_proj", |
| | "o_proj", |
| | "gate_proj", |
| | "up_proj", |
| | "down_proj" |
| | ], |
| | lora_dropout=self.model_args.lora_dropout, |
| | bias="none", |
| | task_type=TaskType.CAUSAL_LM |
| | ) |
| | |
| | model = get_peft_model(model, lora_config) |
| | model.print_trainable_parameters() |
| | |
| | return model |
| | |
| | def _load_datasets(self) -> DatasetDict: |
| | """Load and preprocess datasets""" |
| | logger.info("Loading datasets...") |
| | |
| | if self.data_args.dataset_name: |
| | |
| | datasets = load_dataset(self.data_args.dataset_name) |
| | elif self.data_args.train_file: |
| | |
| | data_files = {"train": self.data_args.train_file} |
| | if self.data_args.validation_file: |
| | data_files["validation"] = self.data_args.validation_file |
| | |
| | datasets = load_dataset("json", data_files=data_files) |
| | else: |
| | raise ValueError("Must provide either dataset_name or train_file") |
| | |
| | |
| | logger.info("Preprocessing datasets...") |
| | datasets = datasets.map( |
| | self._preprocess_function, |
| | batched=True, |
| | num_proc=self.data_args.preprocessing_num_workers, |
| | remove_columns=datasets["train"].column_names, |
| | desc="Preprocessing datasets" |
| | ) |
| | |
| | return datasets |
| | |
| | def _preprocess_function(self, examples): |
| | """Preprocess examples for training""" |
| | |
| | if "prompt" in examples and "completion" in examples: |
| | |
| | texts = [ |
| | f"{prompt}\n{completion}" |
| | for prompt, completion in zip(examples["prompt"], examples["completion"]) |
| | ] |
| | elif "text" in examples: |
| | |
| | texts = examples["text"] |
| | else: |
| | raise ValueError("Dataset must contain 'text' or 'prompt'/'completion' columns") |
| | |
| | |
| | tokenized = self.tokenizer( |
| | texts, |
| | truncation=True, |
| | max_length=self.data_args.max_seq_length, |
| | padding="max_length", |
| | return_tensors=None |
| | ) |
| | |
| | |
| | tokenized["labels"] = tokenized["input_ids"].copy() |
| | |
| | return tokenized |
| | |
| | def train(self): |
| | """Train the model""" |
| | logger.info("Starting training...") |
| | |
| | |
| | data_collator = DataCollatorForLanguageModeling( |
| | tokenizer=self.tokenizer, |
| | mlm=False |
| | ) |
| | |
| | |
| | trainer = Trainer( |
| | model=self.model, |
| | args=self.training_args, |
| | train_dataset=self.datasets["train"], |
| | eval_dataset=self.datasets.get("validation"), |
| | tokenizer=self.tokenizer, |
| | data_collator=data_collator, |
| | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] |
| | ) |
| | |
| | |
| | train_result = trainer.train() |
| | |
| | |
| | trainer.save_model() |
| | |
| | |
| | metrics = train_result.metrics |
| | trainer.log_metrics("train", metrics) |
| | trainer.save_metrics("train", metrics) |
| | trainer.save_state() |
| | |
| | logger.info("Training completed successfully!") |
| | |
| | return trainer, metrics |
| | |
| | def evaluate(self, trainer: Optional[Trainer] = None): |
| | """Evaluate the model""" |
| | if trainer is None: |
| | data_collator = DataCollatorForLanguageModeling( |
| | tokenizer=self.tokenizer, |
| | mlm=False |
| | ) |
| | |
| | trainer = Trainer( |
| | model=self.model, |
| | args=self.training_args, |
| | eval_dataset=self.datasets.get("validation"), |
| | tokenizer=self.tokenizer, |
| | data_collator=data_collator |
| | ) |
| | |
| | logger.info("Evaluating model...") |
| | metrics = trainer.evaluate() |
| | |
| | trainer.log_metrics("eval", metrics) |
| | trainer.save_metrics("eval", metrics) |
| | |
| | return metrics |
| |
|
| |
|
| | def create_code_dataset(examples: List[Dict[str, str]]) -> Dataset: |
| | """ |
| | Create a dataset from code examples |
| | |
| | Args: |
| | examples: List of dictionaries with 'prompt' and 'completion' keys |
| | |
| | Returns: |
| | Dataset object |
| | """ |
| | return Dataset.from_dict({ |
| | "prompt": [ex["prompt"] for ex in examples], |
| | "completion": [ex["completion"] for ex in examples] |
| | }) |
| |
|
| |
|
| | def create_math_dataset(examples: List[Dict[str, str]]) -> Dataset: |
| | """ |
| | Create a dataset from math examples |
| | |
| | Args: |
| | examples: List of dictionaries with 'problem' and 'solution' keys |
| | |
| | Returns: |
| | Dataset object |
| | """ |
| | return Dataset.from_dict({ |
| | "prompt": [f"Problem: {ex['problem']}\nSolution:" for ex in examples], |
| | "completion": [ex["solution"] for ex in examples] |
| | }) |
| |
|
| |
|
| | def main(): |
| | """Main training script""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Train Helion-OSC model") |
| | |
| | |
| | parser.add_argument("--model_name_or_path", type=str, default="DeepXR/Helion-OSC") |
| | parser.add_argument("--use_lora", action="store_true", default=True) |
| | parser.add_argument("--lora_r", type=int, default=16) |
| | parser.add_argument("--lora_alpha", type=int, default=32) |
| | parser.add_argument("--lora_dropout", type=float, default=0.05) |
| | parser.add_argument("--load_in_8bit", action="store_true") |
| | parser.add_argument("--load_in_4bit", action="store_true") |
| | |
| | |
| | parser.add_argument("--dataset_name", type=str, default=None) |
| | parser.add_argument("--dataset_path", type=str, default=None) |
| | parser.add_argument("--train_file", type=str, required=True) |
| | parser.add_argument("--validation_file", type=str, default=None) |
| | parser.add_argument("--max_seq_length", type=int, default=2048) |
| | parser.add_argument("--preprocessing_num_workers", type=int, default=4) |
| | |
| | |
| | parser.add_argument("--output_dir", type=str, required=True) |
| | parser.add_argument("--num_train_epochs", type=int, default=3) |
| | parser.add_argument("--per_device_train_batch_size", type=int, default=4) |
| | parser.add_argument("--per_device_eval_batch_size", type=int, default=4) |
| | parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| | parser.add_argument("--learning_rate", type=float, default=2e-5) |
| | parser.add_argument("--warmup_steps", type=int, default=100) |
| | parser.add_argument("--logging_steps", type=int, default=10) |
| | parser.add_argument("--save_steps", type=int, default=500) |
| | parser.add_argument("--eval_steps", type=int, default=500) |
| | parser.add_argument("--save_total_limit", type=int, default=3) |
| | parser.add_argument("--fp16", action="store_true") |
| | parser.add_argument("--bf16", action="store_true") |
| | parser.add_argument("--gradient_checkpointing", action="store_true") |
| | parser.add_argument("--use_wandb", action="store_true") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | model_args = ModelArguments( |
| | model_name_or_path=args.model_name_or_path, |
| | use_lora=args.use_lora, |
| | lora_r=args.lora_r, |
| | lora_alpha=args.lora_alpha, |
| | lora_dropout=args.lora_dropout, |
| | load_in_8bit=args.load_in_8bit, |
| | load_in_4bit=args.load_in_4bit |
| | ) |
| | |
| | data_args = DataArguments( |
| | dataset_name=args.dataset_name, |
| | dataset_path=args.dataset_path, |
| | train_file=args.train_file, |
| | validation_file=args.validation_file, |
| | max_seq_length=args.max_seq_length, |
| | preprocessing_num_workers=args.preprocessing_num_workers |
| | ) |
| | |
| | training_args = TrainingArguments( |
| | output_dir=args.output_dir, |
| | num_train_epochs=args.num_train_epochs, |
| | per_device_train_batch_size=args.per_device_train_batch_size, |
| | per_device_eval_batch_size=args.per_device_eval_batch_size, |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | learning_rate=args.learning_rate, |
| | warmup_steps=args.warmup_steps, |
| | logging_steps=args.logging_steps, |
| | save_steps=args.save_steps, |
| | eval_steps=args.eval_steps, |
| | save_total_limit=args.save_total_limit, |
| | fp16=args.fp16, |
| | bf16=args.bf16, |
| | gradient_checkpointing=args.gradient_checkpointing, |
| | report_to="wandb" if args.use_wandb else "none", |
| | load_best_model_at_end=True, |
| | metric_for_best_model="eval_loss", |
| | greater_is_better=False, |
| | evaluation_strategy="steps", |
| | save_strategy="steps", |
| | logging_dir=f"{args.output_dir}/logs", |
| | remove_unused_columns=False |
| | ) |
| | |
| | |
| | helion_trainer = HelionOSCTrainer( |
| | model_args=model_args, |
| | data_args=data_args, |
| | training_args=training_args |
| | ) |
| | |
| | |
| | trainer, metrics = helion_trainer.train() |
| | |
| | |
| | if args.validation_file: |
| | eval_metrics = helion_trainer.evaluate(trainer) |
| | logger.info(f"Evaluation metrics: {eval_metrics}") |
| | |
| | logger.info("Training pipeline completed!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |