MnistStudio / static /js /train.js
Shilpaj's picture
Feat: Completed logic for multiple models training and comparison
30d27e9
let ws;
let lossChart;
let accuracyChart;
function showTrainingForm(type) {
const singleForm = document.getElementById('single-model-form');
const compareForm = document.getElementById('compare-models-form');
if (type === 'single') {
singleForm.classList.remove('hidden');
compareForm.classList.add('hidden');
} else {
singleForm.classList.add('hidden');
compareForm.classList.remove('hidden');
}
}
function initializeCharts() {
const lossData = [{
name: 'Training Loss',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Validation Loss',
x: [],
y: [],
type: 'scatter'
}];
const accuracyData = [{
name: 'Training Accuracy',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Validation Accuracy',
x: [],
y: [],
type: 'scatter'
}];
Plotly.newPlot('loss-plot', lossData, {
title: 'Training and Validation Loss',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Loss' }
});
Plotly.newPlot('accuracy-plot', accuracyData, {
title: 'Training and Validation Accuracy',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Accuracy (%)' }
});
}
function updateCharts(data) {
const iteration = data.epoch * data.batch;
Plotly.extendTraces('loss-plot', {
x: [[iteration], [iteration]],
y: [[data.train_loss], [data.val_loss]]
}, [0, 1]);
Plotly.extendTraces('accuracy-plot', {
x: [[iteration], [iteration]],
y: [[data.train_acc], [data.val_acc]]
}, [0, 1]);
// Update training logs
const logsDiv = document.getElementById('training-logs');
logsDiv.innerHTML = `
<p>Epoch: ${data.epoch + 1}</p>
<p>Training Loss: ${data.train_loss.toFixed(4)}</p>
<p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p>
<p>Validation Loss: ${data.val_loss.toFixed(4)}</p>
<p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p>
`;
}
async function trainSingleModel() {
const config = {
kernels: [
parseInt(document.getElementById('kernel1').value),
parseInt(document.getElementById('kernel2').value),
parseInt(document.getElementById('kernel3').value)
],
optimizer: document.getElementById('optimizer').value,
batch_size: parseInt(document.getElementById('batch_size').value),
epochs: parseInt(document.getElementById('epochs').value)
};
// Show progress section and initialize charts
document.getElementById('training-progress').classList.remove('hidden');
initializeCharts();
// Connect to WebSocket
ws = new WebSocket(`ws://${window.location.host}/ws/train`);
ws.onmessage = function(event) {
const data = JSON.parse(event.data);
updateCharts(data);
};
try {
const response = await fetch('/api/train_single', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(config)
});
const data = await response.json();
if (data.status === 'success') {
alert('Training completed successfully!');
}
} catch (error) {
console.error('Error:', error);
alert('Error during training. Please check console for details.');
}
}
async function compareModels() {
const config = {
model1: {
kernels: [
parseInt(document.getElementById('model1_kernel1').value),
parseInt(document.getElementById('model1_kernel2').value),
parseInt(document.getElementById('model1_kernel3').value)
],
optimizer: document.getElementById('model1_optimizer').value,
batch_size: parseInt(document.getElementById('model1_batch_size').value),
epochs: parseInt(document.getElementById('model1_epochs').value)
},
model2: {
kernels: [
parseInt(document.getElementById('model2_kernel1').value),
parseInt(document.getElementById('model2_kernel2').value),
parseInt(document.getElementById('model2_kernel3').value)
],
optimizer: document.getElementById('model2_optimizer').value,
batch_size: parseInt(document.getElementById('model2_batch_size').value),
epochs: parseInt(document.getElementById('model2_epochs').value)
}
};
// Show comparison progress section
document.getElementById('comparison-progress').classList.remove('hidden');
initializeComparisonCharts();
try {
const response = await fetch('/api/train_compare', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(config)
});
const data = await response.json();
if (data.status === 'success') {
displayComparisonResults(data);
alert('Model comparison completed successfully!');
}
} catch (error) {
console.error('Error:', error);
alert('Error during model comparison. Please check console for details.');
}
}
function initializeComparisonCharts() {
const lossData = [{
name: 'Model A Loss',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Model B Loss',
x: [],
y: [],
type: 'scatter'
}];
const accuracyData = [{
name: 'Model A Accuracy',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Model B Accuracy',
x: [],
y: [],
type: 'scatter'
}];
Plotly.newPlot('comparison-loss-plot', lossData, {
title: 'Loss Comparison',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Loss' }
});
Plotly.newPlot('comparison-accuracy-plot', accuracyData, {
title: 'Accuracy Comparison',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Accuracy (%)' }
});
}
function displayComparisonResults(data) {
const logsDiv = document.getElementById('comparison-logs');
logsDiv.innerHTML = `
<div class="comparison-model">
<h4>Model A</h4>
<p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
<p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
<p>Model Name: ${data.model1_results.model_name}</p>
</div>
<div class="comparison-model">
<h4>Model B</h4>
<p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
<p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
<p>Model Name: ${data.model2_results.model_name}</p>
</div>
`;
}
function displayResults(data) {
const resultsDiv = document.getElementById('training-results');
// Display training results
}