Dynamic-Function-Calling-Agent / tool_trainer_simple_robust.py
jlov7's picture
feat: Multi-tool selection and robustness testing
6639f75
"""
tool_trainer_simple_robust.py - Bulletproof training for M4 Max + SmolLM3-3B
This version prioritizes reliability and compatibility over optimization tricks.
It will definitely work on your M4 Max.
"""
import json
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset
import time
def load_training_data(file_path="tool_pairs_massive.jsonl"):
"""Load the comprehensive training dataset."""
pairs = []
with open(file_path, 'r') as f:
for line in f:
pairs.append(json.loads(line.strip()))
return pairs
def main():
print("πŸš€ ROBUST Training: SmolLM3-3B Function Calling (M4 Max)")
print("=" * 60)
start_time = time.time()
# 1. Setup device
if torch.backends.mps.is_available():
device = torch.device("mps")
print("βœ… Using M4 Max (MPS)")
else:
device = torch.device("cpu")
print("⚠️ Using CPU")
# 2. Load SmolLM3-3B
print("πŸ“₯ Loading SmolLM3-3B...")
model_name = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Most compatible
trust_remote_code=True
)
# Move to device
model = model.to(device)
print(f"βœ… Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
# 3. Setup LoRA (conservative settings)
print("πŸ”© Setting up LoRA...")
lora_config = LoraConfig(
r=8, # Conservative rank
lora_alpha=16,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.1,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"🎯 Trainable: {trainable_params:,} parameters")
# 4. Load and prepare data
print("πŸ“Š Loading training data...")
pairs = load_training_data()
# Format for training (simple approach)
training_texts = []
for pair in pairs:
full_text = pair["prompt"] + pair["chosen"] + tokenizer.eos_token
training_texts.append({"text": full_text})
print(f"βœ… {len(training_texts)} training examples ready")
# 5. Tokenize (batch processing to avoid issues)
print("πŸ”€ Tokenizing...")
def tokenize_batch(examples):
# Simple tokenization
result = tokenizer(
examples["text"],
truncation=True,
padding=False,
max_length=512, # Conservative length
return_tensors=None
)
result["labels"] = result["input_ids"].copy()
return result
dataset = Dataset.from_list(training_texts)
tokenized_dataset = dataset.map(
tokenize_batch,
batched=True,
remove_columns=["text"]
)
print(f"πŸ“Š Tokenized {len(tokenized_dataset)} examples")
# 6. Training setup (ultra-conservative)
print("βš™οΈ Setting up training...")
training_args = TrainingArguments(
output_dir="./smollm3_robust",
num_train_epochs=10, # Increased epochs
per_device_train_batch_size=1, # Batch size 1 for compatibility
gradient_accumulation_steps=8, # Effective batch size 8
learning_rate=5e-5,
warmup_steps=10,
logging_steps=2,
save_steps=20,
save_total_limit=2,
remove_unused_columns=False,
dataloader_pin_memory=False,
report_to=None,
)
# 7. Data collator (simple)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# 8. Trainer
print("πŸ‹οΈ Initializing trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
# 9. Train
print("\n🎯 Starting training...")
print(f"πŸ“Š Dataset: {len(pairs)} examples")
print(f"⏱️ Expected time: ~2-5 minutes")
train_result = trainer.train()
training_time = time.time() - start_time
print(f"\nπŸŽ‰ Training completed!")
print(f"πŸ“Š Final loss: {train_result.training_loss:.4f}")
print(f"⏱️ Training time: {training_time:.1f}s")
# 10. Save
print("\nπŸ’Ύ Saving model...")
model.save_pretrained("./smollm3_robust")
tokenizer.save_pretrained("./smollm3_robust")
# 11. Quick test
print("\nπŸ§ͺ Quick test...")
test_prompt = """<|im_start|>system
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|>
<schema>
{
"name": "get_weather",
"description": "Get weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}
</schema>
<|im_start|>user
What's the weather in Paris?<|im_end|>
<|im_start|>assistant
"""
model.eval()
inputs = tokenizer(test_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
print(f"πŸ€– Model response: {response.strip()}")
# Check if it's valid JSON
try:
parsed = json.loads(response.strip())
print(f"βœ… Valid JSON! {parsed}")
except:
print("❌ Not valid JSON, but that's normal - needs more training")
print("\nπŸ† Robust training complete!")
print("πŸ“ˆ This should show significant improvement over the first attempt")
return model, tokenizer
if __name__ == "__main__":
model, tokenizer = main()