Spaces:
Running
Running
(() => { | |
let API_URL = "https://rerun.aswerdlow.com/v1/chat/completions"; | |
const GRID_SIZE = 8; | |
window.autoResetOnMaskSelect = true; | |
window.enable_cache = false; | |
window.skipHashChecking = true; | |
let isImageRemoved = false; // Add flag to track if image is removed | |
let DISABLE_HASH_CHECKING = false; // Add this flag to globally disable hash checking | |
function getMaskSize() { | |
const maskSizeInput = document.getElementById('cached-mask-size'); | |
if (maskSizeInput) { | |
const size = parseInt(maskSizeInput.value, 10); | |
// Validate the size is between 2 and GRID_SIZE | |
return Math.min(Math.max(size, 2), GRID_SIZE); | |
} | |
return 6; // Default value if input not found | |
} | |
// --- Utility functions --- | |
async function processImage(imageBytes, targetResolution) { | |
const img = await createImageBitmap(new Blob([imageBytes], { type: 'image/jpeg' })); | |
const croppedCanvas = squareCrop(img); | |
// Resize step | |
const resizedCanvas = new OffscreenCanvas(targetResolution, targetResolution); | |
const ctx = resizedCanvas.getContext('2d'); | |
// Draw the cropped image onto the resized canvas | |
ctx.drawImage(croppedCanvas, 0, 0, targetResolution, targetResolution); | |
// Return blob from the resized canvas | |
return resizedCanvas.convertToBlob({ quality: 0.95, type: 'image/jpeg' }); | |
} | |
function squareCrop(img) { | |
const size = Math.min(img.width, img.height); | |
const canvas = new OffscreenCanvas(size, size); | |
const ctx = canvas.getContext('2d'); | |
ctx.drawImage(img, | |
(img.width - size) / 2, (img.height - size) / 2, size, size, | |
0, 0, size, size | |
); | |
return canvas; | |
} | |
function encodeMask(maskArray) { | |
const rows = maskArray.length; | |
const cols = maskArray[0].length; | |
const canvas = document.createElement('canvas'); | |
canvas.width = cols; | |
canvas.height = rows; | |
const ctx = canvas.getContext('2d'); | |
const imageData = ctx.createImageData(cols, rows); | |
for (let i = 0; i < maskArray.flat().length; i++) { | |
const val = maskArray.flat()[i]; | |
const color = val ? 255 : 0; | |
const pixelIndex = i * 4; // Each pixel uses 4 bytes in the array | |
imageData.data[pixelIndex] = color; // R | |
imageData.data[pixelIndex + 1] = color; // G | |
imageData.data[pixelIndex + 2] = color; // B | |
imageData.data[pixelIndex + 3] = color; // A | |
} | |
ctx.putImageData(imageData, 0, 0); | |
const dataURL = canvas.toDataURL("image/png"); | |
return { | |
data: dataURL.split(',')[1], | |
width: cols, | |
height: rows | |
}; | |
} | |
function isChromeBrowser() { | |
return /Chrome/.test(navigator.userAgent) && navigator.vendor === "Google Inc."; | |
} | |
// --- NEW: Function to get config from UI --- | |
function getApiConfig() { | |
const config = { | |
temperature: parseFloat(document.getElementById('config-temperature').value), | |
top_p: parseFloat(document.getElementById('config-top_p').value), | |
maskgit_r_temp: parseFloat(document.getElementById('config-maskgit_r_temp').value), | |
cfg: parseFloat(document.getElementById('config-cfg').value), | |
max_tokens: parseInt(document.getElementById('config-max_tokens').value, 10), | |
resolution: parseInt(document.getElementById('config-resolution').value, 10), | |
sampling_steps: parseInt(document.getElementById('config-sampling_steps').value, 10), | |
sampler: document.getElementById('config-sampler').value, | |
use_reward_models: document.getElementById('config-use_reward_models').checked | |
}; | |
console.log("Using API Config:", config); | |
return config; | |
} | |
// --- NEW: Function to update slider value display --- | |
function setupSliderValueDisplay() { | |
const sliders = [ | |
{ id: 'config-temperature', displayId: 'config-temperature-value' }, | |
{ id: 'config-top_p', displayId: 'config-top_p-value' }, | |
{ id: 'config-maskgit_r_temp', displayId: 'config-maskgit_r_temp-value' }, | |
{ id: 'config-cfg', displayId: 'config-cfg-value' }, | |
]; | |
sliders.forEach(sliderInfo => { | |
const slider = document.getElementById(sliderInfo.id); | |
const display = document.getElementById(sliderInfo.displayId); | |
if (slider && display) { | |
// Initial display update | |
display.textContent = slider.value; | |
// Update display on input change | |
slider.addEventListener('input', (event) => { | |
display.textContent = event.target.value; | |
}); | |
} | |
}); | |
} | |
async function callUnidiscAPI(imageBlob, maskArray, sentence, options = {}) { | |
if (!isChromeBrowser()) { | |
alert("Warning: The pre-cached demo only works in Chrome due to differences in hashing algorithms."); | |
} | |
// Use the global isImageRemoved flag directly. | |
let customAPIUrl = API_URL; | |
// Replace <mask> with <m> for API call | |
const apiSentence = sentence.replace(/<mask>/g, "<m>"); | |
console.log("Called API with sentence: ", apiSentence); | |
const messages = [{ | |
role: "user", | |
content: [ | |
...(apiSentence ? [{ type: "text", text: apiSentence }] : []) | |
] | |
}, | |
{ | |
role: "assistant", | |
content: [] | |
} | |
]; | |
const hasMaskedText = apiSentence.includes("<m") || !apiSentence || apiSentence.trim() === ""; | |
const hasMaskedImage = maskArray && maskArray.some(row => row.some(cell => cell === true)); | |
let imageBase64; | |
let maskData; | |
// Get target resolution from config *before* processing the image | |
const resolution = parseInt(document.getElementById('config-resolution').value, 10); | |
if ((hasMaskedText || hasMaskedImage) && !isImageRemoved) { | |
// Process the passed imageBlob with the target resolution | |
const resizedImage = await processImage(await imageBlob.arrayBuffer(), resolution); | |
imageBase64 = await new Promise(resolve => { | |
const reader = new FileReader(); | |
reader.onload = () => resolve(reader.result); | |
reader.readAsDataURL(resizedImage); | |
}); | |
messages[1].content.push({ | |
type: "image_url", | |
image_url: { url: imageBase64 }, | |
is_mask: false | |
}); | |
if (maskArray) { | |
maskData = encodeMask(maskArray); | |
messages[1].content.push({ | |
type: "image_url", | |
image_url: { | |
url: `data:image/png;base64,${maskData.data}`, | |
mask_info: JSON.stringify({ | |
width: maskData.width, | |
height: maskData.height | |
}) | |
}, | |
is_mask: true | |
}); | |
} | |
} | |
if (messages.length > 0 && | |
messages[messages.length - 1].role === 'assistant' && | |
(!messages[messages.length - 1].content || | |
messages[messages.length - 1].content.length === 0)) { | |
console.log("Removing empty assistant message"); | |
messages.pop(); // Remove the empty assistant message | |
} | |
// Create the payload without the hash first. | |
const payload = { | |
messages, | |
model: "unidisc", | |
...getApiConfig() // Use the function to get dynamic config | |
}; | |
// Caching logic - hash the entire request payload. | |
let hash = null; | |
console.log("window.skipHashChecking: ", window.skipHashChecking); | |
// Skip hash generation and checking if DISABLE_HASH_CHECKING is true | |
if (!window.skipHashChecking) { | |
try { | |
const payloadString = JSON.stringify(payload); | |
const encoder = new TextEncoder(); | |
const data = encoder.encode(payloadString); | |
if (typeof crypto !== 'undefined' && crypto.subtle) { | |
const hashBuffer = await crypto.subtle.digest('SHA-256', data); | |
const hashArray = Array.from(new Uint8Array(hashBuffer)); | |
hash = hashArray.map(b => b.toString(16).padStart(2, '0')).join(''); | |
console.log("Hash generated from full payload:", hash); | |
} else { | |
throw new Error('Web Crypto API is not available. Please ensure you are serving your page over HTTPS or via localhost.'); | |
} | |
} catch (error) { | |
console.error('Error generating hash:', error); | |
} | |
} | |
try { | |
// Check cache using the hash only if hash checking is enabled | |
if (hash && !window.skipHashChecking) { | |
try { | |
const response = await fetch(`/static/responses/${hash}.json`, { | |
mode: 'cors', | |
headers: { | |
'Accept': 'application/json' | |
} | |
}); | |
if (response.ok) { | |
const jsonContent = await response.text(); | |
console.log("Cache hit!"); | |
const cachedData = JSON.parse(jsonContent); | |
console.log("Cached data: ", cachedData); | |
return { | |
choices: [{ | |
index: 0, | |
message: cachedData, | |
finish_reason: "stop" | |
}] | |
}; | |
} else { | |
console.log("Cache miss:", response); | |
} | |
} catch (cacheError) { | |
console.log("Cache access failed:", cacheError); | |
console.log("Proceeding with direct API call"); | |
} | |
console.log("Hash: ", hash); | |
} | |
} catch (error) { | |
console.log("Cache miss:", error) | |
} | |
// Only add hash to payload if hash checking is enabled | |
if (hash && !DISABLE_HASH_CHECKING) { | |
payload.request_hash = hash; | |
} | |
console.log("Payload: ", payload); | |
try { | |
const response = await fetch(customAPIUrl, { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json', | |
'Access-Control-Allow-Origin': '*', | |
'Access-Control-Allow-Methods': 'POST, OPTIONS', | |
'Access-Control-Allow-Headers': 'Content-Type' | |
}, | |
body: JSON.stringify(payload) | |
}); | |
if (!response.ok) throw new Error(`API Error: ${response.status}`); | |
const data = await response.json(); | |
console.log("Response: ", data); | |
return data; | |
} catch (error) { | |
console.error('API call failed:', error); | |
throw error; | |
} | |
} | |
window.callUnidiscAPI = callUnidiscAPI; | |
const section = document.getElementById('cached-section'); | |
const grid = section.querySelector('#cached-grid'); | |
const textInput = section.querySelector('#cached-text-input'); // Get the new text input | |
const submitButton = section.querySelector('#cached-submit-text'); // Get the new submit button | |
const responseText = section.querySelector('#cached-response-text'); | |
const inputImage = section.querySelector('#cached-input-image'); | |
const outputImage = section.querySelector('#cached-output-image'); | |
const imageUploadInput = section.querySelector('#cached-image-upload'); // Get the file input | |
const cells = []; | |
let currentRow = 0; | |
let currentCol = 0; | |
let maskLocked = false; // Add this flag to track if mask is locked in place | |
let activeMask = null; // Track the currently active mask coordinates | |
for (let i = 0; i < GRID_SIZE * GRID_SIZE; i++) { | |
const cell = document.createElement('div'); | |
cell.className = 'cached-grid-cell'; | |
cell.dataset.row = Math.floor(i / GRID_SIZE); | |
cell.dataset.col = i % GRID_SIZE; | |
grid.appendChild(cell); | |
cells.push(cell); | |
} | |
function createMaskArray(topLeftRow, topLeftCol) { | |
const maskSize = getMaskSize(); | |
const maskArray = Array.from({ length: GRID_SIZE }, () => Array(GRID_SIZE).fill(false)); | |
for (let r = topLeftRow; r < topLeftRow + maskSize && r < GRID_SIZE; r++) { | |
for (let c = topLeftCol; c < topLeftCol + maskSize && c < GRID_SIZE; c++) { | |
maskArray[r][c] = true; | |
} | |
} | |
return maskArray; | |
} | |
function highlightCells(row, col) { | |
// If mask is locked, don't update highlights on mousemove | |
if (maskLocked) return; | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
const maskSize = getMaskSize(); | |
const offset = Math.floor(maskSize / 2); | |
const topLeftRow = Math.min(Math.max(row - offset, 0), GRID_SIZE - maskSize); | |
const topLeftCol = Math.min(Math.max(col - offset, 0), GRID_SIZE - maskSize); | |
currentRow = row; | |
currentCol = col; | |
for (let r = topLeftRow; r < topLeftRow + maskSize && r < GRID_SIZE; r++) { | |
for (let c = topLeftCol; c < topLeftCol + maskSize && c < GRID_SIZE; c++) { | |
const cell = cells[r * GRID_SIZE + c]; | |
if (cell) { | |
cell.classList.add('cached-highlighted'); | |
} | |
} | |
} | |
} | |
grid.addEventListener('mousemove', (e) => { | |
const rect = grid.getBoundingClientRect(); | |
const x = e.clientX - rect.left; | |
const y = e.clientY - rect.top; | |
const col = Math.floor((x / rect.width) * GRID_SIZE); | |
const row = Math.floor((y / rect.height) * GRID_SIZE); | |
highlightCells(row, col); | |
}); | |
async function updateOutput() { | |
try { | |
const responseText = section.querySelector('#cached-response-text'); | |
const greyOverlay = section.querySelector('#cached-grey-overlay'); | |
const outputOverlay = section.querySelector('#cached-output-overlay'); | |
let maskArray = null; | |
// Only create a mask if we have an active mask | |
if (activeMask) { | |
const [topLeftRow, topLeftCol] = activeMask; | |
maskArray = createMaskArray(topLeftRow, topLeftCol); | |
} | |
// Get sentence from the text input field's value | |
const sentence = textInput.value.trim(); | |
if (!sentence) { | |
// Optionally handle empty input case - maybe show a message? | |
responseText.textContent = ""; | |
outputOverlay.style.display = "none"; // Hide overlay if showing message | |
return; // Don't call API if input is empty | |
} | |
const imageBlob = await fetch(inputImage.src).then(res => res.blob()); | |
responseText.textContent = 'Processing your request...'; // Show loading state | |
outputOverlay.style.display = "block"; // Show grey overlay while loading | |
outputOverlay.innerHTML = ""; // Reset any previous custom text in the overlay | |
console.log("sentence: ", sentence); | |
const response = await callUnidiscAPI(imageBlob, maskArray, sentence); | |
const message = response.choices?.[0]?.message; | |
if (!message) { | |
throw new Error("No message found in the API response"); | |
} | |
// Extract text content if available | |
let textContent = ''; | |
if (Array.isArray(message.content)) { | |
const textPart = message.content.find(part => part.type === "text"); | |
if (textPart && textPart.text) { | |
textContent = textPart.text; | |
} | |
} else if (typeof message.content === 'string') { | |
textContent = message.content; | |
} | |
if (textContent) { | |
textContent = textContent.replace(/\s+/g, ' ').replace(/ ([b-zB-Z]) /g, " "); | |
responseText.textContent = textContent; | |
} else { | |
responseText.textContent = 'Image updated successfully!'; | |
} | |
// Check if there's an image in the response | |
let imageUrl = null; | |
if (Array.isArray(message.content)) { | |
const imagePart = message.content.find(part => part.type === "image_url"); | |
if (imagePart && imagePart.image_url && imagePart.image_url.url) { | |
imageUrl = imagePart.image_url.url; | |
} | |
} else if (message.image_url && message.image_url.url) { | |
imageUrl = message.image_url.url; | |
} | |
// Update the output image if we have an image response | |
if (imageUrl) { | |
const newImageUrl = imageUrl.startsWith("data:image/jpeg;base64,") | |
? imageUrl | |
: `data:image/jpeg;base64,${imageUrl}`; | |
outputImage.src = newImageUrl; | |
// Hide the output overlay when image is ready | |
outputOverlay.style.display = "none"; | |
} else { | |
// No image in response, but API was successful | |
// Show "Image Fixed" text on the overlay | |
outputOverlay.style.display = "block"; | |
outputOverlay.innerHTML = '<div style="display: flex; justify-content: center; align-items: center; height: 100%; color: white; font-size: 24px; font-weight: bold; text-shadow: 1px 1px 3px black;">Image Fixed</div>'; | |
} | |
} catch (error) { | |
console.error('Output update failed:', error); | |
responseText.textContent = 'Error: ' + error.message; | |
// Keep the grey overlay visible on error | |
outputOverlay.style.display = 'block'; // Ensure overlay shows on error | |
outputOverlay.innerHTML = ''; // Clear any "Image Fixed" text on error | |
} | |
} | |
grid.addEventListener('click', async () => { | |
const maskSize = getMaskSize(); | |
const offset = Math.floor(maskSize / 2); | |
const safeTopLeftRow = Math.min(Math.max(currentRow - offset, 0), GRID_SIZE - maskSize); | |
const safeTopLeftCol = Math.min(Math.max(currentCol - offset, 0), GRID_SIZE - maskSize); | |
if (isImageRemoved) { | |
// If the image is currently removed (greyed out), | |
// clicking the grid should probably do nothing or maybe restore the image first. | |
// For now, let's prevent interaction when image is removed. | |
console.log("Image is removed, grid click ignored."); | |
return; | |
} | |
if (maskLocked && | |
activeMask && | |
activeMask[0] === safeTopLeftRow && | |
activeMask[1] === safeTopLeftCol) { | |
// If clicking on the same mask area, unlock it | |
console.log("Clearing mask after clicking on the same mask area"); | |
maskLocked = false; | |
activeMask = null; | |
// Clear highlights | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
// Don't call updateOutput when just removing the mask | |
return; | |
} else { | |
// Lock the mask at current position | |
maskLocked = true; | |
activeMask = [safeTopLeftRow, safeTopLeftCol]; | |
// Ensure the mask area is properly highlighted | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
for (let r = safeTopLeftRow; r < safeTopLeftRow + maskSize && r < GRID_SIZE; r++) { | |
for (let c = safeTopLeftCol; c < safeTopLeftCol + maskSize && c < GRID_SIZE; c++) { | |
const cell = cells[r * GRID_SIZE + c]; | |
if (cell) { | |
cell.classList.add('cached-highlighted'); | |
} | |
} | |
} | |
} | |
try { | |
// Trigger update when mask is placed/changed | |
await updateOutput(); | |
} catch (error) { | |
console.error('Error updating output after grid click:', error); | |
} | |
}); | |
grid.addEventListener('mouseleave', () => { | |
// Only clear highlights if mask is not locked | |
if (!maskLocked) { | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
} | |
}); | |
// Initialize highlighting with default values. | |
highlightCells(1, 1); | |
// Add event listeners for the reset buttons | |
const resetImageButton = section.querySelector('#cached-reset-image'); | |
const clearMaskButton = section.querySelector('#cached-clear-mask'); | |
const removeImageButton = section.querySelector('#cached-remove-image'); | |
const greyOverlay = section.querySelector('#cached-grey-overlay'); | |
const originalImageSrc = "static/images/giraffe.png"; // Store the original image source | |
// Reset image button functionality | |
resetImageButton.addEventListener('click', () => { | |
inputImage.src = originalImageSrc; | |
inputImage.style.filter = "none"; // Clear any filters | |
isImageRemoved = false; // Reset the image removed flag | |
// Hide the grey overlay | |
greyOverlay.style.display = "none"; | |
// Also clear the mask when resetting the image | |
console.log("Clearing mask after resetting the image"); | |
maskLocked = false; | |
activeMask = null; | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
// Reset the output image and response text | |
outputImage.src = originalImageSrc; | |
responseText.textContent = 'Enter a sentence and interact with the image.'; | |
// Show the output overlay when resetting | |
document.querySelector('#cached-output-overlay').style.display = "block"; | |
document.querySelector('#cached-output-overlay').innerHTML = ""; // Clear potential "Image Fixed" text | |
}); | |
// Clear mask button functionality | |
clearMaskButton.addEventListener('click', () => { | |
console.log("Clearing mask without affecting the image"); | |
maskLocked = false; | |
activeMask = null; | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
// Update the output to reflect that the mask has been cleared | |
responseText.textContent = 'Mask cleared. Enter a sentence and interact with the image.'; | |
// Decide if you want to update the output image here or not. | |
// Maybe call updateOutput() if the text input is not empty? | |
}); | |
// Remove image button functionality | |
removeImageButton.addEventListener('click', async () => { | |
// Show the solid grey overlay | |
greyOverlay.style.display = "block"; | |
isImageRemoved = true; // Set flag to indicate image is removed | |
// Clear any active mask | |
console.log("Clearing mask after removing image"); | |
maskLocked = false; | |
activeMask = null; | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
// Call the API after fully masking the image | |
try { | |
// Update output now uses the text input value | |
await updateOutput(); | |
} catch (error) { | |
console.error('Error updating output after fully masking image:', error); | |
} | |
}); | |
// Add event listener for mask size changes to update the highlight | |
const maskSizeInput = document.getElementById('cached-mask-size'); | |
if (maskSizeInput) { | |
maskSizeInput.addEventListener('change', () => { | |
// If we have an active mask, clear it as the size has changed | |
if (maskLocked && activeMask) { | |
maskLocked = false; | |
activeMask = null; | |
cells.forEach(cell => cell.classList.remove('cached-highlighted')); | |
} | |
// Update the highlight with the current mouse position | |
if (currentRow !== undefined && currentCol !== undefined) { | |
highlightCells(currentRow, currentCol); | |
} | |
}); | |
} | |
// Add event listener for the new submit button | |
submitButton.addEventListener('click', async () => { | |
console.log("Submit button clicked"); | |
try { | |
// Trigger update when submit button is clicked | |
await updateOutput(); | |
} catch (error) { | |
console.error('Error updating output after submit click:', error); | |
} | |
}); | |
// Optional: Add event listener for Enter key in the text input | |
textInput.addEventListener('keypress', async (e) => { | |
if (e.key === 'Enter') { | |
console.log("Enter key pressed in text input"); | |
e.preventDefault(); // Prevent default form submission if inside a form | |
try { | |
// Trigger update when Enter is pressed | |
await updateOutput(); | |
} catch (error) { | |
console.error('Error updating output after Enter keypress:', error); | |
} | |
} | |
}); | |
// Set initial state | |
textInput.value = "a happy puppy wearing a top hat, cartoon style"; // Set initial text | |
responseText.textContent = 'Enter a sentence and interact with the image.'; // Initial message | |
document.querySelector('#cached-output-overlay').style.display = "block"; // Show overlay initially | |
// Optionally trigger an initial API call on load if desired | |
// updateOutput(); | |
setupSliderValueDisplay(); // Call the function to set up slider displays | |
// --- NEW: Event Listener for Image Upload --- | |
imageUploadInput.addEventListener('change', (event) => { | |
const file = event.target.files[0]; | |
if (file && (file.type === "image/jpeg" || file.type === "image/png")) { | |
const reader = new FileReader(); | |
reader.onload = (e) => { | |
// Set the input image source to the uploaded image data URL | |
inputImage.src = e.target.result; | |
// Reset output image view to match new input | |
outputImage.src = e.target.result; | |
// Ensure the image display area is visible and not greyed out | |
inputImage.style.filter = "none"; | |
isImageRemoved = false; | |
greyOverlay.style.display = "none"; | |
// Clear any existing mask and update status text | |
clearMaskButton.click(); | |
responseText.textContent = 'New image uploaded. Interact with the image or enter text.'; | |
// Reset the output overlay | |
document.querySelector('#cached-output-overlay').style.display = "block"; | |
document.querySelector('#cached-output-overlay').innerHTML = ""; | |
console.log("Image uploaded and displayed."); | |
} | |
reader.readAsDataURL(file); | |
} else if (file) { | |
alert("Please upload a JPG or PNG image file."); | |
// Reset the file input value so the same file can be selected again if needed after error | |
imageUploadInput.value = ""; | |
} | |
}); | |
})(); |