let ws; let lossChart; let accuracyChart; function showTrainingForm(type) { const singleForm = document.getElementById('single-model-form'); const compareForm = document.getElementById('compare-models-form'); if (type === 'single') { singleForm.classList.remove('hidden'); compareForm.classList.add('hidden'); } else { singleForm.classList.add('hidden'); compareForm.classList.remove('hidden'); } } function initializeCharts() { const lossData = [{ name: 'Training Loss', x: [], y: [], type: 'scatter' }, { name: 'Validation Loss', x: [], y: [], type: 'scatter' }]; const accuracyData = [{ name: 'Training Accuracy', x: [], y: [], type: 'scatter' }, { name: 'Validation Accuracy', x: [], y: [], type: 'scatter' }]; Plotly.newPlot('loss-plot', lossData, { title: 'Training and Validation Loss', xaxis: { title: 'Iterations' }, yaxis: { title: 'Loss' } }); Plotly.newPlot('accuracy-plot', accuracyData, { title: 'Training and Validation Accuracy', xaxis: { title: 'Iterations' }, yaxis: { title: 'Accuracy (%)' } }); } function updateCharts(data) { const iteration = data.epoch * data.batch; Plotly.extendTraces('loss-plot', { x: [[iteration], [iteration]], y: [[data.train_loss], [data.val_loss]] }, [0, 1]); Plotly.extendTraces('accuracy-plot', { x: [[iteration], [iteration]], y: [[data.train_acc], [data.val_acc]] }, [0, 1]); // Update training logs const logsDiv = document.getElementById('training-logs'); logsDiv.innerHTML = `

Epoch: ${data.epoch + 1}

Training Loss: ${data.train_loss.toFixed(4)}

Training Accuracy: ${data.train_acc.toFixed(2)}%

Validation Loss: ${data.val_loss.toFixed(4)}

Validation Accuracy: ${data.val_acc.toFixed(2)}%

`; } async function trainSingleModel() { const config = { kernels: [ parseInt(document.getElementById('kernel1').value), parseInt(document.getElementById('kernel2').value), parseInt(document.getElementById('kernel3').value) ], optimizer: document.getElementById('optimizer').value, batch_size: parseInt(document.getElementById('batch_size').value), epochs: parseInt(document.getElementById('epochs').value) }; // Show progress section and initialize charts document.getElementById('training-progress').classList.remove('hidden'); initializeCharts(); // Connect to WebSocket ws = new WebSocket(`ws://${window.location.host}/ws/train`); ws.onmessage = function(event) { const data = JSON.parse(event.data); updateCharts(data); }; try { const response = await fetch('/api/train_single', { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify(config) }); const data = await response.json(); if (data.status === 'success') { alert('Training completed successfully!'); } } catch (error) { console.error('Error:', error); alert('Error during training. Please check console for details.'); } } async function compareModels() { const config = { model1: { kernels: [ parseInt(document.getElementById('model1_kernel1').value), parseInt(document.getElementById('model1_kernel2').value), parseInt(document.getElementById('model1_kernel3').value) ], optimizer: document.getElementById('model1_optimizer').value, batch_size: parseInt(document.getElementById('model1_batch_size').value), epochs: parseInt(document.getElementById('model1_epochs').value) }, model2: { kernels: [ parseInt(document.getElementById('model2_kernel1').value), parseInt(document.getElementById('model2_kernel2').value), parseInt(document.getElementById('model2_kernel3').value) ], optimizer: document.getElementById('model2_optimizer').value, batch_size: parseInt(document.getElementById('model2_batch_size').value), epochs: parseInt(document.getElementById('model2_epochs').value) } }; // Show comparison progress section document.getElementById('comparison-progress').classList.remove('hidden'); initializeComparisonCharts(); try { const response = await fetch('/api/train_compare', { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify(config) }); const data = await response.json(); if (data.status === 'success') { displayComparisonResults(data); alert('Model comparison completed successfully!'); } } catch (error) { console.error('Error:', error); alert('Error during model comparison. Please check console for details.'); } } function initializeComparisonCharts() { const lossData = [{ name: 'Model A Loss', x: [], y: [], type: 'scatter' }, { name: 'Model B Loss', x: [], y: [], type: 'scatter' }]; const accuracyData = [{ name: 'Model A Accuracy', x: [], y: [], type: 'scatter' }, { name: 'Model B Accuracy', x: [], y: [], type: 'scatter' }]; Plotly.newPlot('comparison-loss-plot', lossData, { title: 'Loss Comparison', xaxis: { title: 'Iterations' }, yaxis: { title: 'Loss' } }); Plotly.newPlot('comparison-accuracy-plot', accuracyData, { title: 'Accuracy Comparison', xaxis: { title: 'Iterations' }, yaxis: { title: 'Accuracy (%)' } }); } function displayComparisonResults(data) { const logsDiv = document.getElementById('comparison-logs'); logsDiv.innerHTML = `

Model A

Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}

Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%

Model Name: ${data.model1_results.model_name}

Model B

Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}

Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%

Model Name: ${data.model2_results.model_name}

`; } function displayResults(data) { const resultsDiv = document.getElementById('training-results'); // Display training results }