|
|
<!DOCTYPE html>
|
|
|
<html lang="en">
|
|
|
|
|
|
<head>
|
|
|
<meta charset="UTF-8" />
|
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
|
<title>DeOldify Quantized (Browser)</title>
|
|
|
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
|
|
|
<style>
|
|
|
body {
|
|
|
font-family: sans-serif;
|
|
|
max-width: 800px;
|
|
|
margin: 0 auto;
|
|
|
padding: 20px;
|
|
|
}
|
|
|
|
|
|
h1 {
|
|
|
text-align: center;
|
|
|
}
|
|
|
|
|
|
.container {
|
|
|
display: flex;
|
|
|
flex-direction: column;
|
|
|
align-items: center;
|
|
|
gap: 20px;
|
|
|
}
|
|
|
|
|
|
canvas {
|
|
|
border: 1px solid #ccc;
|
|
|
max-width: 100%;
|
|
|
}
|
|
|
|
|
|
.controls {
|
|
|
margin-bottom: 20px;
|
|
|
}
|
|
|
|
|
|
#status {
|
|
|
font-weight: bold;
|
|
|
margin-top: 10px;
|
|
|
}
|
|
|
</style>
|
|
|
</head>
|
|
|
|
|
|
<body>
|
|
|
<h1>DeOldify Quantized Model</h1>
|
|
|
<p style="text-align: center;">Faster, smaller download (61MB), slightly lower quality.</p>
|
|
|
<div class="container">
|
|
|
<div class="controls">
|
|
|
<input type="file" id="imageInput" accept="image/*" />
|
|
|
</div>
|
|
|
<div id="status">Select an image to start...</div>
|
|
|
<canvas id="outputCanvas"></canvas>
|
|
|
</div>
|
|
|
|
|
|
<script>
|
|
|
const MODEL_URL = "https://huggingface.co/thookham/DeOldify-on-Browser/resolve/main/deoldify-quant.onnx";
|
|
|
let session = null;
|
|
|
|
|
|
const preprocess = (input_imageData, width, height) => {
|
|
|
const floatArr = new Float32Array(width * height * 3);
|
|
|
let j = 0;
|
|
|
for (let i = 0; i < input_imageData.data.length; i += 4) {
|
|
|
|
|
|
floatArr[j] = input_imageData.data[i] / 255.0;
|
|
|
floatArr[j + 1] = input_imageData.data[i + 1] / 255.0;
|
|
|
floatArr[j + 2] = input_imageData.data[i + 2] / 255.0;
|
|
|
j += 3;
|
|
|
}
|
|
|
return floatArr;
|
|
|
};
|
|
|
|
|
|
const postprocess = (tensor) => {
|
|
|
const channels = tensor.dims[1];
|
|
|
const height = tensor.dims[2];
|
|
|
const width = tensor.dims[3];
|
|
|
const imageData = new ImageData(width, height);
|
|
|
const data = imageData.data;
|
|
|
const tensorData = new Float32Array(tensor.data);
|
|
|
|
|
|
for (let h = 0; h < height; h++) {
|
|
|
for (let w = 0; w < width; w++) {
|
|
|
let rgb = [];
|
|
|
for (let c = 0; c < channels; c++) {
|
|
|
const tensorIndex = (c * height + h) * width + w;
|
|
|
const value = tensorData[tensorIndex];
|
|
|
|
|
|
let val = value * 255.0;
|
|
|
if (val < 0) val = 0;
|
|
|
if (val > 255) val = 255;
|
|
|
rgb.push(Math.round(val));
|
|
|
}
|
|
|
data[(h * width + w) * 4] = rgb[0];
|
|
|
data[(h * width + w) * 4 + 1] = rgb[1];
|
|
|
data[(h * width + w) * 4 + 2] = rgb[2];
|
|
|
data[(h * width + w) * 4 + 3] = 255;
|
|
|
}
|
|
|
}
|
|
|
return imageData;
|
|
|
};
|
|
|
|
|
|
async function init() {
|
|
|
const status = document.getElementById('status');
|
|
|
status.innerText = "Checking cache...";
|
|
|
try {
|
|
|
let buffer;
|
|
|
const cacheName = 'deoldify-models-v1';
|
|
|
|
|
|
|
|
|
try {
|
|
|
const cache = await caches.open(cacheName);
|
|
|
const cachedResponse = await cache.match(MODEL_URL);
|
|
|
|
|
|
if (cachedResponse) {
|
|
|
status.innerText = "Loading model from cache...";
|
|
|
const blob = await cachedResponse.blob();
|
|
|
buffer = await blob.arrayBuffer();
|
|
|
}
|
|
|
} catch (e) {
|
|
|
console.warn("Cache API not supported or failed:", e);
|
|
|
}
|
|
|
|
|
|
|
|
|
if (!buffer) {
|
|
|
status.innerText = "Downloading model from Hugging Face... 0%";
|
|
|
const response = await fetch(MODEL_URL);
|
|
|
if (!response.ok) throw new Error(`Failed to fetch model: ${response.statusText}`);
|
|
|
|
|
|
const contentLength = response.headers.get('content-length');
|
|
|
const total = contentLength ? parseInt(contentLength, 10) : 0;
|
|
|
let loaded = 0;
|
|
|
|
|
|
const reader = response.body.getReader();
|
|
|
const chunks = [];
|
|
|
|
|
|
while (true) {
|
|
|
const { done, value } = await reader.read();
|
|
|
if (done) break;
|
|
|
chunks.push(value);
|
|
|
loaded += value.length;
|
|
|
if (total) {
|
|
|
const progress = Math.round((loaded / total) * 100);
|
|
|
status.innerText = `Downloading model from Hugging Face... ${progress}%`;
|
|
|
} else {
|
|
|
status.innerText = `Downloading model from Hugging Face... ${(loaded / 1024 / 1024).toFixed(1)} MB`;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
const blob = new Blob(chunks);
|
|
|
buffer = await blob.arrayBuffer();
|
|
|
|
|
|
|
|
|
try {
|
|
|
const cache = await caches.open(cacheName);
|
|
|
await cache.put(MODEL_URL, new Response(blob));
|
|
|
console.log("Model saved to cache");
|
|
|
} catch (e) {
|
|
|
console.warn("Failed to save to cache:", e);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
status.innerText = "Initializing session...";
|
|
|
session = await ort.InferenceSession.create(buffer);
|
|
|
|
|
|
status.innerText = "Model loaded! Select an image.";
|
|
|
console.log("Session created:", session);
|
|
|
} catch (e) {
|
|
|
status.innerText = "Error loading model: " + e.message;
|
|
|
console.error(e);
|
|
|
if (e.message.includes("Failed to fetch")) {
|
|
|
status.innerHTML += "<br><br>⚠️ <b>CORS Error Detected</b>: If you are running this file directly (file://), you must use a local server.<br>Run <code>python -m http.server 8000</code> in the terminal and visit <code>http://localhost:8000/quantized.html</code>";
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
document.getElementById('imageInput').addEventListener('change', async function (e) {
|
|
|
if (!session) {
|
|
|
await init();
|
|
|
}
|
|
|
|
|
|
const file = e.target.files[0];
|
|
|
if (!file) return;
|
|
|
|
|
|
|
|
|
if (!file.type.startsWith('image/')) {
|
|
|
alert('Please select a valid image file.');
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
const image = new Image();
|
|
|
const objectUrl = URL.createObjectURL(file);
|
|
|
image.src = objectUrl;
|
|
|
|
|
|
image.onload = async function () {
|
|
|
document.getElementById('status').innerText = "Processing...";
|
|
|
|
|
|
|
|
|
let canvas = document.createElement("canvas");
|
|
|
const size = 256;
|
|
|
canvas.width = size;
|
|
|
canvas.height = size;
|
|
|
let ctx = canvas.getContext("2d");
|
|
|
ctx.drawImage(image, 0, 0, size, size);
|
|
|
|
|
|
const input_img = ctx.getImageData(0, 0, size, size);
|
|
|
const test = preprocess(input_img, size, size);
|
|
|
const input = new ort.Tensor(new Float32Array(test), [1, 3, size, size]);
|
|
|
|
|
|
try {
|
|
|
const result = await session.run({ "input": input });
|
|
|
|
|
|
const output = result["output"] || result["out"] || Object.values(result)[0];
|
|
|
|
|
|
if (!output) throw new Error("No output tensor found in model result");
|
|
|
|
|
|
const imgdata = postprocess(output);
|
|
|
|
|
|
|
|
|
const outCanvas = document.getElementById('outputCanvas');
|
|
|
outCanvas.width = image.width;
|
|
|
outCanvas.height = image.height;
|
|
|
const outCtx = outCanvas.getContext('2d');
|
|
|
|
|
|
|
|
|
const tempCanvas = document.createElement('canvas');
|
|
|
tempCanvas.width = size;
|
|
|
tempCanvas.height = size;
|
|
|
tempCanvas.getContext('2d').putImageData(imgdata, 0, 0);
|
|
|
|
|
|
|
|
|
outCtx.drawImage(tempCanvas, 0, 0, image.width, image.height);
|
|
|
|
|
|
document.getElementById('status').innerText = "Done!";
|
|
|
} catch (err) {
|
|
|
document.getElementById('status').innerText = "Error processing: " + err.message;
|
|
|
console.error(err);
|
|
|
} finally {
|
|
|
|
|
|
URL.revokeObjectURL(objectUrl);
|
|
|
}
|
|
|
};
|
|
|
});
|
|
|
|
|
|
|
|
|
init();
|
|
|
</script>
|
|
|
</body>
|
|
|
|
|
|
</html> |