sql-generator-ai / script.js
jercox's picture
No debe ser simulado
5c3d90a verified
Raw
History Blame Contribute Delete
12.5 kB
class SQLGeneratorAI {
constructor() {
this.trainingData = [];
this.model = null;
this.tokenizer = null;
this.isModelInitialized = false;
this.epochsTrained = 0;
this.currentLoss = 0;
this.initializeElements();
this.attachEventListeners();
this.updateUI();
}
initializeElements() {
this.elements = {
descriptionInput: document.getElementById('descriptionInput'),
sqlInput: document.getElementById('sqlInput'),
testInput: document.getElementById('testInput'),
addDataBtn: document.getElementById('addDataBtn'),
initModelBtn: document.getElementById('initModelBtn'),
trainModelBtn: document.getElementById('trainModelBtn'),
generateBtn: document.getElementById('generateBtn'),
trainingDataList: document.getElementById('trainingDataList'),
output: document.getElementById('output'),
modelStatus: document.getElementById('modelStatus'),
statusIndicator: document.getElementById('statusIndicator'),
trainingProgress: document.getElementById('trainingProgress'),
progressText: document.getElementById('progressText'),
dataCount: document.getElementById('dataCount'),
epochsTrained: document.getElementById('epochsTrained'),
lossValue: document.getElementById('lossValue')
};
}
attachEventListeners() {
this.elements.addDataBtn.addEventListener('click', () => this.addTrainingData());
this.elements.initModelBtn.addEventListener('click', () => this.initializeModel());
this.elements.trainModelBtn.addEventListener('click', () => this.trainModel());
this.elements.generateBtn.addEventListener('click', () => this.generateSQL());
// Allow Enter key to add training data
this.elements.sqlInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') this.addTrainingData();
});
// Allow Enter key to generate SQL
this.elements.testInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') this.generateSQL();
});
}
addTrainingData() {
const description = this.elements.descriptionInput.value.trim();
const sql = this.elements.sqlInput.value.trim();
if (!description || !sql) {
alert('Please enter both description and SQL clause');
return;
}
this.trainingData.push({ description, sql });
this.elements.descriptionInput.value = '';
this.elements.sqlInput.value = '';
this.elements.descriptionInput.focus();
this.updateTrainingDataList();
this.updateUI();
}
updateTrainingDataList() {
if (this.trainingData.length === 0) {
this.elements.trainingDataList.innerHTML = `
<div class="text-center py-8 text-gray-500">
<i data-feather="database" class="w-12 h-12 mx-auto mb-2"></i>
<p>No training data yet</p>
</div>
`;
return;
}
this.elements.trainingDataList.innerHTML = this.trainingData.map((item, index) => `
<div class="training-item p-3 border-b border-gray-100 flex justify-between items-center">
<div>
<div class="font-medium">${item.description}</div>
<div class="text-sm text-gray-600 font-mono">${item.sql}</div>
</div>
<button
class="delete-btn text-red-500 hover:text-red-700"
data-index="${index}"
>
<i data-feather="x"></i>
</button>
</div>
`).join('');
// Add event listeners to delete buttons
document.querySelectorAll('.delete-btn').forEach(button => {
button.addEventListener('click', (e) => {
const index = parseInt(e.currentTarget.getAttribute('data-index'));
this.trainingData.splice(index, 1);
this.updateTrainingDataList();
this.updateUI();
});
});
feather.replace();
}
async initializeModel() {
try {
this.updateStatus('Initializing model...', 'training');
// Create a simple sequential model for text processing
this.model = tf.sequential({
layers: [
tf.layers.embedding({inputDim: 1000, outputDim: 64, inputLength: 20}),
tf.layers.lstm({units: 32, returnSequences: true}),
tf.layers.lstm({units: 32}),
tf.layers.dense({units: 64, activation: 'relu'}),
tf.layers.dense({units: 100, activation: 'linear'})
]
});
// Compile the model
this.model.compile({
optimizer: tf.train.adam(0.001),
loss: 'meanSquaredError',
metrics: ['accuracy']
});
this.isModelInitialized = true;
this.updateStatus('Model initialized', 'active');
this.updateUI();
} catch (error) {
console.error('Error initializing model:', error);
this.updateStatus('Initialization failed', 'error');
}
}
async trainModel() {
if (!this.isModelInitialized || this.trainingData.length === 0) {
alert('Please initialize model and add training data first');
return;
}
try {
this.updateStatus('Training model...', 'training');
this.elements.trainModelBtn.disabled = true;
// Prepare training data
const { descriptions, sqlClauses } = this.prepareTrainingData();
// Convert to tensors
const xs = tf.tensor2d(descriptions, [descriptions.length, descriptions[0].length]);
const ys = tf.tensor2d(sqlClauses, [sqlClauses.length, sqlClauses[0].length]);
// Train the model
const epochs = 20;
await this.model.fit(xs, ys, {
epochs: epochs,
batchSize: 4,
callbacks: {
onEpochEnd: async (epoch, logs) => {
const progress = ((epoch + 1) / epochs) * 100;
this.elements.trainingProgress.style.width = `${progress}%`;
this.elements.progressText.textContent = `${Math.round(progress)}%`;
this.epochsTrained = epoch + 1;
this.currentLoss = logs.loss.toFixed(4);
this.updateModelInfo();
}
}
});
// Dispose tensors to free memory
xs.dispose();
ys.dispose();
this.updateStatus('Training complete', 'active');
this.elements.trainModelBtn.disabled = false;
} catch (error) {
console.error('Error training model:', error);
this.updateStatus('Training failed', 'error');
this.elements.trainModelBtn.disabled = false;
}
}
prepareTrainingData() {
// Simple tokenization for demonstration
const vocab = {};
let vocabIndex = 1; // 0 reserved for padding
// Build vocabulary
this.trainingData.forEach(item => {
const tokens = [...item.description.toLowerCase().split(' '), ...item.sql.toLowerCase().split(' ')];
tokens.forEach(token => {
if (!vocab[token]) {
vocab[token] = vocabIndex++;
}
});
});
// Convert text to sequences
const descriptions = [];
const sqlClauses = [];
this.trainingData.forEach(item => {
const descTokens = item.description.toLowerCase().split(' ').map(token => vocab[token] || 0);
const sqlTokens = item.sql.toLowerCase().split(' ').map(token => vocab[token] || 0);
// Pad or truncate to fixed length
const descSeq = this.padSequence(descTokens, 20);
const sqlSeq = this.padSequence(sqlTokens, 20);
descriptions.push(descSeq);
sqlClauses.push(sqlSeq);
});
return { descriptions, sqlClauses };
}
padSequence(sequence, maxLength) {
if (sequence.length > maxLength) {
return sequence.slice(0, maxLength);
} else {
return [...sequence, ...Array(maxLength - sequence.length).fill(0)];
}
}
generateSQL() {
if (!this.isModelInitialized) {
alert('Please initialize the model first');
return;
}
const input = this.elements.testInput.value.trim();
if (!input) {
alert('Please enter a description');
return;
}
try {
// Tokenize input
const tokens = input.toLowerCase().split(' ');
const vocab = {};
let vocabIndex = 1;
// Build vocabulary from training data
this.trainingData.forEach(item => {
const itemTokens = [...item.description.toLowerCase().split(' '), ...item.sql.toLowerCase().split(' ')];
itemTokens.forEach(token => {
if (!vocab[token]) {
vocab[token] = vocabIndex++;
}
});
});
// Convert tokens to indices
const indices = tokens.map(token => vocab[token] || 0);
const sequence = this.padSequence(indices, 20);
// Convert to tensor
const inputTensor = tf.tensor2d([sequence]);
// Make prediction
const prediction = this.model.predict(inputTensor);
const predictedValues = prediction.dataSync();
// Convert back to text (simplified)
const predictedTokens = Array.from(predictedValues)
.map(val => Math.round(val))
.filter(val => val > 0)
.slice(0, 10);
// For demonstration, we'll create a simple SQL clause
let sqlClause = "WHERE ";
if (predictedTokens.length > 0) {
sqlClause += predictedTokens.map(val => `col${val}`).join(" AND ");
} else {
sqlClause += "1=1"; // Default clause
}
this.elements.output.innerHTML = `
<div class="output-highlight">
<span class="text-gray-800">${sqlClause}</span>
</div>
`;
// Dispose tensor to free memory
inputTensor.dispose();
prediction.dispose();
} catch (error) {
console.error('Error generating SQL:', error);
this.elements.output.innerHTML = `
<div class="text-red-500">
Error generating SQL. Please check console for details.
</div>
`;
}
}
updateStatus(message, status) {
this.elements.modelStatus.textContent = message;
this.elements.statusIndicator.className = `w-3 h-3 rounded-full mr-2 status-indicator ${status}`;
}
updateModelInfo() {
this.elements.dataCount.textContent = this.trainingData.length;
this.elements.epochsTrained.textContent = this.epochsTrained;
this.elements.lossValue.textContent = this.currentLoss;
}
updateUI() {
// Enable/disable buttons based on state
this.elements.trainModelBtn.disabled = !this.isModelInitialized || this.trainingData.length === 0;
this.elements.generateBtn.disabled = !this.isModelInitialized;
// Update model info
this.updateModelInfo();
}
}
// Initialize the app when the DOM is loaded
document.addEventListener('DOMContentLoaded', () => {
window.sqlGenerator = new SQLGeneratorAI();
});