File size: 4,625 Bytes
2a17ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from datasets import load_dataset
import torch
import os
import psutil
import gc
# Memory management and environment setup
def cleanup_memory():
gc.collect()
torch.mps.empty_cache()
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
# Set MPS memory limits and environment variables
# Note: Changed watermark ratio to a more conservative value
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.7' # Changed from 0.8
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5' # Added explicit low watermark
os.environ['PYTORCH_MPS_ALLOCATOR_POLICY'] = 'garbage_collection_conservative'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
# Memory monitoring
def print_memory_stats():
process = psutil.Process()
print(f"RAM Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")
if hasattr(torch.mps, 'current_allocated_memory'):
print(f"MPS Memory allocated: {torch.mps.current_allocated_memory() / 1024 / 1024:.2f} MB")
# Custom callback for memory monitoring
class MemoryCallback(TrainerCallback):
def __init__(self, print_memory_stats_fn):
self.print_memory_stats_fn = print_memory_stats_fn
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % 100 == 0:
print(f"\nStep {state.global_step}:")
self.print_memory_stats_fn()
cleanup_memory()
# Set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
# Load model and tokenizer
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_cache=False,
torch_dtype=torch.float32
)
model.to(device) # Explicitly move model to device
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add pad token
tokenizer.pad_token = tokenizer.eos_token
# Load and filter dataset
train_data = load_dataset("json", data_files={"train": "data.json"})
def filter_dataset(example):
return len(example["prompt"]) + len(example["completion"]) <= 512
train_data = train_data.filter(filter_dataset)
# Preprocess function
def preprocess_function(examples):
inputs = [prompt + tokenizer.eos_token + completion
for prompt, completion in zip(examples["prompt"], examples["completion"])]
model_inputs = tokenizer(
inputs,
max_length=256,
truncation=True,
padding="max_length"
)
model_inputs["labels"] = model_inputs["input_ids"].copy()
return model_inputs
# Preprocess the dataset
train_dataset = train_data["train"].map(preprocess_function, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=15,
per_device_train_batch_size=1,
gradient_accumulation_steps=8, # Reduced from 32
logging_dir="./logs",
fp16=False,
eval_strategy="no",
learning_rate=1e-5, # Reduced from 5e-5
save_steps=100,
save_total_limit=2,
gradient_checkpointing=True,
optim="adamw_torch",
dataloader_num_workers=0,
dataloader_pin_memory=False,
torch_compile=False,
max_grad_norm=1.0, # Increased from 0.5
logging_steps=5, # More frequent logging
max_steps=1000,
warmup_steps=300, # Increased warmup steps
weight_decay=0.2, # Increased from 0.01
logging_first_step=True,
lr_scheduler_type="cosine_with_restarts", # Changed to cosine with restarts
warmup_ratio=0.15, # Increased warmup ratio
)
# Clear cache before training
cleanup_memory()
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
callbacks=[MemoryCallback(print_memory_stats)]
)
# Monitor initial memory usage
print("Initial memory usage:")
print_memory_stats()
# Training with error handling
try:
trainer.train()
except Exception as e:
print(f"Training error: {str(e)}")
cleanup_memory()
try:
model.save_pretrained("./lockin_model_partial")
tokenizer.save_pretrained("./lockin_model_partial")
print("Saved partial progress")
except:
print("Could not save partial progress")
raise e
finally:
cleanup_memory()
# Save the complete model
try:
model.save_pretrained("./lockin_model")
tokenizer.save_pretrained("./lockin_model")
print("Model saved successfully")
except Exception as e:
print(f"Error saving model: {str(e)}")
# Final cleanup
cleanup_memory()
print("\nFinal memory usage:")
print_memory_stats() |