File size: 4,033 Bytes
32df74a
2cfd6b4
 
 
 
a8938fa
2cfd6b4
 
 
a8938fa
2cfd6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec79de9
 
 
 
 
 
 
 
 
 
 
 
 
2cfd6b4
 
 
 
 
 
 
 
7f22fcb
2cfd6b4
 
3eced31
2cfd6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
7f22fcb
2cfd6b4
 
 
 
 
 
 
 
a4534a7
2cfd6b4
 
 
 
 
 
891a89b
2cfd6b4
 
 
32df74a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
<html>
  <head>
    <script
      type="module"
      crossorigin
      src="https://cdn.jsdelivr.net/npm/@gradio/lite@0.4.3/dist/lite.js"
    ></script>
    <link
      rel="stylesheet"
      href="https://cdn.jsdelivr.net/npm/@gradio/lite@0.4.3/dist/lite.css"
    />
    <script type="module">
      const MODEL_URL =
        "https://huggingface.co/lmz/candle-sam/resolve/main/mobile_sam-tiny-vitt.safetensors";

      const samWorker = new Worker("./samWorker.js", { type: "module" });

      async function segmentPoints(
        imageURL, // URL to the image file
        points, // {x, y} points to prompt image
        modelURL = MODEL_URL, // URL to the weights file
        modelID = "sam_mobile_tiny" // model ID
      ) {
        return new Promise((resolve, reject) => {
          function messageHandler(event) {
            console.log(event.data);
            // if ("status" in event.data) {

            // }
            if ("error" in event.data) {
              samWorker.removeEventListener("message", messageHandler);
              reject(new Error(event.data.error));
            }
            if (event.data.status === "complete-embedding") {
              samWorker.removeEventListener("message", messageHandler);
              resolve();
            }
            if (event.data.status === "complete") {
              samWorker.removeEventListener("message", messageHandler);
              resolve(event.data.output);
            }
          }
          samWorker.addEventListener("message", messageHandler);
          samWorker.postMessage({
            modelURL,
            modelID,
            imageURL,
            points,
          });
        });
      }
      globalThis.segmentPoints = segmentPoints;
    </script>
    <style>
      body {
       color: black;
       background-color: white;
      }

      @media (prefers-color-scheme: dark) {
           body {
               color: white;
               background-color: #0b0f19;
           }
      }
    </style>
  </head>

  <body>
    <!-- prettier-ignore -->
    <gradio-lite>
        import gradio as gr


        get_point_mask = """
        async function getPointMask(image, points) {
            console.log("getting point mask");
            //console.log(image, points)
            const { maskURL } = await segmentPoints(
                image,
                points
            );
            if(points.length == 0){
                return [ null ];
            }
            return [ maskURL ]; 
          }
        """
        def set_points(image, points_state, evt: gr.SelectData):
            points_state.append([evt.index[0]/image.width, evt.index[1]/image.height, True])
            return points_state, points_state
 
        with gr.Blocks() as demo:
            gr.Markdown("""## Segment Anything Model (SAM) with Gradio Lite
            This demo uses [Gradio Lite](https://www.gradio.app/guides/gradio-lite) as UI for running the Segment Anything Model (SAM) with WASM build with [Candle](https://github.com/huggingface/candle).    

            **Note:** The model's first run may take a few seconds as it loads and caches the model in the browser, and then creates the image embeddings. Any subsequent clicks on points will be significantly faster.
            """)
            points_state = gr.State([])
            with gr.Row():
                with gr.Column():
                    image = gr.Image(label="Input Image", type="pil")
                    clear_points = gr.Button(value="Clear Points")
                    points = gr.JSON(label="Input Points", visible=False)
                with gr.Column():
                    mask = gr.Image(label="Output Mask")
            clear_points.click(lambda: ([], []), None, [points, points_state])
            image.select(set_points, inputs=[image, points_state], outputs=[points, points_state])
            points.change(None, inputs=[image, points], outputs=[mask], _js=get_point_mask)
        demo.launch(show_api=False)		
    </gradio-lite>
  </body>
</html>