|
<html> |
|
<head> |
|
<script |
|
type="module" |
|
crossorigin |
|
src="https://cdn.jsdelivr.net/npm/@gradio/lite@4.26.0/dist/lite.js" |
|
></script> |
|
<link |
|
rel="stylesheet" |
|
href="https://cdn.jsdelivr.net/npm/@gradio/lite@4.26.0/dist/lite.css" |
|
/> |
|
<script type="module"> |
|
const MODEL_URL = |
|
"https://huggingface.co/lmz/candle-sam/resolve/main/mobile_sam-tiny-vitt.safetensors"; |
|
|
|
const samWorker = new Worker("./samWorker.js", { type: "module" }); |
|
|
|
async function segmentPoints( |
|
imageURL, |
|
points, |
|
modelURL = MODEL_URL, |
|
modelID = "sam_mobile_tiny" |
|
) { |
|
return new Promise((resolve, reject) => { |
|
function messageHandler(event) { |
|
console.log(event.data); |
|
|
|
|
|
|
|
if ("error" in event.data) { |
|
samWorker.removeEventListener("message", messageHandler); |
|
reject(new Error(event.data.error)); |
|
} |
|
if (event.data.status === "complete-embedding") { |
|
samWorker.removeEventListener("message", messageHandler); |
|
resolve(); |
|
} |
|
if (event.data.status === "complete") { |
|
samWorker.removeEventListener("message", messageHandler); |
|
resolve(event.data.output); |
|
} |
|
} |
|
samWorker.addEventListener("message", messageHandler); |
|
samWorker.postMessage({ |
|
modelURL, |
|
modelID, |
|
imageURL, |
|
points, |
|
}); |
|
}); |
|
} |
|
globalThis.segmentPoints = segmentPoints; |
|
</script> |
|
<style> |
|
body { |
|
color: black; |
|
background-color: white; |
|
} |
|
|
|
@media (prefers-color-scheme: dark) { |
|
body { |
|
color: white; |
|
background-color: #0b0f19; |
|
} |
|
} |
|
</style> |
|
</head> |
|
|
|
<body> |
|
|
|
<gradio-lite> |
|
import gradio as gr |
|
|
|
|
|
get_point_mask = """ |
|
async function getPointMask(image, points) { |
|
console.log("getting point mask"); |
|
console.log(points) |
|
const { maskURL } = await segmentPoints( |
|
image, |
|
points |
|
); |
|
if(points.length == 0){ |
|
return [ null ]; |
|
} |
|
return [ maskURL ]; |
|
} |
|
""" |
|
def set_points(image, points_state, evt: gr.SelectData): |
|
points_state.append([evt.index[0]/image.width, evt.index[1]/image.height, True]) |
|
return points_state, points_state |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("""## Segment Anything Model (SAM) with Gradio Lite |
|
This demo uses [Gradio Lite](https://www.gradio.app/guides/gradio-lite) as UI for running the Segment Anything Model (SAM) with WASM build with [Candle](https://github.com/huggingface/candle). |
|
|
|
**Note:** The model's first run may take a few seconds as it loads and caches the model in the browser, and then creates the image embeddings. Any subsequent clicks on points will be significantly faster. |
|
""") |
|
points_state = gr.State([]) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(label="Input Image", type="pil") |
|
clear_points = gr.Button(value="Clear Points") |
|
points = gr.JSON(label="Input Points", visible=False) |
|
with gr.Column(): |
|
mask = gr.Image(label="Output Mask") |
|
clear_points.click(lambda: ([], []), None, [points, points_state]) |
|
image.select(set_points, inputs=[image, points_state], outputs=[points, points_state]) |
|
points.change(None, inputs=[image, points], outputs=[mask], js=get_point_mask) |
|
demo.launch(show_api=False) |
|
</gradio-lite> |
|
</body> |
|
</html> |
|
|