let ws; function initializeComparisonCharts() { const lossData = [{ name: 'Model A Loss', x: [], y: [], type: 'scatter' }, { name: 'Model B Loss', x: [], y: [], type: 'scatter' }]; const accuracyData = [{ name: 'Model A Accuracy', x: [], y: [], type: 'scatter' }, { name: 'Model B Accuracy', x: [], y: [], type: 'scatter' }]; Plotly.newPlot('comparison-loss-plot', lossData, { title: 'Loss Comparison', xaxis: { title: 'Iterations' }, yaxis: { title: 'Loss' } }); Plotly.newPlot('comparison-accuracy-plot', accuracyData, { title: 'Accuracy Comparison', xaxis: { title: 'Iterations' }, yaxis: { title: 'Accuracy (%)' } }); } async function compareModels() { const config = { model1: { kernels: [ parseInt(document.getElementById('model1_kernel1').value), parseInt(document.getElementById('model1_kernel2').value), parseInt(document.getElementById('model1_kernel3').value) ], optimizer: document.getElementById('model1_optimizer').value, batch_size: parseInt(document.getElementById('model1_batch_size').value), epochs: parseInt(document.getElementById('model1_epochs').value) }, model2: { kernels: [ parseInt(document.getElementById('model2_kernel1').value), parseInt(document.getElementById('model2_kernel2').value), parseInt(document.getElementById('model2_kernel3').value) ], optimizer: document.getElementById('model2_optimizer').value, batch_size: parseInt(document.getElementById('model2_batch_size').value), epochs: parseInt(document.getElementById('model2_epochs').value) } }; // Show comparison progress section document.getElementById('comparison-progress').classList.remove('hidden'); initializeComparisonCharts(); try { const response = await fetch('/api/train_compare', { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify(config) }); const data = await response.json(); if (data.status === 'success') { displayComparisonResults(data); alert('Model comparison completed successfully!'); } } catch (error) { console.error('Error:', error); alert('Error during model comparison. Please check console for details.'); } } function displayComparisonResults(data) { const logsDiv = document.getElementById('comparison-logs'); logsDiv.innerHTML = `

Model A

Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}

Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%

Model Name: ${data.model1_results.model_name}

Model B

Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}

Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%

Model Name: ${data.model2_results.model_name}

`; } // Add these helper functions to get the parameters function getModelParameters() { try { const params = { model_a: { block1: parseInt(document.getElementById('model1_kernel1').value), block2: parseInt(document.getElementById('model1_kernel2').value), block3: parseInt(document.getElementById('model1_kernel3').value), optimizer: document.getElementById('model1_optimizer').value, batch_size: parseInt(document.getElementById('model1_batch_size').value), epochs: parseInt(document.getElementById('model1_epochs').value) }, model_b: { block1: parseInt(document.getElementById('model2_kernel1').value), block2: parseInt(document.getElementById('model2_kernel2').value), block3: parseInt(document.getElementById('model2_kernel3').value), optimizer: document.getElementById('model2_optimizer').value, batch_size: parseInt(document.getElementById('model2_batch_size').value), epochs: parseInt(document.getElementById('model2_epochs').value) } }; // Validate that all values are present and valid for (const model of ['model_a', 'model_b']) { for (const [key, value] of Object.entries(params[model])) { if (value === null || value === undefined || Number.isNaN(value)) { throw new Error(`Invalid value for ${model} ${key}: ${value}`); } } } console.log('Collected and validated model parameters:', params); return params; } catch (error) { console.error('Error in getModelParameters:', error); throw error; } } function getDatasetParameters() { return { batch_size: parseInt(document.getElementById('model1_batch_size').value), // Using model1's batch size for dataset shuffle: true }; } // Update the WebSocket event listener document.getElementById('startComparisonBtn').addEventListener('click', function() { console.log('Start Comparison button clicked'); // Validate form inputs before proceeding const formInputs = document.querySelectorAll('input[type="number"], select'); // Added select for optimizer let isValid = true; let formValues = {}; formInputs.forEach(input => { console.log(`Checking input ${input.id}: ${input.value}`); formValues[input.id] = input.value; if (!input.value) { console.error(`Missing value for ${input.id}`); isValid = false; } }); console.log('Form values:', formValues); // Log all form values if (!isValid) { alert('Please fill in all required fields'); return; } // Show comparison progress section document.getElementById('comparison-progress').classList.remove('hidden'); console.log('Initialized comparison charts'); initializeComparisonCharts(); console.log('Attempting WebSocket connection...'); const ws = new WebSocket(`ws://${window.location.host}/ws/compare`); ws.onopen = function() { console.log('WebSocket connection established'); const parameters = { model_params: getModelParameters(), dataset_params: getDatasetParameters() }; const message = { action: 'start_training', parameters: parameters }; console.log('Preparing to send message:', JSON.stringify(message, null, 2)); // Add a small delay to ensure WebSocket is ready setTimeout(() => { try { ws.send(JSON.stringify(message)); console.log('Message sent successfully'); } catch (error) { console.error('Error sending message:', error); alert('Error sending training parameters. Please check console for details.'); } }, 100); }; ws.onmessage = function(event) { console.log('Received WebSocket message:', event.data); try { const data = JSON.parse(event.data); console.log('Parsed message data:', data); updateTrainingProgress(data); } catch (error) { console.error('Error processing message:', error); } }; ws.onerror = function(error) { console.error('WebSocket error:', error); alert('Connection error occurred. Please check console for details.'); }; ws.onclose = function(event) { console.log('WebSocket connection closed. Code:', event.code, 'Reason:', event.reason); }; }); // Add the updateTrainingProgress function function updateTrainingProgress(data) { if (data.status === 'training') { // Update loss plot Plotly.extendTraces('comparison-loss-plot', { y: [[data.metrics.loss]], }, [data.model === 'A' ? 0 : 1]); // Update accuracy plot Plotly.extendTraces('comparison-accuracy-plot', { y: [[data.metrics.accuracy]], }, [data.model === 'A' ? 0 : 1]); // Update progress text const progressText = document.getElementById('training-progress-text'); progressText.textContent = `Training ${data.model === 'A' ? 'Model A' : 'Model B'} - Epoch ${data.epoch + 1}`; } else if (data.status === 'complete') { // Handle training completion document.getElementById('training-progress-text').textContent = 'Training Complete!'; displayComparisonResults(data.metrics); } else if (data.status === 'error') { // Handle error console.error('Training error:', data.message); alert(`Training error: ${data.message}`); } }