MnistStudio / static /js /train_compare.js
Shilpaj's picture
Feat: Completed logic for multiple models training and comparison
30d27e9
let ws;
function initializeComparisonCharts() {
const lossData = [{
name: 'Model A Loss',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Model B Loss',
x: [],
y: [],
type: 'scatter'
}];
const accuracyData = [{
name: 'Model A Accuracy',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Model B Accuracy',
x: [],
y: [],
type: 'scatter'
}];
Plotly.newPlot('comparison-loss-plot', lossData, {
title: 'Loss Comparison',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Loss' }
});
Plotly.newPlot('comparison-accuracy-plot', accuracyData, {
title: 'Accuracy Comparison',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Accuracy (%)' }
});
}
async function compareModels() {
const config = {
model1: {
kernels: [
parseInt(document.getElementById('model1_kernel1').value),
parseInt(document.getElementById('model1_kernel2').value),
parseInt(document.getElementById('model1_kernel3').value)
],
optimizer: document.getElementById('model1_optimizer').value,
batch_size: parseInt(document.getElementById('model1_batch_size').value),
epochs: parseInt(document.getElementById('model1_epochs').value)
},
model2: {
kernels: [
parseInt(document.getElementById('model2_kernel1').value),
parseInt(document.getElementById('model2_kernel2').value),
parseInt(document.getElementById('model2_kernel3').value)
],
optimizer: document.getElementById('model2_optimizer').value,
batch_size: parseInt(document.getElementById('model2_batch_size').value),
epochs: parseInt(document.getElementById('model2_epochs').value)
}
};
// Show comparison progress section
document.getElementById('comparison-progress').classList.remove('hidden');
initializeComparisonCharts();
try {
const response = await fetch('/api/train_compare', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(config)
});
const data = await response.json();
if (data.status === 'success') {
displayComparisonResults(data);
alert('Model comparison completed successfully!');
}
} catch (error) {
console.error('Error:', error);
alert('Error during model comparison. Please check console for details.');
}
}
function displayComparisonResults(data) {
const logsDiv = document.getElementById('comparison-logs');
logsDiv.innerHTML = `
<div class="comparison-model">
<h4>Model A</h4>
<p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
<p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
<p>Model Name: ${data.model1_results.model_name}</p>
</div>
<div class="comparison-model">
<h4>Model B</h4>
<p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
<p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
<p>Model Name: ${data.model2_results.model_name}</p>
</div>
`;
}
// Add these helper functions to get the parameters
function getModelParameters() {
try {
const params = {
model_a: {
block1: parseInt(document.getElementById('model1_kernel1').value),
block2: parseInt(document.getElementById('model1_kernel2').value),
block3: parseInt(document.getElementById('model1_kernel3').value),
optimizer: document.getElementById('model1_optimizer').value,
batch_size: parseInt(document.getElementById('model1_batch_size').value),
epochs: parseInt(document.getElementById('model1_epochs').value)
},
model_b: {
block1: parseInt(document.getElementById('model2_kernel1').value),
block2: parseInt(document.getElementById('model2_kernel2').value),
block3: parseInt(document.getElementById('model2_kernel3').value),
optimizer: document.getElementById('model2_optimizer').value,
batch_size: parseInt(document.getElementById('model2_batch_size').value),
epochs: parseInt(document.getElementById('model2_epochs').value)
}
};
// Validate that all values are present and valid
for (const model of ['model_a', 'model_b']) {
for (const [key, value] of Object.entries(params[model])) {
if (value === null || value === undefined || Number.isNaN(value)) {
throw new Error(`Invalid value for ${model} ${key}: ${value}`);
}
}
}
console.log('Collected and validated model parameters:', params);
return params;
} catch (error) {
console.error('Error in getModelParameters:', error);
throw error;
}
}
function getDatasetParameters() {
return {
batch_size: parseInt(document.getElementById('model1_batch_size').value), // Using model1's batch size for dataset
shuffle: true
};
}
// Update the WebSocket event listener
document.getElementById('startComparisonBtn').addEventListener('click', function() {
console.log('Start Comparison button clicked');
// Validate form inputs before proceeding
const formInputs = document.querySelectorAll('input[type="number"], select'); // Added select for optimizer
let isValid = true;
let formValues = {};
formInputs.forEach(input => {
console.log(`Checking input ${input.id}: ${input.value}`);
formValues[input.id] = input.value;
if (!input.value) {
console.error(`Missing value for ${input.id}`);
isValid = false;
}
});
console.log('Form values:', formValues); // Log all form values
if (!isValid) {
alert('Please fill in all required fields');
return;
}
// Show comparison progress section
document.getElementById('comparison-progress').classList.remove('hidden');
console.log('Initialized comparison charts');
initializeComparisonCharts();
console.log('Attempting WebSocket connection...');
const ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
ws.onopen = function() {
console.log('WebSocket connection established');
const parameters = {
model_params: getModelParameters(),
dataset_params: getDatasetParameters()
};
const message = {
action: 'start_training',
parameters: parameters
};
console.log('Preparing to send message:', JSON.stringify(message, null, 2));
// Add a small delay to ensure WebSocket is ready
setTimeout(() => {
try {
ws.send(JSON.stringify(message));
console.log('Message sent successfully');
} catch (error) {
console.error('Error sending message:', error);
alert('Error sending training parameters. Please check console for details.');
}
}, 100);
};
ws.onmessage = function(event) {
console.log('Received WebSocket message:', event.data);
try {
const data = JSON.parse(event.data);
console.log('Parsed message data:', data);
updateTrainingProgress(data);
} catch (error) {
console.error('Error processing message:', error);
}
};
ws.onerror = function(error) {
console.error('WebSocket error:', error);
alert('Connection error occurred. Please check console for details.');
};
ws.onclose = function(event) {
console.log('WebSocket connection closed. Code:', event.code, 'Reason:', event.reason);
};
});
// Add the updateTrainingProgress function
function updateTrainingProgress(data) {
if (data.status === 'training') {
// Update loss plot
Plotly.extendTraces('comparison-loss-plot', {
y: [[data.metrics.loss]],
}, [data.model === 'A' ? 0 : 1]);
// Update accuracy plot
Plotly.extendTraces('comparison-accuracy-plot', {
y: [[data.metrics.accuracy]],
}, [data.model === 'A' ? 0 : 1]);
// Update progress text
const progressText = document.getElementById('training-progress-text');
progressText.textContent = `Training ${data.model === 'A' ? 'Model A' : 'Model B'} - Epoch ${data.epoch + 1}`;
} else if (data.status === 'complete') {
// Handle training completion
document.getElementById('training-progress-text').textContent = 'Training Complete!';
displayComparisonResults(data.metrics);
} else if (data.status === 'error') {
// Handle error
console.error('Training error:', data.message);
alert(`Training error: ${data.message}`);
}
}