|
|
|
""" |
|
Training script for ScrapeGoat Music models using local model files with HCF optimization. |
|
Optimized for local training with the models in the provided directory structure. |
|
""" |
|
|
|
import os |
|
import sys |
|
import json |
|
import torch |
|
import logging |
|
from pathlib import Path |
|
from dataclasses import dataclass |
|
from typing import Optional, List, Dict, Tuple, Any |
|
import transformers |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
TrainingArguments, |
|
Trainer, |
|
DataCollatorForLanguageModeling |
|
) |
|
from datasets import Dataset |
|
import numpy as np |
|
from accelerate import Accelerator |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file, load_file |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
XCODEC_PATH = os.path.join(SCRIPT_DIR, "xcodec_mini_infer") |
|
sys.path.append(XCODEC_PATH) |
|
|
|
|
|
from train_hcf import ( |
|
TensorInfo, |
|
SafeTensorHCFAnalyzer, |
|
TrainingStatistics, |
|
HCFTrainingOptimizer, |
|
HCFAwareTrainer |
|
) |
|
|
|
@dataclass |
|
class LocalModelConfig: |
|
"""Configuration for local model directories""" |
|
model_path: str |
|
name: str |
|
|
|
@property |
|
def model_dir(self) -> str: |
|
return os.path.abspath(self.model_path) |
|
|
|
class LocalFineTuner: |
|
"""Fine-tuner that works with local model files""" |
|
|
|
def __init__( |
|
self, |
|
model_config: LocalModelConfig, |
|
dataset_path: str, |
|
output_dir: str, |
|
device: str = "auto", |
|
batch_size: int = 4, |
|
gradient_accumulation_steps: int = 4, |
|
learning_rate: float = 1e-5, |
|
num_epochs: int = 3, |
|
use_hcf: bool = True |
|
): |
|
self.model_config = model_config |
|
self.dataset_path = Path(dataset_path) |
|
self.output_dir = Path(output_dir) |
|
self.device = self._setup_device(device) |
|
self.use_hcf = use_hcf |
|
|
|
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.training_args = TrainingArguments( |
|
output_dir=str(self.output_dir), |
|
per_device_train_batch_size=batch_size, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
learning_rate=learning_rate, |
|
num_train_epochs=num_epochs, |
|
logging_steps=100, |
|
save_steps=1000, |
|
evaluation_strategy="steps", |
|
eval_steps=500, |
|
save_total_limit=3, |
|
load_best_model_at_end=True, |
|
gradient_checkpointing=True, |
|
fp16=torch.cuda.is_available(), |
|
optim="adamw_torch" |
|
) |
|
|
|
def _setup_device(self, device: str) -> str: |
|
"""Set up the training device""" |
|
if device == "auto": |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
return device |
|
|
|
def _load_model_and_tokenizer(self): |
|
"""Load model and tokenizer from local path""" |
|
logger.info(f"Loading model from {self.model_config.model_dir}") |
|
|
|
|
|
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
self.model_config.model_dir, |
|
torch_dtype=dtype, |
|
device_map="auto" if self.device == "cuda" else None, |
|
attn_implementation="flash_attention_2" if self.device == "cuda" else "eager", |
|
local_files_only=True |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_config.model_dir, |
|
local_files_only=True |
|
) |
|
|
|
return model, tokenizer |
|
|
|
def _prepare_dataset(self, tokenizer): |
|
"""Prepare dataset for training""" |
|
logger.info("Preparing dataset") |
|
|
|
|
|
with open(self.dataset_path / "metadata" / "dataset_info.json") as f: |
|
metadata = json.load(f) |
|
|
|
|
|
def generate_text(item): |
|
return f"Genre: {item['genre']}\nDuration: {item['duration']:.2f}s\nTitle: {item['title']}\nArtist: {item['artist']}\n" |
|
|
|
|
|
texts = [generate_text(item) for item in metadata["files"]] |
|
dataset = Dataset.from_dict({"text": texts}) |
|
|
|
|
|
def tokenize(examples): |
|
return tokenizer( |
|
examples["text"], |
|
truncation=True, |
|
padding="max_length", |
|
max_length=512, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
tokenize, |
|
batched=True, |
|
remove_columns=dataset.column_names |
|
) |
|
|
|
return tokenized_dataset |
|
|
|
def train(self): |
|
"""Train the model with HCF optimization""" |
|
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info(f"Training {self.model_config.name} model with HCF optimization") |
|
logger.info(f"Model path: {self.model_config.model_dir}") |
|
logger.info(f"Dataset path: {self.dataset_path}") |
|
logger.info(f"Output directory: {self.output_dir}") |
|
logger.info(f"Device: {self.device}") |
|
logger.info(f"HCF optimization: {'enabled' if self.use_hcf else 'disabled'}") |
|
|
|
|
|
model, tokenizer = self._load_model_and_tokenizer() |
|
|
|
|
|
dataset = self._prepare_dataset(tokenizer) |
|
|
|
|
|
dataset = dataset.train_test_split(test_size=0.1) |
|
|
|
if self.use_hcf: |
|
logger.info("Using HCF-aware training") |
|
|
|
optimizer = HCFTrainingOptimizer( |
|
model.parameters(), |
|
lr=self.training_args.learning_rate, |
|
weight_quantization=True, |
|
maintain_patterns=True |
|
) |
|
|
|
|
|
hcf_trainer = HCFAwareTrainer(model, optimizer) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
dataset["train"], |
|
batch_size=self.training_args.per_device_train_batch_size, |
|
shuffle=True |
|
) |
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
for epoch in range(int(self.training_args.num_train_epochs)): |
|
stats = hcf_trainer.train_epoch(train_loader, criterion, epoch) |
|
|
|
|
|
logger.info(f"Epoch {epoch} completed") |
|
logger.info(f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB") |
|
logger.info(f"Quantization Error: {stats.quantization_error:.6f}") |
|
logger.info(f"Convergence Rate: {stats.convergence_rate:.4f}") |
|
|
|
|
|
self._save_hcf_checkpoint(model, tokenizer, epoch) |
|
else: |
|
|
|
logger.info("Using standard training") |
|
trainer = Trainer( |
|
model=model, |
|
args=self.training_args, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["test"], |
|
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), |
|
) |
|
|
|
|
|
logger.info("Starting training") |
|
trainer.train() |
|
|
|
|
|
logger.info("Saving model") |
|
final_output_dir = self.output_dir / "final_model" |
|
final_output_dir.mkdir(exist_ok=True) |
|
|
|
model.save_pretrained(str(final_output_dir)) |
|
tokenizer.save_pretrained(str(final_output_dir)) |
|
|
|
logger.info(f"Training complete. Model saved to {final_output_dir}") |
|
|
|
def _save_hcf_checkpoint(self, model, tokenizer, epoch): |
|
"""Save checkpoint with HCF metadata""" |
|
checkpoint_dir = self.output_dir / f"checkpoint-{epoch}" |
|
checkpoint_dir.mkdir(exist_ok=True) |
|
|
|
|
|
model.save_pretrained(str(checkpoint_dir)) |
|
tokenizer.save_pretrained(str(checkpoint_dir)) |
|
|
|
|
|
analyzer = SafeTensorHCFAnalyzer() |
|
|
|
|
|
model_path = str(checkpoint_dir / "model.safetensors") |
|
if os.path.exists(model_path): |
|
results = analyzer.analyze_safetensor_weights(model_path) |
|
|
|
|
|
with open(checkpoint_dir / "hcf_analysis.json", "w") as f: |
|
json.dump(results, f, indent=2) |
|
|
|
logger.info(f"Saved checkpoint at {checkpoint_dir}") |
|
|
|
def main(): |
|
"""Main function for training""" |
|
import argparse |
|
parser = argparse.ArgumentParser(description="Retrain ScrapeGoat Music models with HCF optimization") |
|
parser.add_argument("--model", type=str, choices=["7b", "1b"], required=True, |
|
help="Model size to train") |
|
parser.add_argument("--dataset_path", type=str, required=True, |
|
help="Path to processed dataset") |
|
parser.add_argument("--output_dir", type=str, required=True, |
|
help="Directory to save trained model") |
|
parser.add_argument("--device", type=str, default="auto", |
|
help="Device to use (cuda, cpu, mps, or auto)") |
|
parser.add_argument("--batch_size", type=int, default=4, |
|
help="Batch size for training") |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, |
|
help="Gradient accumulation steps") |
|
parser.add_argument("--learning_rate", type=float, default=1e-5, |
|
help="Learning rate") |
|
parser.add_argument("--num_epochs", type=int, default=3, |
|
help="Number of training epochs") |
|
parser.add_argument("--use_hcf", action="store_true", default=True, |
|
help="Enable HCF optimization") |
|
parser.add_argument("--base_dir", type=str, default=os.getcwd(), |
|
help="Base directory containing model folders") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.model == "7b": |
|
model_path = os.path.join(args.base_dir, "scrapegoat/ScrapeGoat-Music-Stage1") |
|
model_config = LocalModelConfig( |
|
model_path=model_path, |
|
name="ScrapeGoatMusic 7B" |
|
) |
|
else: |
|
model_path = os.path.join(args.base_dir, "scrapegoat/ScrapeGoat-Music-Stage2") |
|
model_config = LocalModelConfig( |
|
model_path=model_path, |
|
name="ScrapeGoatMusic 1B" |
|
) |
|
|
|
|
|
fine_tuner = LocalFineTuner( |
|
model_config=model_config, |
|
dataset_path=args.dataset_path, |
|
output_dir=args.output_dir, |
|
device=args.device, |
|
batch_size=args.batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
learning_rate=args.learning_rate, |
|
num_epochs=args.num_epochs, |
|
use_hcf=args.use_hcf |
|
) |
|
|
|
|
|
fine_tuner.train() |
|
|
|
if __name__ == "__main__": |
|
main() |