Spaces:
Sleeping
Sleeping
let ws; | |
function initializeCharts() { | |
const lossData = [{ | |
name: 'Training Loss', | |
x: [], | |
y: [], | |
type: 'scatter' | |
}, { | |
name: 'Validation Loss', | |
x: [], | |
y: [], | |
type: 'scatter' | |
}]; | |
const accuracyData = [{ | |
name: 'Training Accuracy', | |
x: [], | |
y: [], | |
type: 'scatter' | |
}, { | |
name: 'Validation Accuracy', | |
x: [], | |
y: [], | |
type: 'scatter' | |
}]; | |
Plotly.newPlot('loss-plot', lossData, { | |
title: 'Training and Validation Loss', | |
xaxis: { title: 'Iterations' }, | |
yaxis: { title: 'Loss' } | |
}); | |
Plotly.newPlot('accuracy-plot', accuracyData, { | |
title: 'Training and Validation Accuracy', | |
xaxis: { title: 'Iterations' }, | |
yaxis: { title: 'Accuracy (%)' } | |
}); | |
} | |
function updateCharts(data) { | |
const iteration = data.epoch * data.batch; | |
Plotly.extendTraces('loss-plot', { | |
x: [[iteration], [iteration]], | |
y: [[data.train_loss], [data.val_loss]] | |
}, [0, 1]); | |
Plotly.extendTraces('accuracy-plot', { | |
x: [[iteration], [iteration]], | |
y: [[data.train_acc], [data.val_acc]] | |
}, [0, 1]); | |
// Update training logs | |
const logsDiv = document.getElementById('training-logs'); | |
logsDiv.innerHTML = ` | |
<p>Epoch: ${data.epoch + 1}</p> | |
<p>Training Loss: ${data.train_loss.toFixed(4)}</p> | |
<p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p> | |
<p>Validation Loss: ${data.val_loss.toFixed(4)}</p> | |
<p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p> | |
`; | |
} | |
async function trainModel() { | |
console.log("Training started..."); // Debug log | |
const config = { | |
kernels: [ | |
parseInt(document.getElementById('kernel1').value), | |
parseInt(document.getElementById('kernel2').value), | |
parseInt(document.getElementById('kernel3').value) | |
], | |
optimizer: document.getElementById('optimizer').value, | |
batch_size: parseInt(document.getElementById('batch_size').value), | |
epochs: parseInt(document.getElementById('epochs').value) | |
}; | |
console.log("Config:", config); // Debug log | |
// Show progress section and initialize charts | |
document.getElementById('training-progress').classList.remove('hidden'); | |
initializeCharts(); | |
try { | |
// Connect to WebSocket | |
console.log("Connecting to WebSocket..."); // Debug log | |
ws = new WebSocket(`ws://${window.location.host}/ws/train`); | |
ws.onopen = function() { | |
console.log("WebSocket connection established"); | |
// Send configuration once connected | |
ws.send(JSON.stringify(config)); | |
console.log("Config sent to server"); // Debug log | |
}; | |
ws.onmessage = function(event) { | |
console.log("Received message:", event.data); // Debug log | |
const data = JSON.parse(event.data); | |
if (data.status === "completed") { | |
alert('Training completed successfully!'); | |
} else if (data.status === "error") { | |
alert('Error during training: ' + data.message); | |
} else { | |
updateCharts(data); | |
} | |
}; | |
ws.onerror = function(error) { | |
console.error('WebSocket error:', error); | |
alert('Error connecting to training server'); | |
}; | |
ws.onclose = function() { | |
console.log('WebSocket connection closed'); | |
}; | |
} catch (error) { | |
console.error('Error:', error); | |
alert('Error during training: ' + error.message); | |
} | |
} |