File size: 4,033 Bytes
32df74a 2cfd6b4 a8938fa 2cfd6b4 a8938fa 2cfd6b4 ec79de9 2cfd6b4 7f22fcb 2cfd6b4 3eced31 2cfd6b4 7f22fcb 2cfd6b4 a4534a7 2cfd6b4 891a89b 2cfd6b4 32df74a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
<html>
<head>
<script
type="module"
crossorigin
src="https://cdn.jsdelivr.net/npm/@gradio/lite@0.4.3/dist/lite.js"
></script>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@gradio/lite@0.4.3/dist/lite.css"
/>
<script type="module">
const MODEL_URL =
"https://huggingface.co/lmz/candle-sam/resolve/main/mobile_sam-tiny-vitt.safetensors";
const samWorker = new Worker("./samWorker.js", { type: "module" });
async function segmentPoints(
imageURL, // URL to the image file
points, // {x, y} points to prompt image
modelURL = MODEL_URL, // URL to the weights file
modelID = "sam_mobile_tiny" // model ID
) {
return new Promise((resolve, reject) => {
function messageHandler(event) {
console.log(event.data);
// if ("status" in event.data) {
// }
if ("error" in event.data) {
samWorker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete-embedding") {
samWorker.removeEventListener("message", messageHandler);
resolve();
}
if (event.data.status === "complete") {
samWorker.removeEventListener("message", messageHandler);
resolve(event.data.output);
}
}
samWorker.addEventListener("message", messageHandler);
samWorker.postMessage({
modelURL,
modelID,
imageURL,
points,
});
});
}
globalThis.segmentPoints = segmentPoints;
</script>
<style>
body {
color: black;
background-color: white;
}
@media (prefers-color-scheme: dark) {
body {
color: white;
background-color: #0b0f19;
}
}
</style>
</head>
<body>
<!-- prettier-ignore -->
<gradio-lite>
import gradio as gr
get_point_mask = """
async function getPointMask(image, points) {
console.log("getting point mask");
//console.log(image, points)
const { maskURL } = await segmentPoints(
image,
points
);
if(points.length == 0){
return [ null ];
}
return [ maskURL ];
}
"""
def set_points(image, points_state, evt: gr.SelectData):
points_state.append([evt.index[0]/image.width, evt.index[1]/image.height, True])
return points_state, points_state
with gr.Blocks() as demo:
gr.Markdown("""## Segment Anything Model (SAM) with Gradio Lite
This demo uses [Gradio Lite](https://www.gradio.app/guides/gradio-lite) as UI for running the Segment Anything Model (SAM) with WASM build with [Candle](https://github.com/huggingface/candle).
**Note:** The model's first run may take a few seconds as it loads and caches the model in the browser, and then creates the image embeddings. Any subsequent clicks on points will be significantly faster.
""")
points_state = gr.State([])
with gr.Row():
with gr.Column():
image = gr.Image(label="Input Image", type="pil")
clear_points = gr.Button(value="Clear Points")
points = gr.JSON(label="Input Points", visible=False)
with gr.Column():
mask = gr.Image(label="Output Mask")
clear_points.click(lambda: ([], []), None, [points, points_state])
image.select(set_points, inputs=[image, points_state], outputs=[points, points_state])
points.change(None, inputs=[image, points], outputs=[mask], _js=get_point_mask)
demo.launch(show_api=False)
</gradio-lite>
</body>
</html>
|