Spaces:
Build error
Build error
'use client'; | |
import { useState, useRef, useEffect } from 'react'; | |
import axios from 'axios'; | |
import useImage from 'use-image'; | |
import Konva from 'konva'; | |
import { Stage, Image, Layer } from 'react-konva'; | |
import { BASE_URL } from './constants'; | |
import Dropdown from './dropdown'; | |
interface RGB { | |
r: number; | |
g: number; | |
b: number; | |
} | |
interface Box { | |
x: number; | |
y: number; | |
width: number; | |
height: number; | |
} | |
const maskFilter = (imageData: ImageData, color: RGB) => { | |
const nPixels = imageData.data.length; | |
for (let i = 0; i < nPixels; i += 4) { | |
const r = imageData.data[i]; | |
const g = imageData.data[i + 1]; | |
const b = imageData.data[i + 2]; | |
if (r === 0 && g === 0 && b === 0) { | |
imageData.data[i + 3] = 0; | |
} else { | |
imageData.data[i] = color.r; | |
imageData.data[i + 1] = color.g; | |
imageData.data[i + 2] = color.b; | |
imageData.data[i + 3] = 128; | |
} | |
} | |
}; | |
export default function Canvas({ imageUrl, imageName }: { imageUrl: string, imageName: string }) { | |
const [image] = useImage(imageUrl, 'anonymous'); | |
const stageRef = useRef<Konva.Stage>(null); | |
const layerRef = useRef<Konva.Layer>(null); | |
const groupRef = useRef<Konva.Group[]>([]); | |
const pointsRef = useRef<Array<[number, number]>>([]); | |
const labelsRef = useRef<Array<number>>([]); | |
const instanceColorsRef = useRef<string[]>([]); | |
const [classList, setClassList] = useState<string[]>([]); | |
const classOptionsRef = useRef<string[]>([]); | |
const classSelectionRef = useRef<string>(''); | |
const [useLatest, setUseLatest] = useState<boolean>(false); | |
const selectedInstanceRef = useRef<number>(-1); | |
const [selectedInstanceStyle, setSelectedInstanceStyle] = useState<boolean>(false); | |
const clearLayers = () => { | |
groupRef.current.forEach((group) => { | |
group.destroy(); | |
}); | |
groupRef.current = []; | |
layerRef.current?.draw(); | |
pointsRef.current = []; | |
labelsRef.current = []; | |
instanceColorsRef.current = [Konva.Util.getRandomColor()]; | |
setClassList(_ => ([])); | |
} | |
useEffect(() => { | |
clearLayers(); | |
}, [imageUrl]); | |
const showMask = (data: string, group: Konva.Group) => { | |
const layer = layerRef.current; | |
if (!layer) return; | |
const color = Konva.Util.getRGB( | |
instanceColorsRef.current[instanceColorsRef.current.length - 1] | |
); | |
const width = image?.width; | |
const height = image?.height; | |
const maskObj = new window.Image(); | |
maskObj.onload = () => { | |
const image = new Konva.Image({ | |
x: 0, | |
y: 0, | |
image: maskObj, | |
width: width, | |
height: height, | |
}); | |
image.filters([(imageData: ImageData) => maskFilter(imageData, color)]); | |
image.cache(); | |
group.add(image); | |
if (groupRef.current.length === 0) | |
groupRef.current.push(group); | |
layer.add(group); | |
layer.draw() | |
}; | |
maskObj.src = `data:image/png;base64,${data}`; | |
setUseLatest(true); | |
} | |
const showBox = (box: Box, group: Konva.Group) => { | |
const layer = layerRef.current; | |
if (!layer) return; | |
const color = Konva.Util.getRGB( | |
instanceColorsRef.current[instanceColorsRef.current.length - 1] | |
); | |
const width = image?.width; | |
const height = image?.height; | |
function rgbToHex(r, g, b) { | |
return '#' + (1 << 24 | r << 16 | g << 8 | b).toString(16).slice(1); | |
} | |
const rect = new Konva.Rect({ | |
x: box.x, | |
y: box.y, | |
width: box.width, | |
height: box.height, | |
stroke: rgbToHex(color.r, color.g, color.b), | |
storkeWidth: 2, | |
}); | |
if (groupRef.current.length === 0) | |
groupRef.current.push(group); | |
group.add(rect); | |
layer.add(group); | |
layer.draw(); | |
setUseLatest(true); | |
} | |
const handleStageClick = async (e: Konva.KonvaEventObject<MouseEvent>) => { | |
const stage = stageRef.current; | |
const layer = layerRef.current; | |
if (!stage || !layer) return; | |
const pos = e.target.getStage()?.getPointerPosition(); | |
if (!pos) return; | |
let labels = labelsRef.current; | |
if (e.evt.button === 2) { | |
labels = [...labels, 0]; | |
} else { | |
labels = [...labels, 1]; | |
} | |
labelsRef.current = labels; | |
// re-draw last layer so we don't overlap masks on clicks | |
if (groupRef.current.length > 0) { | |
groupRef.current[groupRef.current.length - 1].destroy(); | |
} else { | |
groupRef.current.push(new Konva.Group()); | |
} | |
layer.draw(); | |
let points = pointsRef.current; | |
points = [...points, [pos.x, pos.y]]; | |
pointsRef.current = points; | |
const res = await axios.post( | |
`${BASE_URL}/v1/get_label_preds/${imageName}`, | |
{ | |
points: points, | |
labels: labels, | |
} | |
); | |
showMask(res.data.masks[0], groupRef.current[groupRef.current.length - 1]); | |
} | |
const pinInstance = async () => { | |
const layer = layerRef.current; | |
if (!layer) return; | |
const groupList = groupRef.current; | |
// if (groupList.length === 0) return; | |
if (classSelectionRef.current === '') { | |
alert('Please select a class'); | |
} else { | |
groupList.push(new Konva.Group()); | |
pointsRef.current = []; | |
labelsRef.current = []; | |
instanceColorsRef.current.push(Konva.Util.getRandomColor()); | |
setClassList(prev => ([...prev, classSelectionRef.current])) | |
setUseLatest(false); | |
} | |
} | |
const deleteSelectedInstance = () => { | |
const layer = layerRef.current; | |
if (!layer) return; | |
if (selectedInstanceRef.current === -1) return; | |
// always 1 extra group for the next instance | |
if (groupRef.current.length <= 1) return; | |
groupRef.current[selectedInstanceRef.current].destroy(); | |
groupRef.current.splice(selectedInstanceRef.current, 1); | |
setClassList([ | |
...classList.slice(0, selectedInstanceRef.current), | |
...classList.slice(selectedInstanceRef.current + 1) | |
]); | |
instanceColorsRef.current.splice(selectedInstanceRef.current, 1); | |
layer.draw() | |
// reset selectedInstanceRef? | |
selectedInstanceRef.current = -1; | |
} | |
const submitLabels = async () => { | |
const groupList = groupRef.current; | |
if (groupList.length === 0) return; | |
const masks = []; | |
const length = useLatest ? groupList.length : groupList.length - 1; | |
for (let i = 0; i < length; i++) { | |
const labelData = groupList[i].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height }); | |
masks.push(labelData); | |
} | |
const res = await axios.put(`${BASE_URL}/v1/label_image/${imageName}`, | |
{ | |
masks: masks, | |
labels: classList, | |
} | |
); | |
pointsRef.current = []; | |
labelsRef.current = []; | |
clearLayers(); | |
} | |
const getLabels = async () => { | |
clearLayers(); | |
try { | |
const res = await axios.get(`${BASE_URL}/v1/get_labels/${imageName}`); | |
classOptionsRef.current = [...new Set(res.data.labels)]; | |
// add an initial group to start with | |
groupRef.current.push(new Konva.Group()); | |
for (let i = 0; i < res.data.masks.length; i++) { | |
showMask(res.data.masks[i], groupRef.current[groupRef.current.length - 1]); | |
classSelectionRef.current = res.data.labels[i]; | |
pinInstance(); | |
} | |
} catch (err) { | |
console.log(err); | |
} | |
} | |
const predLabels = async () => { | |
const length = useLatest ? groupRef.current.length : groupRef.current.length - 1; | |
if (groupRef.current.length === 0) { | |
alert('Please pin an instance'); | |
return | |
} | |
const mask = groupRef.current[length - 1].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height }); | |
clearLayers(); | |
try { | |
const res = await axios.post(`${BASE_URL}/v1/get_multi_label_preds/${imageName}`, { | |
mask: mask, | |
label: classList[length - 1], | |
}); | |
groupRef.current.push(new Konva.Group()); | |
for (let i = 0; i < res.data.masks.length; i++) { | |
showMask(res.data.masks[i], groupRef.current[groupRef.current.length - 1]); | |
classSelectionRef.current = res.data.labels[i]; | |
pinInstance(); | |
} | |
} catch (err) { | |
console.log(err); | |
} | |
} | |
const getBboxes = async () => { | |
clearLayers(); | |
try { | |
const res = await axios.get(`${BASE_URL}/v1/get_labels/${imageName}`); | |
classOptionsRef.current = [...new Set(res.data.labels)]; | |
groupRef.current.push(new Konva.Group()); | |
for (let i = 0; i < res.data.bboxes.length; i++) { | |
showBox(res.data.bboxes[i], groupRef.current[groupRef.current.length - 1]); | |
classSelectionRef.current = res.data.labels[i]; | |
pinInstance(); | |
} | |
} catch (err) { | |
console.log(err); | |
} | |
} | |
return ( | |
<div className="flex flex-row columns-2 gap-20"> | |
<div className=""> | |
<h1>Instance Class List</h1> | |
<ul> | |
{classList.map((item, index) => ( | |
<li key={index}> | |
<div | |
// style={{ background: instanceColorsRef.current[index], border: '2px solid red'}} | |
style={{ | |
background: instanceColorsRef.current[index], | |
border: selectedInstanceRef.current === index && selectedInstanceStyle ? '2px solid red' : 'none' | |
}} | |
onClick={() => { | |
selectedInstanceRef.current = index; | |
setSelectedInstanceStyle(prev => !prev) | |
}} | |
>{item}</div> | |
</li> | |
))} | |
</ul> | |
<br /> | |
<Dropdown | |
options={classOptionsRef} | |
selectedOption={classSelectionRef} | |
/> | |
</div> | |
<div className=""> | |
<Stage | |
width={image?.width} | |
height={image?.height} | |
ref={stageRef} | |
onClick={handleStageClick} | |
onContextMenu={(e) => e.evt.preventDefault()} | |
> | |
<Layer ref={layerRef}> | |
<Image image={image} alt="image" /> | |
</Layer> | |
</Stage> | |
<br /> | |
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2" | |
onClick={clearLayers}>Clear All</button> | |
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2" | |
onClick={getBboxes}>Load BBoxes</button> | |
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2" | |
onClick={getLabels}>Load Masks</button> | |
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2" | |
onClick={deleteSelectedInstance}>Delete Selected Instance</button> | |
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2" | |
onClick={pinInstance}>Pin Instance</button> | |
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2" | |
onClick={submitLabels}>Save All</button> | |
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2" | |
onClick={predLabels}>Auto Label</button> | |
</div> | |
</div> | |
); | |
}; | |