File size: 4,625 Bytes
2a17ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from datasets import load_dataset
import torch
import os
import psutil
import gc

# Memory management and environment setup
def cleanup_memory():
    gc.collect()
    torch.mps.empty_cache()
    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()

# Set MPS memory limits and environment variables
# Note: Changed watermark ratio to a more conservative value
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.7'  # Changed from 0.8
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5'   # Added explicit low watermark
os.environ['PYTORCH_MPS_ALLOCATOR_POLICY'] = 'garbage_collection_conservative'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Memory monitoring
def print_memory_stats():
    process = psutil.Process()
    print(f"RAM Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")
    if hasattr(torch.mps, 'current_allocated_memory'):
        print(f"MPS Memory allocated: {torch.mps.current_allocated_memory() / 1024 / 1024:.2f} MB")

# Custom callback for memory monitoring
class MemoryCallback(TrainerCallback):
    def __init__(self, print_memory_stats_fn):
        self.print_memory_stats_fn = print_memory_stats_fn
        
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 100 == 0:
            print(f"\nStep {state.global_step}:")
            self.print_memory_stats_fn()
            cleanup_memory()

# Set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model and tokenizer
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_cache=False,
    torch_dtype=torch.float32
)
model.to(device)  # Explicitly move model to device
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Add pad token
tokenizer.pad_token = tokenizer.eos_token

# Load and filter dataset
train_data = load_dataset("json", data_files={"train": "data.json"})

def filter_dataset(example):
    return len(example["prompt"]) + len(example["completion"]) <= 512

train_data = train_data.filter(filter_dataset)

# Preprocess function
def preprocess_function(examples):
    inputs = [prompt + tokenizer.eos_token + completion 
              for prompt, completion in zip(examples["prompt"], examples["completion"])]
    
    model_inputs = tokenizer(
        inputs, 
        max_length=256, 
        truncation=True, 
        padding="max_length"
    )
    
    model_inputs["labels"] = model_inputs["input_ids"].copy()
    return model_inputs

# Preprocess the dataset
train_dataset = train_data["train"].map(preprocess_function, batched=True)

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=15,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,  # Reduced from 32
    logging_dir="./logs",
    fp16=False,
    eval_strategy="no",
    learning_rate=1e-5,  # Reduced from 5e-5
    save_steps=100,
    save_total_limit=2,
    gradient_checkpointing=True,
    optim="adamw_torch",
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    torch_compile=False,
    max_grad_norm=1.0,  # Increased from 0.5
    logging_steps=5,  # More frequent logging
    max_steps=1000,
    warmup_steps=300,  # Increased warmup steps
    weight_decay=0.2,  # Increased from 0.01
    logging_first_step=True,
    lr_scheduler_type="cosine_with_restarts",  # Changed to cosine with restarts
    warmup_ratio=0.15,  # Increased warmup ratio
)

# Clear cache before training
cleanup_memory()

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    callbacks=[MemoryCallback(print_memory_stats)]
)

# Monitor initial memory usage
print("Initial memory usage:")
print_memory_stats()

# Training with error handling
try:
    trainer.train()
except Exception as e:
    print(f"Training error: {str(e)}")
    cleanup_memory()
    try:
        model.save_pretrained("./lockin_model_partial")
        tokenizer.save_pretrained("./lockin_model_partial")
        print("Saved partial progress")
    except:
        print("Could not save partial progress")
    raise e
finally:
    cleanup_memory()

# Save the complete model
try:
    model.save_pretrained("./lockin_model")
    tokenizer.save_pretrained("./lockin_model")
    print("Model saved successfully")
except Exception as e:
    print(f"Error saving model: {str(e)}")

# Final cleanup
cleanup_memory()
print("\nFinal memory usage:")
print_memory_stats()