MnistStudio / static /js /inference.js
Shilpaj's picture
Feat: Complete single model training and inference
61f0070
let canvas, ctx;
window.onload = function() {
canvas = document.getElementById('drawing-canvas');
ctx = canvas.getContext('2d');
setupCanvas();
};
function setupCanvas() {
ctx.fillStyle = "white";
ctx.fillRect(0, 0, canvas.width, canvas.height);
let drawing = false;
canvas.addEventListener('mousedown', startDrawing);
canvas.addEventListener('mousemove', draw);
canvas.addEventListener('mouseup', stopDrawing);
canvas.addEventListener('mouseout', stopDrawing);
function startDrawing(e) {
drawing = true;
draw(e);
}
function draw(e) {
if (!drawing) return;
const rect = canvas.getBoundingClientRect();
const x = e.clientX - rect.left;
const y = e.clientY - rect.top;
ctx.lineWidth = 15;
ctx.lineCap = 'round';
ctx.strokeStyle = 'black';
ctx.lineTo(x, y);
ctx.stroke();
ctx.beginPath();
ctx.moveTo(x, y);
}
function stopDrawing() {
drawing = false;
ctx.beginPath();
}
}
function clearCanvas() {
const canvas = document.getElementById('drawing-canvas');
const ctx = canvas.getContext('2d');
// Clear the canvas
ctx.fillStyle = "white";
ctx.fillRect(0, 0, canvas.width, canvas.height);
ctx.beginPath();
// Hide and clear prediction result
const resultDiv = document.getElementById('prediction-result');
resultDiv.classList.add('hidden');
resultDiv.innerHTML = '';
}
async function predict() {
const modelSelect = document.getElementById('model-select');
const selectedModel = modelSelect.value;
if (!selectedModel) {
alert('Please train a model first');
return;
}
const imageData = canvas.toDataURL('image/png');
try {
const response = await fetch('/api/inference', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
image: imageData,
model_name: selectedModel
})
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.detail || 'Prediction failed');
}
const data = await response.json();
displayPrediction(data.prediction);
} catch (error) {
console.error('Error:', error);
alert(error.message || 'Error during prediction');
}
}
function displayPrediction(prediction) {
const resultDiv = document.getElementById('prediction-result');
resultDiv.classList.remove('hidden');
resultDiv.innerHTML = `
<h2>Prediction Result</h2>
<p class="prediction-text">Predicted Digit: ${prediction}</p>
<div class="confidence-bar">
<div class="confidence-level" style="width: 100%"></div>
</div>
`;
}