dillonlaird's picture
added js
8040aeb
raw
history blame
11 kB
'use client';
import { useState, useRef, useEffect } from 'react';
import axios from 'axios';
import useImage from 'use-image';
import Konva from 'konva';
import { Stage, Image, Layer } from 'react-konva';
import { BASE_URL } from './constants';
import Dropdown from './dropdown';
interface RGB {
r: number;
g: number;
b: number;
}
interface Box {
x: number;
y: number;
width: number;
height: number;
}
const maskFilter = (imageData: ImageData, color: RGB) => {
const nPixels = imageData.data.length;
for (let i = 0; i < nPixels; i += 4) {
const r = imageData.data[i];
const g = imageData.data[i + 1];
const b = imageData.data[i + 2];
if (r === 0 && g === 0 && b === 0) {
imageData.data[i + 3] = 0;
} else {
imageData.data[i] = color.r;
imageData.data[i + 1] = color.g;
imageData.data[i + 2] = color.b;
imageData.data[i + 3] = 128;
}
}
};
export default function Canvas({ imageUrl, imageName }: { imageUrl: string, imageName: string }) {
const [image] = useImage(imageUrl, 'anonymous');
const stageRef = useRef<Konva.Stage>(null);
const layerRef = useRef<Konva.Layer>(null);
const groupRef = useRef<Konva.Group[]>([]);
const pointsRef = useRef<Array<[number, number]>>([]);
const labelsRef = useRef<Array<number>>([]);
const instanceColorsRef = useRef<string[]>([]);
const [classList, setClassList] = useState<string[]>([]);
const classOptionsRef = useRef<string[]>([]);
const classSelectionRef = useRef<string>('');
const [useLatest, setUseLatest] = useState<boolean>(false);
const selectedInstanceRef = useRef<number>(-1);
const [selectedInstanceStyle, setSelectedInstanceStyle] = useState<boolean>(false);
const clearLayers = () => {
groupRef.current.forEach((group) => {
group.destroy();
});
groupRef.current = [];
layerRef.current?.draw();
pointsRef.current = [];
labelsRef.current = [];
instanceColorsRef.current = [Konva.Util.getRandomColor()];
setClassList(_ => ([]));
}
useEffect(() => {
clearLayers();
}, [imageUrl]);
const showMask = (data: string, group: Konva.Group) => {
const layer = layerRef.current;
if (!layer) return;
const color = Konva.Util.getRGB(
instanceColorsRef.current[instanceColorsRef.current.length - 1]
);
const width = image?.width;
const height = image?.height;
const maskObj = new window.Image();
maskObj.onload = () => {
const image = new Konva.Image({
x: 0,
y: 0,
image: maskObj,
width: width,
height: height,
});
image.filters([(imageData: ImageData) => maskFilter(imageData, color)]);
image.cache();
group.add(image);
if (groupRef.current.length === 0)
groupRef.current.push(group);
layer.add(group);
layer.draw()
};
maskObj.src = `data:image/png;base64,${data}`;
setUseLatest(true);
}
const showBox = (box: Box, group: Konva.Group) => {
const layer = layerRef.current;
if (!layer) return;
const color = Konva.Util.getRGB(
instanceColorsRef.current[instanceColorsRef.current.length - 1]
);
const width = image?.width;
const height = image?.height;
function rgbToHex(r, g, b) {
return '#' + (1 << 24 | r << 16 | g << 8 | b).toString(16).slice(1);
}
const rect = new Konva.Rect({
x: box.x,
y: box.y,
width: box.width,
height: box.height,
stroke: rgbToHex(color.r, color.g, color.b),
storkeWidth: 2,
});
if (groupRef.current.length === 0)
groupRef.current.push(group);
group.add(rect);
layer.add(group);
layer.draw();
setUseLatest(true);
}
const handleStageClick = async (e: Konva.KonvaEventObject<MouseEvent>) => {
const stage = stageRef.current;
const layer = layerRef.current;
if (!stage || !layer) return;
const pos = e.target.getStage()?.getPointerPosition();
if (!pos) return;
let labels = labelsRef.current;
if (e.evt.button === 2) {
labels = [...labels, 0];
} else {
labels = [...labels, 1];
}
labelsRef.current = labels;
// re-draw last layer so we don't overlap masks on clicks
if (groupRef.current.length > 0) {
groupRef.current[groupRef.current.length - 1].destroy();
} else {
groupRef.current.push(new Konva.Group());
}
layer.draw();
let points = pointsRef.current;
points = [...points, [pos.x, pos.y]];
pointsRef.current = points;
const res = await axios.post(
`${BASE_URL}/v1/get_label_preds/${imageName}`,
{
points: points,
labels: labels,
}
);
showMask(res.data.masks[0], groupRef.current[groupRef.current.length - 1]);
}
const pinInstance = async () => {
const layer = layerRef.current;
if (!layer) return;
const groupList = groupRef.current;
// if (groupList.length === 0) return;
if (classSelectionRef.current === '') {
alert('Please select a class');
} else {
groupList.push(new Konva.Group());
pointsRef.current = [];
labelsRef.current = [];
instanceColorsRef.current.push(Konva.Util.getRandomColor());
setClassList(prev => ([...prev, classSelectionRef.current]))
setUseLatest(false);
}
}
const deleteSelectedInstance = () => {
const layer = layerRef.current;
if (!layer) return;
if (selectedInstanceRef.current === -1) return;
// always 1 extra group for the next instance
if (groupRef.current.length <= 1) return;
groupRef.current[selectedInstanceRef.current].destroy();
groupRef.current.splice(selectedInstanceRef.current, 1);
setClassList([
...classList.slice(0, selectedInstanceRef.current),
...classList.slice(selectedInstanceRef.current + 1)
]);
instanceColorsRef.current.splice(selectedInstanceRef.current, 1);
layer.draw()
// reset selectedInstanceRef?
selectedInstanceRef.current = -1;
}
const submitLabels = async () => {
const groupList = groupRef.current;
if (groupList.length === 0) return;
const masks = [];
const length = useLatest ? groupList.length : groupList.length - 1;
for (let i = 0; i < length; i++) {
const labelData = groupList[i].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height });
masks.push(labelData);
}
const res = await axios.put(`${BASE_URL}/v1/label_image/${imageName}`,
{
masks: masks,
labels: classList,
}
);
pointsRef.current = [];
labelsRef.current = [];
clearLayers();
}
const getLabels = async () => {
clearLayers();
try {
const res = await axios.get(`${BASE_URL}/v1/get_labels/${imageName}`);
classOptionsRef.current = [...new Set(res.data.labels)];
// add an initial group to start with
groupRef.current.push(new Konva.Group());
for (let i = 0; i < res.data.masks.length; i++) {
showMask(res.data.masks[i], groupRef.current[groupRef.current.length - 1]);
classSelectionRef.current = res.data.labels[i];
pinInstance();
}
} catch (err) {
console.log(err);
}
}
const predLabels = async () => {
const length = useLatest ? groupRef.current.length : groupRef.current.length - 1;
if (groupRef.current.length === 0) {
alert('Please pin an instance');
return
}
const mask = groupRef.current[length - 1].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height });
clearLayers();
try {
const res = await axios.post(`${BASE_URL}/v1/get_multi_label_preds/${imageName}`, {
mask: mask,
label: classList[length - 1],
});
groupRef.current.push(new Konva.Group());
for (let i = 0; i < res.data.masks.length; i++) {
showMask(res.data.masks[i], groupRef.current[groupRef.current.length - 1]);
classSelectionRef.current = res.data.labels[i];
pinInstance();
}
} catch (err) {
console.log(err);
}
}
const getBboxes = async () => {
clearLayers();
try {
const res = await axios.get(`${BASE_URL}/v1/get_labels/${imageName}`);
classOptionsRef.current = [...new Set(res.data.labels)];
groupRef.current.push(new Konva.Group());
for (let i = 0; i < res.data.bboxes.length; i++) {
showBox(res.data.bboxes[i], groupRef.current[groupRef.current.length - 1]);
classSelectionRef.current = res.data.labels[i];
pinInstance();
}
} catch (err) {
console.log(err);
}
}
return (
<div className="flex flex-row columns-2 gap-20">
<div className="">
<h1>Instance Class List</h1>
<ul>
{classList.map((item, index) => (
<li key={index}>
<div
// style={{ background: instanceColorsRef.current[index], border: '2px solid red'}}
style={{
background: instanceColorsRef.current[index],
border: selectedInstanceRef.current === index && selectedInstanceStyle ? '2px solid red' : 'none'
}}
onClick={() => {
selectedInstanceRef.current = index;
setSelectedInstanceStyle(prev => !prev)
}}
>{item}</div>
</li>
))}
</ul>
<br />
<Dropdown
options={classOptionsRef}
selectedOption={classSelectionRef}
/>
</div>
<div className="">
<Stage
width={image?.width}
height={image?.height}
ref={stageRef}
onClick={handleStageClick}
onContextMenu={(e) => e.evt.preventDefault()}
>
<Layer ref={layerRef}>
<Image image={image} alt="image" />
</Layer>
</Stage>
<br />
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2"
onClick={clearLayers}>Clear All</button>&nbsp;
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2"
onClick={getBboxes}>Load BBoxes</button>&nbsp;
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2"
onClick={getLabels}>Load Masks</button>&nbsp;
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2"
onClick={deleteSelectedInstance}>Delete Selected Instance</button>&nbsp;
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2"
onClick={pinInstance}>Pin Instance</button>&nbsp;
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2"
onClick={submitLabels}>Save All</button>&nbsp;
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-2"
onClick={predLabels}>Auto Label</button>
</div>
</div>
);
};