Spaces:
Running
Running
import { env, pipeline, SamModel, AutoProcessor, RawImage, Tensor, PreTrainedModel, Processor, AutoModel } from "@xenova/transformers/dist/transformers.js" | |
export default () => { | |
// Since we will download the model from the Hugging Face Hub, we can skip the local model check | |
env.allowLocalModels = false; | |
// the promise will be fulfilled the first time it is called | |
const segmentationModel = (async ( | |
modelId = 'Xenova/slimsam-77-uniform') => ({ | |
model: await SamModel.from_pretrained(modelId, { quantized: true }), | |
processor: await AutoProcessor.from_pretrained(modelId) | |
}))() | |
// the promise will be fullfilled the first time it is claled | |
const depthEstimationModel = (async () => { | |
const p = await pipeline( | |
'depth-estimation', | |
'Xenova/depth-anything-small-hf' | |
// 'Xenova/dpt-hybrid-midas' | |
) | |
return p | |
})() | |
const bgRemoverModel = (async ( | |
modelId = 'briaai/RMBG-1.4') => ({ | |
model: await AutoModel.from_pretrained(modelId, { | |
// Do not require config.json to be present in the repository | |
config: { model_type: 'custom' } | |
}), | |
processor: await AutoProcessor.from_pretrained(modelId, { | |
// Do not require config.json to be present in the repository | |
config: { | |
do_normalize: true, | |
do_pad: false, | |
do_rescale: true, | |
do_resize: true, | |
image_mean: [0.5, 0.5, 0.5], | |
feature_extractor_type: "ImageFeatureExtractor", | |
image_std: [1, 1, 1], | |
resample: 2, | |
rescale_factor: 0.00392156862745098, | |
size: { width: 1024, height: 1024 }, | |
} | |
}) | |
}))() | |
// { | |
// predicted_depth: Tensor { | |
// dims: [ 384, 384 ], | |
// type: 'float32', | |
// data: Float32Array(147456) [ 542.859130859375, 545.2833862304688, 546.1649169921875, ... ], | |
// size: 147456 | |
// }, | |
// depth: RawImage { | |
// data: Uint8Array(307200) [ 86, 86, 86, ... ], | |
// width: 640, | |
// height: 480, | |
// channels: 1 | |
// } | |
// } | |
// State variables | |
let image_embeddings = null; | |
let image_string = ""; | |
let image_inputs = null; | |
let ready = false; | |
async function decode(data) { | |
const sam = await segmentationModel | |
// Prepare inputs for decoding | |
const reshaped = image_inputs.reshaped_input_sizes[0]; | |
const points = data.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]]) | |
const labels = data.map(x => BigInt(x.label)); | |
const input_points = new Tensor( | |
'float32', | |
points.flat(Infinity), | |
[1, 1, points.length, 2], | |
) | |
const input_labels = new Tensor( | |
'int64', | |
labels.flat(Infinity), | |
[1, 1, labels.length], | |
) | |
// Generate the mask | |
const outputs = await sam.model({ | |
...image_embeddings, | |
input_points, | |
input_labels, | |
}) | |
console.log(outputs.iou_scores); | |
// Post-process the mask | |
// @ts-ignore | |
const masks = await sam.processor.post_process_masks( | |
outputs.pred_masks, | |
image_inputs.original_sizes, | |
image_inputs.reshaped_input_sizes, | |
); | |
console.log(masks); | |
return { | |
mask: RawImage.fromTensor(masks[0][0]), | |
scores: outputs.iou_scores.data, | |
} | |
} | |
self.onmessage = async (e) => { // eslint-disable-line no-restricted-globals | |
const sam = await segmentationModel | |
if (!ready) { | |
// Indicate that we are ready to accept requests | |
ready = true; | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'ready', | |
}); | |
} | |
const { type } = e.data; | |
if (type === 'reset') { | |
image_inputs = null; | |
image_embeddings = null; | |
} else if (type === 'remove_background') { | |
console.log('starting background removal') | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'remove_background_start', | |
// data: 'start' | |
}) | |
const bgRemover = await bgRemoverModel | |
const url = (e.data.url || "") | |
// Read image | |
const image = await RawImage.fromURL(url); | |
// Preprocess image | |
const { pixel_values } = await bgRemover.processor(image); | |
// Predict alpha matte | |
const { output } = await bgRemover.model({ input: pixel_values }); | |
// Resize mask back to original size | |
const mask = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(image.width, image.height); | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'remove_background_end', | |
data: mask | |
}) | |
} else if (type === 'depth') { | |
console.log('starting depth') | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'depth_start', | |
// data: 'start' | |
}) | |
const data = e.data.data || "" | |
const dem = await depthEstimationModel | |
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png'; | |
const depthEstimations = await dem(url) | |
const depthEstimation = Array.isArray(depthEstimations) ? depthEstimations[0] : depthEstimations | |
const { depth } = depthEstimation | |
// Indicate that we have computed the image embeddings, and we are ready to accept decoding requests | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'depth_end', | |
data: depth, | |
}); | |
} else if (type === 'segment') { | |
// Indicate that we are starting to segment the image | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'segment_result', | |
data: 'start', | |
}); | |
const data = e.data.data; | |
// keep this for later use | |
image_string = data; | |
// Read the image and recompute image embeddings | |
const image = await RawImage.read(data); | |
image_inputs = await sam.processor(image); | |
// @ts-ignore | |
image_embeddings = await sam.model.get_image_embeddings(image_inputs) | |
// Indicate that we have computed the image embeddings, and we are ready to accept decoding requests | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'segment_result', | |
data: 'done', | |
}); | |
} else if (type === 'decode') { | |
const inputData = e.data.data; | |
const outputData = await decode(inputData) | |
// Send the result back to the main thread | |
self.postMessage({ // eslint-disable-line no-restricted-globals | |
type: 'decode_result', | |
data: outputData, | |
}); | |
} else { | |
throw new Error(`Unknown message type: ${type}`); | |
} | |
} | |
} |