jbilcke-hf's picture
jbilcke-hf HF staff
Upload 23 files
5f0f01e verified
raw
history blame
6.39 kB
import { env, pipeline, SamModel, AutoProcessor, RawImage, Tensor, PreTrainedModel, Processor, AutoModel } from "@xenova/transformers/dist/transformers.js"
export default () => {
// Since we will download the model from the Hugging Face Hub, we can skip the local model check
env.allowLocalModels = false;
// the promise will be fulfilled the first time it is called
const segmentationModel = (async (
modelId = 'Xenova/slimsam-77-uniform') => ({
model: await SamModel.from_pretrained(modelId, { quantized: true }),
processor: await AutoProcessor.from_pretrained(modelId)
}))()
// the promise will be fullfilled the first time it is claled
const depthEstimationModel = (async () => {
const p = await pipeline(
'depth-estimation',
'Xenova/depth-anything-small-hf'
// 'Xenova/dpt-hybrid-midas'
)
return p
})()
const bgRemoverModel = (async (
modelId = 'briaai/RMBG-1.4') => ({
model: await AutoModel.from_pretrained(modelId, {
// Do not require config.json to be present in the repository
config: { model_type: 'custom' }
}),
processor: await AutoProcessor.from_pretrained(modelId, {
// Do not require config.json to be present in the repository
config: {
do_normalize: true,
do_pad: false,
do_rescale: true,
do_resize: true,
image_mean: [0.5, 0.5, 0.5],
feature_extractor_type: "ImageFeatureExtractor",
image_std: [1, 1, 1],
resample: 2,
rescale_factor: 0.00392156862745098,
size: { width: 1024, height: 1024 },
}
})
}))()
// {
// predicted_depth: Tensor {
// dims: [ 384, 384 ],
// type: 'float32',
// data: Float32Array(147456) [ 542.859130859375, 545.2833862304688, 546.1649169921875, ... ],
// size: 147456
// },
// depth: RawImage {
// data: Uint8Array(307200) [ 86, 86, 86, ... ],
// width: 640,
// height: 480,
// channels: 1
// }
// }
// State variables
let image_embeddings = null;
let image_string = "";
let image_inputs = null;
let ready = false;
async function decode(data) {
const sam = await segmentationModel
// Prepare inputs for decoding
const reshaped = image_inputs.reshaped_input_sizes[0];
const points = data.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
const labels = data.map(x => BigInt(x.label));
const input_points = new Tensor(
'float32',
points.flat(Infinity),
[1, 1, points.length, 2],
)
const input_labels = new Tensor(
'int64',
labels.flat(Infinity),
[1, 1, labels.length],
)
// Generate the mask
const outputs = await sam.model({
...image_embeddings,
input_points,
input_labels,
})
console.log(outputs.iou_scores);
// Post-process the mask
// @ts-ignore
const masks = await sam.processor.post_process_masks(
outputs.pred_masks,
image_inputs.original_sizes,
image_inputs.reshaped_input_sizes,
);
console.log(masks);
return {
mask: RawImage.fromTensor(masks[0][0]),
scores: outputs.iou_scores.data,
}
}
self.onmessage = async (e) => { // eslint-disable-line no-restricted-globals
const sam = await segmentationModel
if (!ready) {
// Indicate that we are ready to accept requests
ready = true;
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'ready',
});
}
const { type } = e.data;
if (type === 'reset') {
image_inputs = null;
image_embeddings = null;
} else if (type === 'remove_background') {
console.log('starting background removal')
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'remove_background_start',
// data: 'start'
})
const bgRemover = await bgRemoverModel
const url = (e.data.url || "")
// Read image
const image = await RawImage.fromURL(url);
// Preprocess image
const { pixel_values } = await bgRemover.processor(image);
// Predict alpha matte
const { output } = await bgRemover.model({ input: pixel_values });
// Resize mask back to original size
const mask = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(image.width, image.height);
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'remove_background_end',
data: mask
})
} else if (type === 'depth') {
console.log('starting depth')
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'depth_start',
// data: 'start'
})
const data = e.data.data || ""
const dem = await depthEstimationModel
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png';
const depthEstimations = await dem(url)
const depthEstimation = Array.isArray(depthEstimations) ? depthEstimations[0] : depthEstimations
const { depth } = depthEstimation
// Indicate that we have computed the image embeddings, and we are ready to accept decoding requests
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'depth_end',
data: depth,
});
} else if (type === 'segment') {
// Indicate that we are starting to segment the image
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'segment_result',
data: 'start',
});
const data = e.data.data;
// keep this for later use
image_string = data;
// Read the image and recompute image embeddings
const image = await RawImage.read(data);
image_inputs = await sam.processor(image);
// @ts-ignore
image_embeddings = await sam.model.get_image_embeddings(image_inputs)
// Indicate that we have computed the image embeddings, and we are ready to accept decoding requests
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'segment_result',
data: 'done',
});
} else if (type === 'decode') {
const inputData = e.data.data;
const outputData = await decode(inputData)
// Send the result back to the main thread
self.postMessage({ // eslint-disable-line no-restricted-globals
type: 'decode_result',
data: outputData,
});
} else {
throw new Error(`Unknown message type: ${type}`);
}
}
}