MnistStudio / static /js /train_single.js
Shilpaj's picture
Feat: Craete frontend and backend for the project
4c1a791
let ws;
function initializeCharts() {
const lossData = [{
name: 'Training Loss',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Validation Loss',
x: [],
y: [],
type: 'scatter'
}];
const accuracyData = [{
name: 'Training Accuracy',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Validation Accuracy',
x: [],
y: [],
type: 'scatter'
}];
Plotly.newPlot('loss-plot', lossData, {
title: 'Training and Validation Loss',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Loss' }
});
Plotly.newPlot('accuracy-plot', accuracyData, {
title: 'Training and Validation Accuracy',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Accuracy (%)' }
});
}
function updateCharts(data) {
const iteration = data.epoch * data.batch;
Plotly.extendTraces('loss-plot', {
x: [[iteration], [iteration]],
y: [[data.train_loss], [data.val_loss]]
}, [0, 1]);
Plotly.extendTraces('accuracy-plot', {
x: [[iteration], [iteration]],
y: [[data.train_acc], [data.val_acc]]
}, [0, 1]);
// Update training logs
const logsDiv = document.getElementById('training-logs');
logsDiv.innerHTML = `
<p>Epoch: ${data.epoch + 1}</p>
<p>Training Loss: ${data.train_loss.toFixed(4)}</p>
<p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p>
<p>Validation Loss: ${data.val_loss.toFixed(4)}</p>
<p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p>
`;
}
async function trainModel() {
console.log("Training started..."); // Debug log
const config = {
kernels: [
parseInt(document.getElementById('kernel1').value),
parseInt(document.getElementById('kernel2').value),
parseInt(document.getElementById('kernel3').value)
],
optimizer: document.getElementById('optimizer').value,
batch_size: parseInt(document.getElementById('batch_size').value),
epochs: parseInt(document.getElementById('epochs').value)
};
console.log("Config:", config); // Debug log
// Show progress section and initialize charts
document.getElementById('training-progress').classList.remove('hidden');
initializeCharts();
try {
// Connect to WebSocket
console.log("Connecting to WebSocket..."); // Debug log
ws = new WebSocket(`ws://${window.location.host}/ws/train`);
ws.onopen = function() {
console.log("WebSocket connection established");
// Send configuration once connected
ws.send(JSON.stringify(config));
console.log("Config sent to server"); // Debug log
};
ws.onmessage = function(event) {
console.log("Received message:", event.data); // Debug log
const data = JSON.parse(event.data);
if (data.status === "completed") {
alert('Training completed successfully!');
} else if (data.status === "error") {
alert('Error during training: ' + data.message);
} else {
updateCharts(data);
}
};
ws.onerror = function(error) {
console.error('WebSocket error:', error);
alert('Error connecting to training server');
};
ws.onclose = function() {
console.log('WebSocket connection closed');
};
} catch (error) {
console.error('Error:', error);
alert('Error during training: ' + error.message);
}
}