File size: 8,697 Bytes
a57357b
 
 
20852a7
 
 
 
 
 
a57357b
20852a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57357b
20852a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57357b
20852a7
 
a57357b
20852a7
a57357b
20852a7
 
a57357b
20852a7
 
 
 
 
 
 
 
 
 
 
a57357b
20852a7
 
a57357b
20852a7
 
 
 
a57357b
 
20852a7
 
 
a57357b
20852a7
 
 
 
 
 
a57357b
20852a7
a57357b
20852a7
 
a57357b
20852a7
 
 
 
 
 
 
 
 
 
 
 
 
a57357b
20852a7
 
 
 
 
 
a57357b
20852a7
 
a57357b
20852a7
a57357b
20852a7
 
 
a57357b
20852a7
 
 
 
 
 
 
 
 
 
a57357b
20852a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57357b
20852a7
 
 
 
 
 
a90f827
 
 
 
 
 
 
 
 
 
 
 
 
20852a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f827
20852a7
a57357b
20852a7
a57357b
20852a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import sys
import json
import logging
import gradio as gr
from pathlib import Path
import subprocess
import time
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# Configuration paths
CONFIG_DIR = "."
TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")
HARDWARE_CONFIG = os.path.join(CONFIG_DIR, "hardware_config.json")
DATASET_CONFIG = os.path.join(CONFIG_DIR, "dataset_config.json")

def load_config(config_path):
    """Load configuration from JSON file."""
    try:
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                return json.load(f)
        else:
            logger.error(f"Config file not found: {config_path}")
            return None
    except Exception as e:
        logger.error(f"Error loading config: {str(e)}")
        return None

def display_config():
    """Display current training configuration."""
    transformers_config = load_config(TRANSFORMERS_CONFIG)
    hardware_config = load_config(HARDWARE_CONFIG)
    dataset_config = load_config(DATASET_CONFIG)
    
    if not all([transformers_config, hardware_config, dataset_config]):
        return "Error loading configuration files."
    
    # Extract key parameters
    model_name = transformers_config.get("model", {}).get("name", "")
    dataset_name = dataset_config.get("dataset", {}).get("name", "")
    batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 0)
    gradient_accum = transformers_config.get("training", {}).get("gradient_accumulation_steps", 0)
    lr = transformers_config.get("training", {}).get("learning_rate", 0)
    epochs = transformers_config.get("training", {}).get("num_train_epochs", 0)
    gpu_count = hardware_config.get("specs", {}).get("gpu_count", 0)
    gpu_type = hardware_config.get("specs", {}).get("gpu_type", "")
    
    config_info = f"""
    ## Current Training Configuration
    
    **Model**: {model_name}
    **Dataset**: {dataset_name}
    
    **Training Parameters**:
    - Learning Rate: {lr}
    - Epochs: {epochs}
    - Batch Size/GPU: {batch_size}
    - Gradient Accumulation: {gradient_accum}
    - Effective Batch Size: {batch_size * gradient_accum * gpu_count}
    
    **Hardware**:
    - GPUs: {gpu_count}x {gpu_type}
    - Flash Attention: {hardware_config.get("memory_optimization", {}).get("use_flash_attention", False)}
    - Gradient Checkpointing: {hardware_config.get("memory_optimization", {}).get("use_gradient_checkpointing", False)}
    
    **Pre-quantized 4-bit Training**: Enabled
    """
    
    return config_info

def start_training():
    """Start the training process."""
    try:
        # Check if already running
        if os.path.exists("training.pid"):
            with open("training.pid", "r") as f:
                pid = f.read().strip()
                try:
                    # Check if process is still running
                    os.kill(int(pid), 0)
                    return f"Training is already running with PID {pid}"
                except OSError:
                    # Process not running, remove stale PID file
                    os.remove("training.pid")
        
        # Start training in background
        cmd = "python run_transformers_training.py"
        process = subprocess.Popen(
            cmd, 
            shell=True,
            stdout=open('training.log', 'a'),
            stderr=subprocess.STDOUT
        )
        
        # Save PID
        with open("training.pid", "w") as f:
            f.write(str(process.pid))
        
        # Log start time
        with open("training_history.log", "a") as f:
            f.write(f"{datetime.now().isoformat()}: Training started (PID: {process.pid})\n")
        
        return f"Training started with PID {process.pid}. Check status for updates."
    
    except Exception as e:
        return f"Error starting training: {str(e)}"

def check_training_status():
    """Check the status of training."""
    try:
        # Check if training is running
        if os.path.exists("training.pid"):
            with open("training.pid", "r") as f:
                pid = f.read().strip()
                try:
                    # Check if process is still running
                    os.kill(int(pid), 0)
                    status = f"Training is running with PID {pid}"
                except OSError:
                    status = "Training process has stopped"
                    os.remove("training.pid")
        else:
            status = "No training process is currently running"
        
        # Get last lines from training log
        log_content = "No training log available"
        if os.path.exists("training.log"):
            with open("training.log", "r") as f:
                lines = f.readlines()
                log_content = "".join(lines[-20:]) if lines else "Log file is empty"
        
        return f"{status}\n\n**Recent Log:**\n```\n{log_content}\n```"
    
    except Exception as e:
        return f"Error checking status: {str(e)}"

# Create the Gradio interface
with gr.Blocks(title="Phi-4 Unsloth Training", theme=gr.themes.Soft(primary_hue="blue")) as app:
    gr.Markdown("# Phi-4 Unsloth 4-bit Training Interface")
    
    with gr.Tabs():
        with gr.TabItem("Configuration"):
            config_output = gr.Markdown(display_config())
            refresh_btn = gr.Button("Refresh Configuration")
            refresh_btn.click(fn=display_config, outputs=config_output)
            
        with gr.TabItem("Training Control"):
            gr.Markdown("## Training Management")
            
            with gr.Row():
                start_btn = gr.Button("Start Training", variant="primary")
                check_btn = gr.Button("Check Status")
            
            status_output = gr.Markdown("Click 'Check Status' to see training progress")
            
            start_btn.click(fn=start_training, outputs=status_output)
            check_btn.click(fn=check_training_status, outputs=status_output)
            
            # Auto-refresh status
            gr.HTML('''
            <script>
            let intervalId;
            
            document.addEventListener('DOMContentLoaded', function() {
                // Find the "Check Status" button
                const buttons = Array.from(document.querySelectorAll('button'));
                const checkBtn = buttons.find(btn => btn.textContent.includes('Check Status'));
                
                // Set up interval to click the button every 30 seconds
                if (checkBtn) {
                    intervalId = setInterval(() => {
                        checkBtn.click();
                    }, 30000);
                }
            });
            
            // Clean up on tab/window close
            window.addEventListener('beforeunload', function() {
                clearInterval(intervalId);
            });
            </script>
            ''')
    
        with gr.TabItem("Help"):
            gr.Markdown("""
            ## Phi-4 Unsloth Training Help
            
            This interface allows you to manage training of the Phi-4 model with Unsloth 4-bit optimizations.
            
            ### Installation
            
            Before starting training, ensure all dependencies are installed:
            
            ```bash
            pip install -r requirements.txt
            ```
            
            Critical packages:
            - unsloth (>=2024.3)
            - peft (>=0.9.0)
            - transformers (>=4.36.0)
            
            ### Quick Start
            
            1. Review the configuration in the Configuration tab
            2. Click "Start Training" to begin the process
            3. Use "Check Status" to monitor progress
            
            ### Notes
            
            - Training uses the pre-quantized model `unsloth/phi-4-unsloth-bnb-4bit`
            - The process maintains paper order and handles metadata appropriately
            - Training progress will be regularly saved to HuggingFace Hub
            
            ### Troubleshooting
            
            If training stops unexpectedly:
            - Check the logs for out-of-memory errors
            - Verify the VRAM usage on each GPU
            - Check for CUDA version compatibility
            - If you see "Unsloth not available" error, run: `pip install unsloth>=2024.3 peft>=0.9.0`
            """)

# Launch the app
if __name__ == "__main__":
    app.launch()