Spaces:
Running
Running
| class SQLGeneratorAI { | |
| constructor() { | |
| this.trainingData = []; | |
| this.model = null; | |
| this.tokenizer = null; | |
| this.isModelInitialized = false; | |
| this.epochsTrained = 0; | |
| this.currentLoss = 0; | |
| this.initializeElements(); | |
| this.attachEventListeners(); | |
| this.updateUI(); | |
| } | |
| initializeElements() { | |
| this.elements = { | |
| descriptionInput: document.getElementById('descriptionInput'), | |
| sqlInput: document.getElementById('sqlInput'), | |
| testInput: document.getElementById('testInput'), | |
| addDataBtn: document.getElementById('addDataBtn'), | |
| initModelBtn: document.getElementById('initModelBtn'), | |
| trainModelBtn: document.getElementById('trainModelBtn'), | |
| generateBtn: document.getElementById('generateBtn'), | |
| trainingDataList: document.getElementById('trainingDataList'), | |
| output: document.getElementById('output'), | |
| modelStatus: document.getElementById('modelStatus'), | |
| statusIndicator: document.getElementById('statusIndicator'), | |
| trainingProgress: document.getElementById('trainingProgress'), | |
| progressText: document.getElementById('progressText'), | |
| dataCount: document.getElementById('dataCount'), | |
| epochsTrained: document.getElementById('epochsTrained'), | |
| lossValue: document.getElementById('lossValue') | |
| }; | |
| } | |
| attachEventListeners() { | |
| this.elements.addDataBtn.addEventListener('click', () => this.addTrainingData()); | |
| this.elements.initModelBtn.addEventListener('click', () => this.initializeModel()); | |
| this.elements.trainModelBtn.addEventListener('click', () => this.trainModel()); | |
| this.elements.generateBtn.addEventListener('click', () => this.generateSQL()); | |
| // Allow Enter key to add training data | |
| this.elements.sqlInput.addEventListener('keypress', (e) => { | |
| if (e.key === 'Enter') this.addTrainingData(); | |
| }); | |
| // Allow Enter key to generate SQL | |
| this.elements.testInput.addEventListener('keypress', (e) => { | |
| if (e.key === 'Enter') this.generateSQL(); | |
| }); | |
| } | |
| addTrainingData() { | |
| const description = this.elements.descriptionInput.value.trim(); | |
| const sql = this.elements.sqlInput.value.trim(); | |
| if (!description || !sql) { | |
| alert('Please enter both description and SQL clause'); | |
| return; | |
| } | |
| this.trainingData.push({ description, sql }); | |
| this.elements.descriptionInput.value = ''; | |
| this.elements.sqlInput.value = ''; | |
| this.elements.descriptionInput.focus(); | |
| this.updateTrainingDataList(); | |
| this.updateUI(); | |
| } | |
| updateTrainingDataList() { | |
| if (this.trainingData.length === 0) { | |
| this.elements.trainingDataList.innerHTML = ` | |
| <div class="text-center py-8 text-gray-500"> | |
| <i data-feather="database" class="w-12 h-12 mx-auto mb-2"></i> | |
| <p>No training data yet</p> | |
| </div> | |
| `; | |
| return; | |
| } | |
| this.elements.trainingDataList.innerHTML = this.trainingData.map((item, index) => ` | |
| <div class="training-item p-3 border-b border-gray-100 flex justify-between items-center"> | |
| <div> | |
| <div class="font-medium">${item.description}</div> | |
| <div class="text-sm text-gray-600 font-mono">${item.sql}</div> | |
| </div> | |
| <button | |
| class="delete-btn text-red-500 hover:text-red-700" | |
| data-index="${index}" | |
| > | |
| <i data-feather="x"></i> | |
| </button> | |
| </div> | |
| `).join(''); | |
| // Add event listeners to delete buttons | |
| document.querySelectorAll('.delete-btn').forEach(button => { | |
| button.addEventListener('click', (e) => { | |
| const index = parseInt(e.currentTarget.getAttribute('data-index')); | |
| this.trainingData.splice(index, 1); | |
| this.updateTrainingDataList(); | |
| this.updateUI(); | |
| }); | |
| }); | |
| feather.replace(); | |
| } | |
| async initializeModel() { | |
| try { | |
| this.updateStatus('Initializing model...', 'training'); | |
| // Create a simple sequential model for text processing | |
| this.model = tf.sequential({ | |
| layers: [ | |
| tf.layers.embedding({inputDim: 1000, outputDim: 64, inputLength: 20}), | |
| tf.layers.lstm({units: 32, returnSequences: true}), | |
| tf.layers.lstm({units: 32}), | |
| tf.layers.dense({units: 64, activation: 'relu'}), | |
| tf.layers.dense({units: 100, activation: 'linear'}) | |
| ] | |
| }); | |
| // Compile the model | |
| this.model.compile({ | |
| optimizer: tf.train.adam(0.001), | |
| loss: 'meanSquaredError', | |
| metrics: ['accuracy'] | |
| }); | |
| this.isModelInitialized = true; | |
| this.updateStatus('Model initialized', 'active'); | |
| this.updateUI(); | |
| } catch (error) { | |
| console.error('Error initializing model:', error); | |
| this.updateStatus('Initialization failed', 'error'); | |
| } | |
| } | |
| async trainModel() { | |
| if (!this.isModelInitialized || this.trainingData.length === 0) { | |
| alert('Please initialize model and add training data first'); | |
| return; | |
| } | |
| try { | |
| this.updateStatus('Training model...', 'training'); | |
| this.elements.trainModelBtn.disabled = true; | |
| // Prepare training data | |
| const { descriptions, sqlClauses } = this.prepareTrainingData(); | |
| // Convert to tensors | |
| const xs = tf.tensor2d(descriptions, [descriptions.length, descriptions[0].length]); | |
| const ys = tf.tensor2d(sqlClauses, [sqlClauses.length, sqlClauses[0].length]); | |
| // Train the model | |
| const epochs = 20; | |
| await this.model.fit(xs, ys, { | |
| epochs: epochs, | |
| batchSize: 4, | |
| callbacks: { | |
| onEpochEnd: async (epoch, logs) => { | |
| const progress = ((epoch + 1) / epochs) * 100; | |
| this.elements.trainingProgress.style.width = `${progress}%`; | |
| this.elements.progressText.textContent = `${Math.round(progress)}%`; | |
| this.epochsTrained = epoch + 1; | |
| this.currentLoss = logs.loss.toFixed(4); | |
| this.updateModelInfo(); | |
| } | |
| } | |
| }); | |
| // Dispose tensors to free memory | |
| xs.dispose(); | |
| ys.dispose(); | |
| this.updateStatus('Training complete', 'active'); | |
| this.elements.trainModelBtn.disabled = false; | |
| } catch (error) { | |
| console.error('Error training model:', error); | |
| this.updateStatus('Training failed', 'error'); | |
| this.elements.trainModelBtn.disabled = false; | |
| } | |
| } | |
| prepareTrainingData() { | |
| // Simple tokenization for demonstration | |
| const vocab = {}; | |
| let vocabIndex = 1; // 0 reserved for padding | |
| // Build vocabulary | |
| this.trainingData.forEach(item => { | |
| const tokens = [...item.description.toLowerCase().split(' '), ...item.sql.toLowerCase().split(' ')]; | |
| tokens.forEach(token => { | |
| if (!vocab[token]) { | |
| vocab[token] = vocabIndex++; | |
| } | |
| }); | |
| }); | |
| // Convert text to sequences | |
| const descriptions = []; | |
| const sqlClauses = []; | |
| this.trainingData.forEach(item => { | |
| const descTokens = item.description.toLowerCase().split(' ').map(token => vocab[token] || 0); | |
| const sqlTokens = item.sql.toLowerCase().split(' ').map(token => vocab[token] || 0); | |
| // Pad or truncate to fixed length | |
| const descSeq = this.padSequence(descTokens, 20); | |
| const sqlSeq = this.padSequence(sqlTokens, 20); | |
| descriptions.push(descSeq); | |
| sqlClauses.push(sqlSeq); | |
| }); | |
| return { descriptions, sqlClauses }; | |
| } | |
| padSequence(sequence, maxLength) { | |
| if (sequence.length > maxLength) { | |
| return sequence.slice(0, maxLength); | |
| } else { | |
| return [...sequence, ...Array(maxLength - sequence.length).fill(0)]; | |
| } | |
| } | |
| generateSQL() { | |
| if (!this.isModelInitialized) { | |
| alert('Please initialize the model first'); | |
| return; | |
| } | |
| const input = this.elements.testInput.value.trim(); | |
| if (!input) { | |
| alert('Please enter a description'); | |
| return; | |
| } | |
| try { | |
| // Tokenize input | |
| const tokens = input.toLowerCase().split(' '); | |
| const vocab = {}; | |
| let vocabIndex = 1; | |
| // Build vocabulary from training data | |
| this.trainingData.forEach(item => { | |
| const itemTokens = [...item.description.toLowerCase().split(' '), ...item.sql.toLowerCase().split(' ')]; | |
| itemTokens.forEach(token => { | |
| if (!vocab[token]) { | |
| vocab[token] = vocabIndex++; | |
| } | |
| }); | |
| }); | |
| // Convert tokens to indices | |
| const indices = tokens.map(token => vocab[token] || 0); | |
| const sequence = this.padSequence(indices, 20); | |
| // Convert to tensor | |
| const inputTensor = tf.tensor2d([sequence]); | |
| // Make prediction | |
| const prediction = this.model.predict(inputTensor); | |
| const predictedValues = prediction.dataSync(); | |
| // Convert back to text (simplified) | |
| const predictedTokens = Array.from(predictedValues) | |
| .map(val => Math.round(val)) | |
| .filter(val => val > 0) | |
| .slice(0, 10); | |
| // For demonstration, we'll create a simple SQL clause | |
| let sqlClause = "WHERE "; | |
| if (predictedTokens.length > 0) { | |
| sqlClause += predictedTokens.map(val => `col${val}`).join(" AND "); | |
| } else { | |
| sqlClause += "1=1"; // Default clause | |
| } | |
| this.elements.output.innerHTML = ` | |
| <div class="output-highlight"> | |
| <span class="text-gray-800">${sqlClause}</span> | |
| </div> | |
| `; | |
| // Dispose tensor to free memory | |
| inputTensor.dispose(); | |
| prediction.dispose(); | |
| } catch (error) { | |
| console.error('Error generating SQL:', error); | |
| this.elements.output.innerHTML = ` | |
| <div class="text-red-500"> | |
| Error generating SQL. Please check console for details. | |
| </div> | |
| `; | |
| } | |
| } | |
| updateStatus(message, status) { | |
| this.elements.modelStatus.textContent = message; | |
| this.elements.statusIndicator.className = `w-3 h-3 rounded-full mr-2 status-indicator ${status}`; | |
| } | |
| updateModelInfo() { | |
| this.elements.dataCount.textContent = this.trainingData.length; | |
| this.elements.epochsTrained.textContent = this.epochsTrained; | |
| this.elements.lossValue.textContent = this.currentLoss; | |
| } | |
| updateUI() { | |
| // Enable/disable buttons based on state | |
| this.elements.trainModelBtn.disabled = !this.isModelInitialized || this.trainingData.length === 0; | |
| this.elements.generateBtn.disabled = !this.isModelInitialized; | |
| // Update model info | |
| this.updateModelInfo(); | |
| } | |
| } | |
| // Initialize the app when the DOM is loaded | |
| document.addEventListener('DOMContentLoaded', () => { | |
| window.sqlGenerator = new SQLGeneratorAI(); | |
| }); |