Spaces:
Sleeping
Sleeping
File size: 8,697 Bytes
a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a57357b 20852a7 a90f827 20852a7 a90f827 20852a7 a57357b 20852a7 a57357b 20852a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import os
import sys
import json
import logging
import gradio as gr
from pathlib import Path
import subprocess
import time
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Configuration paths
CONFIG_DIR = "."
TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")
HARDWARE_CONFIG = os.path.join(CONFIG_DIR, "hardware_config.json")
DATASET_CONFIG = os.path.join(CONFIG_DIR, "dataset_config.json")
def load_config(config_path):
"""Load configuration from JSON file."""
try:
if os.path.exists(config_path):
with open(config_path, 'r') as f:
return json.load(f)
else:
logger.error(f"Config file not found: {config_path}")
return None
except Exception as e:
logger.error(f"Error loading config: {str(e)}")
return None
def display_config():
"""Display current training configuration."""
transformers_config = load_config(TRANSFORMERS_CONFIG)
hardware_config = load_config(HARDWARE_CONFIG)
dataset_config = load_config(DATASET_CONFIG)
if not all([transformers_config, hardware_config, dataset_config]):
return "Error loading configuration files."
# Extract key parameters
model_name = transformers_config.get("model", {}).get("name", "")
dataset_name = dataset_config.get("dataset", {}).get("name", "")
batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 0)
gradient_accum = transformers_config.get("training", {}).get("gradient_accumulation_steps", 0)
lr = transformers_config.get("training", {}).get("learning_rate", 0)
epochs = transformers_config.get("training", {}).get("num_train_epochs", 0)
gpu_count = hardware_config.get("specs", {}).get("gpu_count", 0)
gpu_type = hardware_config.get("specs", {}).get("gpu_type", "")
config_info = f"""
## Current Training Configuration
**Model**: {model_name}
**Dataset**: {dataset_name}
**Training Parameters**:
- Learning Rate: {lr}
- Epochs: {epochs}
- Batch Size/GPU: {batch_size}
- Gradient Accumulation: {gradient_accum}
- Effective Batch Size: {batch_size * gradient_accum * gpu_count}
**Hardware**:
- GPUs: {gpu_count}x {gpu_type}
- Flash Attention: {hardware_config.get("memory_optimization", {}).get("use_flash_attention", False)}
- Gradient Checkpointing: {hardware_config.get("memory_optimization", {}).get("use_gradient_checkpointing", False)}
**Pre-quantized 4-bit Training**: Enabled
"""
return config_info
def start_training():
"""Start the training process."""
try:
# Check if already running
if os.path.exists("training.pid"):
with open("training.pid", "r") as f:
pid = f.read().strip()
try:
# Check if process is still running
os.kill(int(pid), 0)
return f"Training is already running with PID {pid}"
except OSError:
# Process not running, remove stale PID file
os.remove("training.pid")
# Start training in background
cmd = "python run_transformers_training.py"
process = subprocess.Popen(
cmd,
shell=True,
stdout=open('training.log', 'a'),
stderr=subprocess.STDOUT
)
# Save PID
with open("training.pid", "w") as f:
f.write(str(process.pid))
# Log start time
with open("training_history.log", "a") as f:
f.write(f"{datetime.now().isoformat()}: Training started (PID: {process.pid})\n")
return f"Training started with PID {process.pid}. Check status for updates."
except Exception as e:
return f"Error starting training: {str(e)}"
def check_training_status():
"""Check the status of training."""
try:
# Check if training is running
if os.path.exists("training.pid"):
with open("training.pid", "r") as f:
pid = f.read().strip()
try:
# Check if process is still running
os.kill(int(pid), 0)
status = f"Training is running with PID {pid}"
except OSError:
status = "Training process has stopped"
os.remove("training.pid")
else:
status = "No training process is currently running"
# Get last lines from training log
log_content = "No training log available"
if os.path.exists("training.log"):
with open("training.log", "r") as f:
lines = f.readlines()
log_content = "".join(lines[-20:]) if lines else "Log file is empty"
return f"{status}\n\n**Recent Log:**\n```\n{log_content}\n```"
except Exception as e:
return f"Error checking status: {str(e)}"
# Create the Gradio interface
with gr.Blocks(title="Phi-4 Unsloth Training", theme=gr.themes.Soft(primary_hue="blue")) as app:
gr.Markdown("# Phi-4 Unsloth 4-bit Training Interface")
with gr.Tabs():
with gr.TabItem("Configuration"):
config_output = gr.Markdown(display_config())
refresh_btn = gr.Button("Refresh Configuration")
refresh_btn.click(fn=display_config, outputs=config_output)
with gr.TabItem("Training Control"):
gr.Markdown("## Training Management")
with gr.Row():
start_btn = gr.Button("Start Training", variant="primary")
check_btn = gr.Button("Check Status")
status_output = gr.Markdown("Click 'Check Status' to see training progress")
start_btn.click(fn=start_training, outputs=status_output)
check_btn.click(fn=check_training_status, outputs=status_output)
# Auto-refresh status
gr.HTML('''
<script>
let intervalId;
document.addEventListener('DOMContentLoaded', function() {
// Find the "Check Status" button
const buttons = Array.from(document.querySelectorAll('button'));
const checkBtn = buttons.find(btn => btn.textContent.includes('Check Status'));
// Set up interval to click the button every 30 seconds
if (checkBtn) {
intervalId = setInterval(() => {
checkBtn.click();
}, 30000);
}
});
// Clean up on tab/window close
window.addEventListener('beforeunload', function() {
clearInterval(intervalId);
});
</script>
''')
with gr.TabItem("Help"):
gr.Markdown("""
## Phi-4 Unsloth Training Help
This interface allows you to manage training of the Phi-4 model with Unsloth 4-bit optimizations.
### Installation
Before starting training, ensure all dependencies are installed:
```bash
pip install -r requirements.txt
```
Critical packages:
- unsloth (>=2024.3)
- peft (>=0.9.0)
- transformers (>=4.36.0)
### Quick Start
1. Review the configuration in the Configuration tab
2. Click "Start Training" to begin the process
3. Use "Check Status" to monitor progress
### Notes
- Training uses the pre-quantized model `unsloth/phi-4-unsloth-bnb-4bit`
- The process maintains paper order and handles metadata appropriately
- Training progress will be regularly saved to HuggingFace Hub
### Troubleshooting
If training stops unexpectedly:
- Check the logs for out-of-memory errors
- Verify the VRAM usage on each GPU
- Check for CUDA version compatibility
- If you see "Unsloth not available" error, run: `pip install unsloth>=2024.3 peft>=0.9.0`
""")
# Launch the app
if __name__ == "__main__":
app.launch()
|