Xenova HF staff commited on
Commit
b21b1c5
·
verified ·
1 Parent(s): f31a6ca

Upload 4 files

Browse files
Files changed (4) hide show
  1. index.css +116 -0
  2. index.html +39 -18
  3. index.js +295 -0
  4. worker.js +109 -0
index.css ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ padding: 0;
4
+ margin: 0;
5
+ font-family: sans-serif;
6
+ }
7
+
8
+ html,
9
+ body {
10
+ height: 100%;
11
+ }
12
+
13
+ body {
14
+ padding: 16px 32px;
15
+ }
16
+
17
+ body,
18
+ #container,
19
+ #upload-button {
20
+ display: flex;
21
+ flex-direction: column;
22
+ justify-content: center;
23
+ align-items: center;
24
+ }
25
+
26
+ h1 {
27
+ text-align: center;
28
+ }
29
+
30
+ #container {
31
+ position: relative;
32
+ width: 640px;
33
+ height: 420px;
34
+ max-width: 100%;
35
+ max-height: 100%;
36
+ border: 2px dashed #D1D5DB;
37
+ border-radius: 0.75rem;
38
+ overflow: hidden;
39
+ cursor: pointer;
40
+ margin-top: 1rem;
41
+ background-size: 100% 100%;
42
+ background-position: center;
43
+ background-repeat: no-repeat;
44
+ }
45
+
46
+ #mask-output {
47
+ position: absolute;
48
+ width: 100%;
49
+ height: 100%;
50
+ pointer-events: none;
51
+ }
52
+
53
+ #upload-button {
54
+ gap: 0.4rem;
55
+ font-size: 18px;
56
+ cursor: pointer;
57
+ }
58
+
59
+ #upload {
60
+ display: none;
61
+ }
62
+
63
+ svg {
64
+ pointer-events: none;
65
+ }
66
+
67
+ #example {
68
+ font-size: 14px;
69
+ text-decoration: underline;
70
+ cursor: pointer;
71
+ }
72
+
73
+ #example:hover {
74
+ color: #2563EB;
75
+ }
76
+
77
+ canvas {
78
+ position: absolute;
79
+ width: 100%;
80
+ height: 100%;
81
+ opacity: 0.6;
82
+ }
83
+
84
+ #status {
85
+ min-height: 16px;
86
+ margin: 8px 0;
87
+ }
88
+
89
+ .icon {
90
+ height: 16px;
91
+ width: 16px;
92
+ position: absolute;
93
+ transform: translate(-50%, -50%);
94
+ }
95
+
96
+ #controls>button {
97
+ padding: 6px 12px;
98
+ background-color: #3498db;
99
+ color: white;
100
+ border: 1px solid #2980b9;
101
+ border-radius: 5px;
102
+ cursor: pointer;
103
+ font-size: 16px;
104
+ }
105
+
106
+ #controls>button:disabled {
107
+ background-color: #d1d5db;
108
+ color: #6b7280;
109
+ border: 1px solid #9ca3af;
110
+ cursor: not-allowed;
111
+ }
112
+
113
+ #information {
114
+ margin-top: 0.25rem;
115
+ font-size: 15px;
116
+ }
index.html CHANGED
@@ -1,19 +1,40 @@
1
  <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <meta charset="utf-8" />
5
- <meta name="viewport" content="width=device-width" />
6
- <title>My static Space</title>
7
- <link rel="stylesheet" href="style.css" />
8
- </head>
9
- <body>
10
- <div class="card">
11
- <h1>Welcome to your static Space!</h1>
12
- <p>You can modify this app directly by editing <i>index.html</i> in the Files and versions tab.</p>
13
- <p>
14
- Also don't forget to check the
15
- <a href="https://huggingface.co/docs/hub/spaces" target="_blank">Spaces documentation</a>.
16
- </p>
17
- </div>
18
- </body>
19
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <link rel="stylesheet" href="index.css" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+ <title>Transformers.js - Segment Anything</title>
10
+ </head>
11
+
12
+ <body>
13
+ <h1>Segment Anything w/ 🤗 Transformers.js</h1>
14
+ <div id="container">
15
+ <label id="upload-button" for="upload">
16
+ <svg width="25" height="25" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
17
+ <path fill="#000"
18
+ d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z">
19
+ </path>
20
+ </svg>
21
+ Click to upload image
22
+ <label id="example">(or try example)</label>
23
+ </label>
24
+ <canvas id="mask-output"></canvas>
25
+ </div>
26
+ <label id="status"></label>
27
+ <div id="controls">
28
+ <button id="reset-image">Reset image</button>
29
+ <button id="clear-points">Clear points</button>
30
+ <button id="cut-mask" disabled>Cut mask</button>
31
+ </div>
32
+ <p id="information">
33
+ Left click = positive points, right click = negative points.
34
+ </p>
35
+ <input id="upload" type="file" accept="image/*" />
36
+
37
+ <script src="index.js" type="module"></script>
38
+ </body>
39
+
40
+ </html>
index.js ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ // Reference the elements we will use
3
+ const statusLabel = document.getElementById('status');
4
+ const fileUpload = document.getElementById('upload');
5
+ const imageContainer = document.getElementById('container');
6
+ const example = document.getElementById('example');
7
+ const maskCanvas = document.getElementById('mask-output');
8
+ const uploadButton = document.getElementById('upload-button');
9
+ const resetButton = document.getElementById('reset-image');
10
+ const clearButton = document.getElementById('clear-points');
11
+ const cutButton = document.getElementById('cut-mask');
12
+
13
+ // State variables
14
+ let lastPoints = null;
15
+ let isEncoded = false;
16
+ let isDecoding = false;
17
+ let isMultiMaskMode = false;
18
+ let modelReady = false;
19
+ let imageDataURI = null;
20
+
21
+ // Constants
22
+ const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/';
23
+ const EXAMPLE_URL = BASE_URL + 'corgi.jpg';
24
+
25
+ // Create a web worker so that the main (UI) thread is not blocked during inference.
26
+ const worker = new Worker('worker.js', {
27
+ type: 'module',
28
+ });
29
+
30
+ // Preload star and cross images to avoid lag on first click
31
+ const star = new Image();
32
+ star.src = BASE_URL + 'star-icon.png';
33
+ star.className = 'icon';
34
+
35
+ const cross = new Image();
36
+ cross.src = BASE_URL + 'cross-icon.png';
37
+ cross.className = 'icon';
38
+
39
+ // Set up message handler
40
+ worker.addEventListener('message', (e) => {
41
+ const { type, data } = e.data;
42
+ if (type === 'ready') {
43
+ modelReady = true;
44
+ statusLabel.textContent = 'Ready';
45
+
46
+ } else if (type === 'decode_result') {
47
+ isDecoding = false;
48
+
49
+ if (!isEncoded) {
50
+ return; // We are not ready to decode yet
51
+ }
52
+
53
+ if (!isMultiMaskMode && lastPoints) {
54
+ // Perform decoding with the last point
55
+ decode();
56
+ lastPoints = null;
57
+ }
58
+
59
+ const { mask, scores } = data;
60
+
61
+ // Update canvas dimensions (if different)
62
+ if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
63
+ maskCanvas.width = mask.width;
64
+ maskCanvas.height = mask.height;
65
+ }
66
+
67
+ // Create context and allocate buffer for pixel data
68
+ const context = maskCanvas.getContext('2d');
69
+ const imageData = context.createImageData(maskCanvas.width, maskCanvas.height);
70
+
71
+ // Select best mask
72
+ const numMasks = scores.length; // 3
73
+ let bestIndex = 0;
74
+ for (let i = 1; i < numMasks; ++i) {
75
+ if (scores[i] > scores[bestIndex]) {
76
+ bestIndex = i;
77
+ }
78
+ }
79
+ statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
80
+
81
+ // Fill mask with colour
82
+ const pixelData = imageData.data;
83
+ for (let i = 0; i < pixelData.length; ++i) {
84
+ if (mask.data[numMasks * i + bestIndex] === 1) {
85
+ const offset = 4 * i;
86
+ pixelData[offset] = 0; // red
87
+ pixelData[offset + 1] = 114; // green
88
+ pixelData[offset + 2] = 189; // blue
89
+ pixelData[offset + 3] = 255; // alpha
90
+ }
91
+ }
92
+
93
+ // Draw image data to context
94
+ context.putImageData(imageData, 0, 0);
95
+
96
+ } else if (type === 'segment_result') {
97
+ if (data === 'start') {
98
+ statusLabel.textContent = 'Extracting image embedding...';
99
+ } else {
100
+ statusLabel.textContent = 'Embedding extracted!';
101
+ isEncoded = true;
102
+ }
103
+ }
104
+ });
105
+
106
+ function decode() {
107
+ isDecoding = true;
108
+ worker.postMessage({ type: 'decode', data: lastPoints });
109
+ }
110
+
111
+ function clearPointsAndMask() {
112
+ // Reset state
113
+ isMultiMaskMode = false;
114
+ lastPoints = null;
115
+
116
+ // Remove points from previous mask (if any)
117
+ document.querySelectorAll('.icon').forEach(e => e.remove());
118
+
119
+ // Disable cut button
120
+ cutButton.disabled = true;
121
+
122
+ // Reset mask canvas
123
+ maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height);
124
+ }
125
+ clearButton.addEventListener('click', clearPointsAndMask);
126
+
127
+ resetButton.addEventListener('click', () => {
128
+ // Update state
129
+ isEncoded = false;
130
+ imageDataURI = null;
131
+
132
+ // Indicate to worker that we have reset the state
133
+ worker.postMessage({ type: 'reset' });
134
+
135
+ // Clear points and mask (if present)
136
+ clearPointsAndMask();
137
+
138
+ // Update UI
139
+ cutButton.disabled = true;
140
+ imageContainer.style.backgroundImage = 'none';
141
+ uploadButton.style.display = 'flex';
142
+ statusLabel.textContent = 'Ready';
143
+ });
144
+
145
+ function segment(data) {
146
+ // Update state
147
+ isEncoded = false;
148
+ if (!modelReady) {
149
+ statusLabel.textContent = 'Loading model...';
150
+ }
151
+ imageDataURI = data;
152
+
153
+ // Update UI
154
+ imageContainer.style.backgroundImage = `url(${data})`;
155
+ uploadButton.style.display = 'none';
156
+ cutButton.disabled = true;
157
+
158
+ // Instruct worker to segment the image
159
+ worker.postMessage({ type: 'segment', data });
160
+ }
161
+
162
+ // Handle file selection
163
+ fileUpload.addEventListener('change', function (e) {
164
+ const file = e.target.files[0];
165
+ if (!file) {
166
+ return;
167
+ }
168
+
169
+ const reader = new FileReader();
170
+
171
+ // Set up a callback when the file is loaded
172
+ reader.onload = e2 => segment(e2.target.result);
173
+
174
+ reader.readAsDataURL(file);
175
+ });
176
+
177
+ example.addEventListener('click', (e) => {
178
+ e.preventDefault();
179
+ segment(EXAMPLE_URL);
180
+ });
181
+
182
+ function addIcon({ point, label }) {
183
+ const icon = (label === 1 ? star : cross).cloneNode();
184
+ icon.style.left = `${point[0] * 100}%`;
185
+ icon.style.top = `${point[1] * 100}%`;
186
+ imageContainer.appendChild(icon);
187
+ }
188
+
189
+ // Attach hover event to image container
190
+ imageContainer.addEventListener('mousedown', e => {
191
+ if (e.button !== 0 && e.button !== 2) {
192
+ return; // Ignore other buttons
193
+ }
194
+ if (!isEncoded) {
195
+ return; // Ignore if not encoded yet
196
+ }
197
+ if (!isMultiMaskMode) {
198
+ lastPoints = [];
199
+ isMultiMaskMode = true;
200
+ cutButton.disabled = false;
201
+ }
202
+
203
+ const point = getPoint(e);
204
+ lastPoints.push(point);
205
+
206
+ // add icon
207
+ addIcon(point);
208
+
209
+ decode();
210
+ });
211
+
212
+
213
+ // Clamp a value inside a range [min, max]
214
+ function clamp(x, min = 0, max = 1) {
215
+ return Math.max(Math.min(x, max), min)
216
+ }
217
+
218
+ function getPoint(e) {
219
+ // Get bounding box
220
+ const bb = imageContainer.getBoundingClientRect();
221
+
222
+ // Get the mouse coordinates relative to the container
223
+ const mouseX = clamp((e.clientX - bb.left) / bb.width);
224
+ const mouseY = clamp((e.clientY - bb.top) / bb.height);
225
+
226
+ return {
227
+ point: [mouseX, mouseY],
228
+ label: e.button === 2 // right click
229
+ ? 0 // negative prompt
230
+ : 1, // positive prompt
231
+ }
232
+ }
233
+
234
+ // Do not show context menu on right click
235
+ imageContainer.addEventListener('contextmenu', e => {
236
+ e.preventDefault();
237
+ });
238
+
239
+ // Attach hover event to image container
240
+ imageContainer.addEventListener('mousemove', e => {
241
+ if (!isEncoded || isMultiMaskMode) {
242
+ // Ignore mousemove events if the image is not encoded yet,
243
+ // or we are in multi-mask mode
244
+ return;
245
+ }
246
+ lastPoints = [getPoint(e)];
247
+
248
+ if (!isDecoding) {
249
+ decode(); // Only decode if we are not already decoding
250
+ }
251
+ });
252
+
253
+ // Handle cut button click
254
+ cutButton.addEventListener('click', () => {
255
+ const [w, h] = [maskCanvas.width, maskCanvas.height];
256
+
257
+ // Get the mask pixel data
258
+ const maskContext = maskCanvas.getContext('2d');
259
+ const maskPixelData = maskContext.getImageData(0, 0, w, h);
260
+
261
+ // Load the image
262
+ const image = new Image();
263
+ image.crossOrigin = 'anonymous';
264
+ image.onload = async () => {
265
+ // Create a new canvas to hold the image
266
+ const imageCanvas = new OffscreenCanvas(w, h);
267
+ const imageContext = imageCanvas.getContext('2d');
268
+ imageContext.drawImage(image, 0, 0, w, h);
269
+ const imagePixelData = imageContext.getImageData(0, 0, w, h);
270
+
271
+ // Create a new canvas to hold the cut-out
272
+ const cutCanvas = new OffscreenCanvas(w, h);
273
+ const cutContext = cutCanvas.getContext('2d');
274
+ const cutPixelData = cutContext.getImageData(0, 0, w, h);
275
+
276
+ // Copy the image pixel data to the cut canvas
277
+ for (let i = 3; i < maskPixelData.data.length; i += 4) {
278
+ if (maskPixelData.data[i] > 0) {
279
+ for (let j = 0; j < 4; ++j) {
280
+ const offset = i - j;
281
+ cutPixelData.data[offset] = imagePixelData.data[offset];
282
+ }
283
+ }
284
+ }
285
+ cutContext.putImageData(cutPixelData, 0, 0);
286
+
287
+ // Download image
288
+ const link = document.createElement('a');
289
+ link.download = 'image.png';
290
+ link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
291
+ link.click();
292
+ link.remove();
293
+ }
294
+ image.src = imageDataURI;
295
+ });
worker.js ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { env, SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.14.0';
2
+
3
+ // Since we will download the model from the Hugging Face Hub, we can skip the local model check
4
+ env.allowLocalModels = false;
5
+
6
+ // We adopt the singleton pattern to enable lazy-loading of the model and processor.
7
+ export class SegmentAnythingSingleton {
8
+ static model_id = 'Xenova/slimsam-77-uniform';
9
+ static model;
10
+ static processor;
11
+ static quantized = true;
12
+
13
+ static getInstance() {
14
+ if (!this.model) {
15
+ this.model = SamModel.from_pretrained(this.model_id, {
16
+ quantized: this.quantized,
17
+ });
18
+ }
19
+ if (!this.processor) {
20
+ this.processor = AutoProcessor.from_pretrained(this.model_id);
21
+ }
22
+
23
+ return Promise.all([this.model, this.processor]);
24
+ }
25
+ }
26
+
27
+
28
+ // State variables
29
+ let image_embeddings = null;
30
+ let image_inputs = null;
31
+ let ready = false;
32
+
33
+ self.onmessage = async (e) => {
34
+ const [model, processor] = await SegmentAnythingSingleton.getInstance();
35
+ if (!ready) {
36
+ // Indicate that we are ready to accept requests
37
+ ready = true;
38
+ self.postMessage({
39
+ type: 'ready',
40
+ });
41
+ }
42
+
43
+ const { type, data } = e.data;
44
+ if (type === 'reset') {
45
+ image_inputs = null;
46
+ image_embeddings = null;
47
+
48
+ } else if (type === 'segment') {
49
+ // Indicate that we are starting to segment the image
50
+ self.postMessage({
51
+ type: 'segment_result',
52
+ data: 'start',
53
+ });
54
+
55
+ // Read the image and recompute image embeddings
56
+ const image = await RawImage.read(e.data.data);
57
+ image_inputs = await processor(image);
58
+ image_embeddings = await model.get_image_embeddings(image_inputs)
59
+
60
+ // Indicate that we have computed the image embeddings, and we are ready to accept decoding requests
61
+ self.postMessage({
62
+ type: 'segment_result',
63
+ data: 'done',
64
+ });
65
+
66
+ } else if (type === 'decode') {
67
+ // Prepare inputs for decoding
68
+ const reshaped = image_inputs.reshaped_input_sizes[0];
69
+ const points = data.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
70
+ const labels = data.map(x => BigInt(x.label));
71
+
72
+ const input_points = new Tensor(
73
+ 'float32',
74
+ points.flat(Infinity),
75
+ [1, 1, points.length, 2],
76
+ )
77
+ const input_labels = new Tensor(
78
+ 'int64',
79
+ labels.flat(Infinity),
80
+ [1, 1, labels.length],
81
+ )
82
+
83
+ // Generate the mask
84
+ const outputs = await model({
85
+ ...image_embeddings,
86
+ input_points,
87
+ input_labels,
88
+ })
89
+
90
+ // Post-process the mask
91
+ const masks = await processor.post_process_masks(
92
+ outputs.pred_masks,
93
+ image_inputs.original_sizes,
94
+ image_inputs.reshaped_input_sizes,
95
+ );
96
+
97
+ // Send the result back to the main thread
98
+ self.postMessage({
99
+ type: 'decode_result',
100
+ data: {
101
+ mask: RawImage.fromTensor(masks[0][0]),
102
+ scores: outputs.iou_scores.data,
103
+ },
104
+ });
105
+
106
+ } else {
107
+ throw new Error(`Unknown message type: ${type}`);
108
+ }
109
+ }