import torch import argparse import os from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from datasets import load_dataset def parse_args(): parser = argparse.ArgumentParser(description="Fine-tune Charm 15 AI Model") parser.add_argument("--model_name", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1", help="Base model name or local path (default: Mixtral-8x7B)") parser.add_argument("--dataset", type=str, required=True, help="Path to training dataset (JSON or text file)") parser.add_argument("--eval_dataset", type=str, default=None, help="Path to optional validation dataset") parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=1, help="Per-device training batch size (lowered for GPU compatibility)") parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate") parser.add_argument("--output_dir", type=str, default="./finetuned_charm15", help="Model save directory") parser.add_argument("--max_length", type=int, default=512, help="Max token length for training") return parser.parse_args() def tokenize_function(examples, tokenizer, max_length): """Tokenize dataset and prepare labels for causal LM.""" tokenized = tokenizer( examples["text"], padding="max_length", truncation=True, max_length=max_length, return_tensors="pt" ) tokenized["labels"] = tokenized["input_ids"].clone() return tokenized def main(): args = parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Ensure output directory exists os.makedirs(args.output_dir, exist_ok=True) os.makedirs("./logs", exist_ok=True) # Load tokenizer print(f"Loading tokenizer from {args.model_name}...") try: tokenizer = AutoTokenizer.from_pretrained(args.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id except Exception as e: print(f"Error loading tokenizer: {e}") exit(1) # Load model with optimizations print(f"Loading model {args.model_name}...") try: model = AutoModelForCausalLM.from_pretrained( args.model_name, torch_dtype=torch.bfloat16, # Efficient precision device_map="auto", # Spread across GPU/CPU low_cpu_mem_usage=True # Reduce RAM ).to(device) except Exception as e: print(f"Error loading model: {e}") exit(1) # Load dataset print(f"Loading dataset from {args.dataset}...") try: if args.dataset.endswith(".json"): dataset = load_dataset("json", data_files={"train": args.dataset}) else: dataset = load_dataset("text", data_files={"train": args.dataset}) eval_dataset = None if args.eval_dataset: if args.eval_dataset.endswith(".json"): eval_dataset = load_dataset("json", data_files={"train": args.eval_dataset})["train"] else: eval_dataset = load_dataset("text", data_files={"train": args.eval_dataset})["train"] except Exception as e: print(f"Error loading dataset: {e}") exit(1) # Tokenize datasets print("Tokenizing dataset...") train_dataset = dataset["train"].map( lambda x: tokenize_function(x, tokenizer, args.max_length), batched=True, remove_columns=["text"] ) eval_dataset = eval_dataset.map( lambda x: tokenize_function(x, tokenizer, args.max_length), batched=True, remove_columns=["text"] ) if args.eval_dataset else None # Training arguments training_args = TrainingArguments( output_dir=args.output_dir, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, num_train_epochs=args.epochs, learning_rate=args.lr, gradient_accumulation_steps=8, # Effective batch size = 8 bf16=True, # Match dtype fp16=False, save_total_limit=2, save_steps=500, logging_dir="./logs", logging_steps=100, report_to="none", evaluation_strategy="epoch" if eval_dataset else "no", save_strategy="epoch", load_best_model_at_end=bool(eval_dataset), metric_for_best_model="loss" ) # Initialize Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer ) # Train print("Starting fine-tuning...") try: trainer.train() except RuntimeError as e: print(f"Training failed: {e} (Try reducing batch_size or max_length)") exit(1) # Save print(f"Saving fine-tuned model to {args.output_dir}") trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) # Cleanup del model torch.cuda.empty_cache() print("Training complete. Memory cleared.") if __name__ == "__main__": main()