Plachta commited on
Commit
fcdfd72
1 Parent(s): 6aa57fc

Upload 50 files

Browse files
Files changed (50) hide show
  1. README.md +1 -1
  2. app.py +131 -0
  3. gradio_image_prompter-0.1.0-py3-none-any.whl +0 -0
  4. models/__init__.py +15 -0
  5. models/automatic_mask_generator.py +372 -0
  6. models/build_sam.py +107 -0
  7. models/grasp_mods.py +318 -0
  8. models/modeling/__init__.py +11 -0
  9. models/modeling/common.py +43 -0
  10. models/modeling/image_encoder.py +395 -0
  11. models/modeling/mask_decoder.py +176 -0
  12. models/modeling/prompt_encoder.py +214 -0
  13. models/modeling/sam.py +174 -0
  14. models/modeling/transformer.py +240 -0
  15. models/predictor.py +269 -0
  16. models/utils/__init__.py +5 -0
  17. models/utils/amg.py +346 -0
  18. models/utils/onnx.py +144 -0
  19. models/utils/transforms.py +102 -0
  20. requirements.txt +5 -0
  21. src/.gitignore +9 -0
  22. src/LICENSE +201 -0
  23. src/README.md +48 -0
  24. src/backend/gradio_image_prompter/__init__.py +3 -0
  25. src/backend/gradio_image_prompter/image_prompter.py +133 -0
  26. src/backend/gradio_image_prompter/image_prompter.pyi +134 -0
  27. src/backend/gradio_image_prompter/templates/component/__vite-browser-external-2447137e.js +4 -0
  28. src/backend/gradio_image_prompter/templates/component/index.js +0 -0
  29. src/backend/gradio_image_prompter/templates/component/style.css +1 -0
  30. src/backend/gradio_image_prompter/templates/component/wrapper-6f348d45-f837cf34.js +2455 -0
  31. src/backend/gradio_image_prompter/templates/example/index.js +263 -0
  32. src/backend/gradio_image_prompter/templates/example/style.css +1 -0
  33. src/demo/__init__.py +0 -0
  34. src/demo/app.py +9 -0
  35. src/frontend/Example.svelte +44 -0
  36. src/frontend/Index.svelte +167 -0
  37. src/frontend/package-lock.json +718 -0
  38. src/frontend/package.json +28 -0
  39. src/frontend/shared/BoxDrawer.svelte +237 -0
  40. src/frontend/shared/ClearImage.svelte +48 -0
  41. src/frontend/shared/Image.svelte +15 -0
  42. src/frontend/shared/ImagePreview.svelte +88 -0
  43. src/frontend/shared/ImageUploader.svelte +192 -0
  44. src/frontend/shared/utils.ts +24 -0
  45. src/pyproject.toml +43 -0
  46. structures/__init__.py +0 -0
  47. structures/bounding_box.py +323 -0
  48. structures/grasp_box.py +127 -0
  49. structures/image_list.py +67 -0
  50. structures/segmentation_mask.py +298 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: GraspAnything
3
- emoji: 🚀
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: GraspAnything
3
+ emoji: 🤖✊
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import torch
4
+
5
+ import sys
6
+ sys.path.append("./")
7
+ from models import sam_model_registry
8
+ from models.grasp_mods import modify_forward
9
+ from models.utils.transforms import ResizeLongestSide
10
+
11
+ from gradio_image_prompter import ImagePrompter
12
+ from structures.grasp_box import GraspCoder
13
+ img_resize = ResizeLongestSide(1024)
14
+ import cv2
15
+
16
+ import gradio as gr
17
+
18
+ from models.grasp_mods import add_inference_method
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model_type = "vit_b"
22
+
23
+ mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis]
24
+ std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis]
25
+
26
+ sam = sam_model_registry[model_type]()
27
+ sam.to(device=device)
28
+
29
+ sam.forward = modify_forward(sam)
30
+ sam.infer = add_inference_method(sam)
31
+
32
+ pretrained_model_path = "E:/epoch_9_step_535390.pth"
33
+
34
+ if pretrained_model_path != "":
35
+ sd = torch.load(pretrained_model_path)
36
+ # strip prefix "module." from keys
37
+ new_sd = {}
38
+ for k, v in sd.items():
39
+ if k.startswith("module."):
40
+ k = k[7:]
41
+ new_sd[k] = v
42
+ sam.load_state_dict(new_sd)
43
+
44
+ sam.eval()
45
+
46
+ def predict(input, topk):
47
+ np_image = input["image"]
48
+ points = input["points"]
49
+ orig_size = np_image.shape[:2]
50
+ # normalize image
51
+ np_image = np_image.transpose(2, 0, 1)
52
+
53
+ image = (np_image - mean) / std
54
+ image = torch.tensor(image).float().to(device)
55
+ image = image.unsqueeze(0)
56
+ t_image = img_resize.apply_image_torch(image)
57
+ t_orig_size = t_image.shape[-2:]
58
+ # pad to 1024x1024
59
+ t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
60
+
61
+ # get box prompt
62
+ valid_boxes = []
63
+ for point in points:
64
+ x1, y1, type1, x2, y2, type2 = point
65
+ if type1 == 2 and type2 == 3:
66
+ valid_boxes.append([x1, y1, x2, y2])
67
+ if len(valid_boxes) == 0:
68
+ return np_image
69
+ t_boxes = np.array(valid_boxes)
70
+ t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
71
+ box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
72
+ batched_inputs = [{"image": t_image[0], "boxes": box_torch}]
73
+ with torch.no_grad():
74
+ outputs = sam.infer(batched_inputs, multimask_output=False)
75
+ # visualize and post on tensorboard
76
+ # recover image
77
+ recovered_img = batched_inputs[0]['image'].cpu().numpy()
78
+ recovered_img = recovered_img * std + mean
79
+ recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255)
80
+
81
+ for i in range(len(outputs.pred_masks)):
82
+ # get predicted mask
83
+ pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5
84
+ pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2)
85
+
86
+ # get predicted grasp
87
+ pred_logits = outputs.logits[i].detach().cpu().numpy()
88
+ top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
89
+ pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
90
+ coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp)
91
+ _ = coded_grasp.decode()
92
+ decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
93
+
94
+ # draw mask
95
+ mask_color = np.array([0, 255, 0])[None, None, :]
96
+ recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5
97
+
98
+ # draw grasp
99
+ recovered_img = np.ascontiguousarray(recovered_img)
100
+ for grasp in decoded_grasp:
101
+ grasp = grasp.astype(int)
102
+ cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1)
103
+ cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1)
104
+ cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2)
105
+ cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2)
106
+
107
+ recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]]
108
+ # resize to original size
109
+ recovered_img = cv2.resize(recovered_img, (orig_size[0], orig_size[1]))
110
+ return recovered_img
111
+
112
+ if __name__ == "__main__":
113
+ app = gr.Blocks(title="GraspAnything")
114
+ with app:
115
+ gr.Markdown("""
116
+ # GraspAnything <br>
117
+ Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object.
118
+ """)
119
+ with gr.Column():
120
+ prompter = ImagePrompter(show_label=False)
121
+ top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps")
122
+ with gr.Column():
123
+ image_output = gr.Image()
124
+ btn = gr.Button("Generate!")
125
+ btn.click(predict,
126
+ inputs=[prompter, top_k],
127
+ outputs=[image_output])
128
+ app.launch()
129
+
130
+
131
+
gradio_image_prompter-0.1.0-py3-none-any.whl ADDED
Binary file (96.2 kB). View file
 
models/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
14
+ from .predictor import SamPredictor
15
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
models/automatic_mask_generator.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
+
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (
16
+ MaskData,
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ remove_small_regions,
28
+ rle_to_mask,
29
+ uncrop_boxes_xyxy,
30
+ uncrop_masks,
31
+ uncrop_points,
32
+ )
33
+
34
+
35
+ class SamAutomaticMaskGenerator:
36
+ def __init__(
37
+ self,
38
+ model: Sam,
39
+ points_per_side: Optional[int] = 32,
40
+ points_per_batch: int = 64,
41
+ pred_iou_thresh: float = 0.88,
42
+ stability_score_thresh: float = 0.95,
43
+ stability_score_offset: float = 1.0,
44
+ box_nms_thresh: float = 0.7,
45
+ crop_n_layers: int = 0,
46
+ crop_nms_thresh: float = 0.7,
47
+ crop_overlap_ratio: float = 512 / 1500,
48
+ crop_n_points_downscale_factor: int = 1,
49
+ point_grids: Optional[List[np.ndarray]] = None,
50
+ min_mask_region_area: int = 0,
51
+ output_mode: str = "binary_mask",
52
+ ) -> None:
53
+ """
54
+ Using a SAM model, generates masks for the entire image.
55
+ Generates a grid of point prompts over the image, then filters
56
+ low quality and duplicate masks. The default settings are chosen
57
+ for SAM with a ViT-H backbone.
58
+
59
+ Arguments:
60
+ model (Sam): The SAM model to use for mask prediction.
61
+ points_per_side (int or None): The number of points to be sampled
62
+ along one side of the image. The total number of points is
63
+ points_per_side**2. If None, 'point_grids' must provide explicit
64
+ point sampling.
65
+ points_per_batch (int): Sets the number of points run simultaneously
66
+ by the model. Higher numbers may be faster but use more GPU memory.
67
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
68
+ model's predicted mask quality.
69
+ stability_score_thresh (float): A filtering threshold in [0,1], using
70
+ the stability of the mask under changes to the cutoff used to binarize
71
+ the model's mask predictions.
72
+ stability_score_offset (float): The amount to shift the cutoff when
73
+ calculated the stability score.
74
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
75
+ suppression to filter duplicate masks.
76
+ crop_n_layers (int): If >0, mask prediction will be run again on
77
+ crops of the image. Sets the number of layers to run, where each
78
+ layer has 2**i_layer number of image crops.
79
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks between different crops.
81
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
82
+ In the first crop layer, crops will overlap by this fraction of
83
+ the image length. Later layers with more crops scale down this overlap.
84
+ crop_n_points_downscale_factor (int): The number of points-per-side
85
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86
+ point_grids (list(np.ndarray) or None): A list over explicit grids
87
+ of points used for sampling, normalized to [0,1]. The nth grid in the
88
+ list is used in the nth crop layer. Exclusive with points_per_side.
89
+ min_mask_region_area (int): If >0, postprocessing will be applied
90
+ to remove disconnected regions and holes in masks with area smaller
91
+ than min_mask_region_area. Requires opencv.
92
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
93
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94
+ For large resolutions, 'binary_mask' may consume large amounts of
95
+ memory.
96
+ """
97
+
98
+ assert (points_per_side is None) != (
99
+ point_grids is None
100
+ ), "Exactly one of points_per_side or point_grid must be provided."
101
+ if points_per_side is not None:
102
+ self.point_grids = build_all_layer_point_grids(
103
+ points_per_side,
104
+ crop_n_layers,
105
+ crop_n_points_downscale_factor,
106
+ )
107
+ elif point_grids is not None:
108
+ self.point_grids = point_grids
109
+ else:
110
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
111
+
112
+ assert output_mode in [
113
+ "binary_mask",
114
+ "uncompressed_rle",
115
+ "coco_rle",
116
+ ], f"Unknown output_mode {output_mode}."
117
+ if output_mode == "coco_rle":
118
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119
+
120
+ if min_mask_region_area > 0:
121
+ import cv2 # type: ignore # noqa: F401
122
+
123
+ self.predictor = SamPredictor(model)
124
+ self.points_per_batch = points_per_batch
125
+ self.pred_iou_thresh = pred_iou_thresh
126
+ self.stability_score_thresh = stability_score_thresh
127
+ self.stability_score_offset = stability_score_offset
128
+ self.box_nms_thresh = box_nms_thresh
129
+ self.crop_n_layers = crop_n_layers
130
+ self.crop_nms_thresh = crop_nms_thresh
131
+ self.crop_overlap_ratio = crop_overlap_ratio
132
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133
+ self.min_mask_region_area = min_mask_region_area
134
+ self.output_mode = output_mode
135
+
136
+ @torch.no_grad()
137
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138
+ """
139
+ Generates masks for the given image.
140
+
141
+ Arguments:
142
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143
+
144
+ Returns:
145
+ list(dict(str, any)): A list over records for masks. Each record is
146
+ a dict containing the following keys:
147
+ segmentation (dict(str, any) or np.ndarray): The mask. If
148
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
149
+ is a dictionary containing the RLE.
150
+ bbox (list(float)): The box around the mask, in XYWH format.
151
+ area (int): The area in pixels of the mask.
152
+ predicted_iou (float): The model's own prediction of the mask's
153
+ quality. This is filtered by the pred_iou_thresh parameter.
154
+ point_coords (list(list(float))): The point coordinates input
155
+ to the model to generate this mask.
156
+ stability_score (float): A measure of the mask's quality. This
157
+ is filtered on using the stability_score_thresh parameter.
158
+ crop_box (list(float)): The crop of the image used to generate
159
+ the mask, given in XYWH format.
160
+ """
161
+
162
+ # Generate masks
163
+ mask_data = self._generate_masks(image)
164
+
165
+ # Filter small disconnected regions and holes in masks
166
+ if self.min_mask_region_area > 0:
167
+ mask_data = self.postprocess_small_regions(
168
+ mask_data,
169
+ self.min_mask_region_area,
170
+ max(self.box_nms_thresh, self.crop_nms_thresh),
171
+ )
172
+
173
+ # Encode masks
174
+ if self.output_mode == "coco_rle":
175
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176
+ elif self.output_mode == "binary_mask":
177
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178
+ else:
179
+ mask_data["segmentations"] = mask_data["rles"]
180
+
181
+ # Write mask records
182
+ curr_anns = []
183
+ for idx in range(len(mask_data["segmentations"])):
184
+ ann = {
185
+ "segmentation": mask_data["segmentations"][idx],
186
+ "area": area_from_rle(mask_data["rles"][idx]),
187
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
189
+ "point_coords": [mask_data["points"][idx].tolist()],
190
+ "stability_score": mask_data["stability_score"][idx].item(),
191
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192
+ }
193
+ curr_anns.append(ann)
194
+
195
+ return curr_anns
196
+
197
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
198
+ orig_size = image.shape[:2]
199
+ crop_boxes, layer_idxs = generate_crop_boxes(
200
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
201
+ )
202
+
203
+ # Iterate over image crops
204
+ data = MaskData()
205
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207
+ data.cat(crop_data)
208
+
209
+ # Remove duplicate masks between crops
210
+ if len(crop_boxes) > 1:
211
+ # Prefer masks from smaller crops
212
+ scores = 1 / box_area(data["crop_boxes"])
213
+ scores = scores.to(data["boxes"].device)
214
+ keep_by_nms = batched_nms(
215
+ data["boxes"].float(),
216
+ scores,
217
+ torch.zeros_like(data["boxes"][:, 0]), # categories
218
+ iou_threshold=self.crop_nms_thresh,
219
+ )
220
+ data.filter(keep_by_nms)
221
+
222
+ data.to_numpy()
223
+ return data
224
+
225
+ def _process_crop(
226
+ self,
227
+ image: np.ndarray,
228
+ crop_box: List[int],
229
+ crop_layer_idx: int,
230
+ orig_size: Tuple[int, ...],
231
+ ) -> MaskData:
232
+ # Crop the image and calculate embeddings
233
+ x0, y0, x1, y1 = crop_box
234
+ cropped_im = image[y0:y1, x0:x1, :]
235
+ cropped_im_size = cropped_im.shape[:2]
236
+ self.predictor.set_image(cropped_im)
237
+
238
+ # Get points for this crop
239
+ points_scale = np.array(cropped_im_size)[None, ::-1]
240
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
241
+
242
+ # Generate masks for this crop in batches
243
+ data = MaskData()
244
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246
+ data.cat(batch_data)
247
+ del batch_data
248
+ self.predictor.reset_image()
249
+
250
+ # Remove duplicates within this crop.
251
+ keep_by_nms = batched_nms(
252
+ data["boxes"].float(),
253
+ data["iou_preds"],
254
+ torch.zeros_like(data["boxes"][:, 0]), # categories
255
+ iou_threshold=self.box_nms_thresh,
256
+ )
257
+ data.filter(keep_by_nms)
258
+
259
+ # Return to the original image frame
260
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261
+ data["points"] = uncrop_points(data["points"], crop_box)
262
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263
+
264
+ return data
265
+
266
+ def _process_batch(
267
+ self,
268
+ points: np.ndarray,
269
+ im_size: Tuple[int, ...],
270
+ crop_box: List[int],
271
+ orig_size: Tuple[int, ...],
272
+ ) -> MaskData:
273
+ orig_h, orig_w = orig_size
274
+
275
+ # Run model on this batch
276
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
277
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279
+ masks, iou_preds, _ = self.predictor.predict_torch(
280
+ in_points[:, None, :],
281
+ in_labels[:, None],
282
+ multimask_output=True,
283
+ return_logits=True,
284
+ )
285
+
286
+ # Serialize predictions and store in MaskData
287
+ data = MaskData(
288
+ masks=masks.flatten(0, 1),
289
+ iou_preds=iou_preds.flatten(0, 1),
290
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291
+ )
292
+ del masks
293
+
294
+ # Filter by predicted IoU
295
+ if self.pred_iou_thresh > 0.0:
296
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
297
+ data.filter(keep_mask)
298
+
299
+ # Calculate stability score
300
+ data["stability_score"] = calculate_stability_score(
301
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302
+ )
303
+ if self.stability_score_thresh > 0.0:
304
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
305
+ data.filter(keep_mask)
306
+
307
+ # Threshold masks and calculate boxes
308
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309
+ data["boxes"] = batched_mask_to_box(data["masks"])
310
+
311
+ # Filter boxes that touch crop boundaries
312
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313
+ if not torch.all(keep_mask):
314
+ data.filter(keep_mask)
315
+
316
+ # Compress to RLE
317
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
319
+ del data["masks"]
320
+
321
+ return data
322
+
323
+ @staticmethod
324
+ def postprocess_small_regions(
325
+ mask_data: MaskData, min_area: int, nms_thresh: float
326
+ ) -> MaskData:
327
+ """
328
+ Removes small disconnected regions and holes in masks, then reruns
329
+ box NMS to remove any new duplicates.
330
+
331
+ Edits mask_data in place.
332
+
333
+ Requires open-cv as a dependency.
334
+ """
335
+ if len(mask_data["rles"]) == 0:
336
+ return mask_data
337
+
338
+ # Filter small disconnected regions and holes
339
+ new_masks = []
340
+ scores = []
341
+ for rle in mask_data["rles"]:
342
+ mask = rle_to_mask(rle)
343
+
344
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
+ unchanged = not changed
346
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
+ unchanged = unchanged and not changed
348
+
349
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
+ # Give score=0 to changed masks and score=1 to unchanged masks
351
+ # so NMS will prefer ones that didn't need postprocessing
352
+ scores.append(float(unchanged))
353
+
354
+ # Recalculate boxes and remove any new duplicates
355
+ masks = torch.cat(new_masks, dim=0)
356
+ boxes = batched_mask_to_box(masks)
357
+ keep_by_nms = batched_nms(
358
+ boxes.float(),
359
+ torch.as_tensor(scores),
360
+ torch.zeros_like(boxes[:, 0]), # categories
361
+ iou_threshold=nms_thresh,
362
+ )
363
+
364
+ # Only recalculate RLEs for masks that have changed
365
+ for i_mask in keep_by_nms:
366
+ if scores[i_mask] == 0.0:
367
+ mask_torch = masks[i_mask].unsqueeze(0)
368
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
+ mask_data.filter(keep_by_nms)
371
+
372
+ return mask_data
models/build_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ )
22
+
23
+
24
+ build_sam = build_sam_vit_h
25
+
26
+
27
+ def build_sam_vit_l(checkpoint=None):
28
+ return _build_sam(
29
+ encoder_embed_dim=1024,
30
+ encoder_depth=24,
31
+ encoder_num_heads=16,
32
+ encoder_global_attn_indexes=[5, 11, 17, 23],
33
+ checkpoint=checkpoint,
34
+ )
35
+
36
+
37
+ def build_sam_vit_b(checkpoint=None):
38
+ return _build_sam(
39
+ encoder_embed_dim=768,
40
+ encoder_depth=12,
41
+ encoder_num_heads=12,
42
+ encoder_global_attn_indexes=[2, 5, 8, 11],
43
+ checkpoint=checkpoint,
44
+ )
45
+
46
+
47
+ sam_model_registry = {
48
+ "default": build_sam_vit_h,
49
+ "vit_h": build_sam_vit_h,
50
+ "vit_l": build_sam_vit_l,
51
+ "vit_b": build_sam_vit_b,
52
+ }
53
+
54
+
55
+ def _build_sam(
56
+ encoder_embed_dim,
57
+ encoder_depth,
58
+ encoder_num_heads,
59
+ encoder_global_attn_indexes,
60
+ checkpoint=None,
61
+ ):
62
+ prompt_embed_dim = 256
63
+ image_size = 1024
64
+ vit_patch_size = 16
65
+ image_embedding_size = image_size // vit_patch_size
66
+ sam = Sam(
67
+ image_encoder=ImageEncoderViT(
68
+ depth=encoder_depth,
69
+ embed_dim=encoder_embed_dim,
70
+ img_size=image_size,
71
+ mlp_ratio=4,
72
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73
+ num_heads=encoder_num_heads,
74
+ patch_size=vit_patch_size,
75
+ qkv_bias=True,
76
+ use_rel_pos=True,
77
+ global_attn_indexes=encoder_global_attn_indexes,
78
+ window_size=14,
79
+ out_chans=prompt_embed_dim,
80
+ ),
81
+ prompt_encoder=PromptEncoder(
82
+ embed_dim=prompt_embed_dim,
83
+ image_embedding_size=(image_embedding_size, image_embedding_size),
84
+ input_image_size=(image_size, image_size),
85
+ mask_in_chans=16,
86
+ ),
87
+ mask_decoder=MaskDecoder(
88
+ num_multimask_outputs=3,
89
+ transformer=TwoWayTransformer(
90
+ depth=2,
91
+ embedding_dim=prompt_embed_dim,
92
+ mlp_dim=2048,
93
+ num_heads=8,
94
+ ),
95
+ transformer_dim=prompt_embed_dim,
96
+ iou_head_depth=3,
97
+ iou_head_hidden_dim=256,
98
+ ),
99
+ pixel_mean=[123.675, 116.28, 103.53],
100
+ pixel_std=[58.395, 57.12, 57.375],
101
+ )
102
+ sam.eval()
103
+ if checkpoint is not None:
104
+ with open(checkpoint, "rb") as f:
105
+ state_dict = torch.load(f)
106
+ sam.load_state_dict(state_dict)
107
+ return sam
models/grasp_mods.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Add additional grasp decoder for Segment Anything model.
3
+ The structure should follow the grasp decoder structure in GraspDETR.
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.models.detr.configuration_detr import DetrConfig
8
+ from transformers.models.detr.modeling_detr import DetrHungarianMatcher, DetrLoss, DetrSegmentationOutput, DetrDecoder, sigmoid_focal_loss, dice_loss
9
+ from typing import Any, Dict, List, Tuple
10
+ from transformers.models.detr.modeling_detr import generalized_box_iou
11
+ from transformers.image_transforms import center_to_corners_format
12
+ from scipy.optimize import linear_sum_assignment
13
+
14
+ def modify_matcher_forward(self):
15
+ @torch.no_grad()
16
+ def matcher_forward(outputs, targets):
17
+
18
+ batch_size, num_queries = outputs["logits"].shape[:2]
19
+
20
+ # We flatten to compute the cost matrices in a batch
21
+ out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
22
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
23
+
24
+ # Also concat the target labels and boxes
25
+ target_ids = torch.cat([v["class_labels"] for v in targets])
26
+ target_bbox = torch.cat([v["boxes"] for v in targets])
27
+
28
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
29
+ # but approximate it in 1 - proba[target class].
30
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
31
+ class_cost = -out_prob[:, target_ids]
32
+
33
+ # Compute the L1 cost between boxes
34
+ bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
35
+
36
+ # Compute the giou cost between boxes
37
+ giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox[:, :4]), center_to_corners_format(target_bbox[:, :4]))
38
+
39
+ # Final cost matrix
40
+ cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
41
+ cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
42
+
43
+ sizes = [len(v["boxes"]) for v in targets]
44
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
45
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
46
+ return matcher_forward
47
+
48
+ def modify_grasp_loss_forward(self):
49
+ def modified_loss_labels(outputs, targets, indices, num_boxes):
50
+ """
51
+ Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
52
+ [nb_target_boxes]
53
+ """
54
+ num_classes = 1 # model v9 always use class agnostic grasp
55
+ if "logits" not in outputs:
56
+ raise KeyError("No logits were found in the outputs")
57
+ source_logits = outputs["logits"]
58
+
59
+ idx = self._get_source_permutation_idx(indices)
60
+ target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
61
+ target_classes = torch.full(
62
+ source_logits.shape[:2], num_classes, dtype=torch.int64, device=source_logits.device
63
+ )
64
+ target_classes[idx] = target_classes_o
65
+
66
+ loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes)
67
+ losses = {"loss_ce": loss_ce}
68
+
69
+ return losses
70
+
71
+ def modified_loss_boxes(outputs, targets, indices, num_boxes):
72
+
73
+ if "pred_boxes" not in outputs:
74
+ raise KeyError("No predicted boxes found in outputs")
75
+ idx = self._get_source_permutation_idx(indices)
76
+ source_boxes = outputs["pred_boxes"][idx]
77
+ target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
78
+
79
+ loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
80
+
81
+ losses = {}
82
+ losses["loss_bbox"] = loss_bbox.sum() / num_boxes
83
+
84
+ loss_giou = 1 - torch.diag(
85
+ generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
86
+ )
87
+ losses["loss_giou"] = loss_giou.sum() / num_boxes
88
+ return losses
89
+ return modified_loss_labels, modified_loss_boxes
90
+
91
+ def modify_forward(self):
92
+ """
93
+ Modify the following methods to make SAM perform grasp detection after segmentation:
94
+ 1. Add a parallel decoder for grasping detection: 1(+1) classes, 5 values to regress (bbox & rotation)
95
+ Returns:
96
+ Modified model
97
+ """
98
+ # 1. We instantiate a new module in self.base_model, as another decoder
99
+ self.grasp_decoder_config = DetrConfig()
100
+ self.grasp_decoder = DetrDecoder(self.grasp_decoder_config).to(self.device)
101
+ self.grasp_query_position_embeddings = nn.Embedding(20, 256).to(self.device)
102
+ # 2. Base model forward method is not directly used, no modification needs to be done
103
+ # self.detr.model.forward = modify_base_model_forward(self.detr.model)
104
+ # 3. Add additional classification head & bbox regression head for grasp_decoder output
105
+ self.grasp_predictor = torch.nn.Sequential(
106
+ torch.nn.Linear(256, 256),
107
+ torch.nn.Linear(256, 256),
108
+ torch.nn.Linear(256, 5)
109
+ ).to(self.device)
110
+ self.grasp_label_classifier = torch.nn.Linear(256, 2).to(self.device)
111
+ # 4. Add positional embedding
112
+ # name it as grasp_img_pos_embed to avoid name conflict
113
+ class ImagePosEmbed(nn.Module):
114
+ def __init__(self, img_size=64, hidden_dim=256):
115
+ super().__init__()
116
+ self.pos_embed = nn.Parameter(
117
+ torch.randn(1, img_size, img_size, hidden_dim)
118
+ )
119
+ def forward(self, x):
120
+ return x + self.pos_embed
121
+
122
+ self.grasp_img_pos_embed = ImagePosEmbed().to(self.device)
123
+
124
+ def modified_forward(
125
+ batched_input: List[Dict[str, Any]],
126
+ multimask_output: bool,
127
+ ):
128
+ input_images = torch.stack([x["image"] for x in batched_input], dim=0)
129
+ image_embeddings = self.image_encoder(input_images)
130
+
131
+ outputs = []
132
+ srcs = []
133
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
134
+ if "point_coords" in image_record:
135
+ points = (image_record["point_coords"], image_record["point_labels"])
136
+ else:
137
+ points = None
138
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
139
+ points=points,
140
+ boxes=image_record.get("boxes", None),
141
+ masks=image_record.get("mask_inputs", None),
142
+ )
143
+ low_res_masks, iou_predictions, src = self.mask_decoder(
144
+ image_embeddings=curr_embedding.unsqueeze(0),
145
+ image_pe=self.prompt_encoder.get_dense_pe(),
146
+ sparse_prompt_embeddings=sparse_embeddings,
147
+ dense_prompt_embeddings=dense_embeddings,
148
+ multimask_output=multimask_output,
149
+ )
150
+ outputs.append(
151
+ {
152
+ "iou_predictions": iou_predictions,
153
+ "low_res_logits": low_res_masks,
154
+ }
155
+ )
156
+ srcs.append(src[0])
157
+ srcs = torch.stack(srcs, dim=0)
158
+ # forward grasp decoder here
159
+ # 1. Get encoder hidden states
160
+ grasp_encoder_hidden_states = self.grasp_img_pos_embed(srcs.permute(0, 2, 3, 1))
161
+ # 2. Get query embeddings
162
+ grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
163
+ # repeat to batchsize
164
+ grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
165
+ grasp_decoder_outputs = self.grasp_decoder(
166
+ inputs_embeds=torch.zeros_like(grasp_query_pe),
167
+ attention_mask=None,
168
+ position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
169
+ query_position_embeddings=grasp_query_pe,
170
+ encoder_hidden_states=grasp_encoder_hidden_states,
171
+ encoder_attention_mask=None,
172
+ output_attentions=False,
173
+ output_hidden_states=False,
174
+ return_dict=True,
175
+ )
176
+ grasp_sequence_output = grasp_decoder_outputs[0]
177
+ grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
178
+ pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()
179
+
180
+ # 3. Calculate loss
181
+ loss, loss_dict = 0, {}
182
+ if "grasp_labels" in batched_input[0]:
183
+ config = self.grasp_decoder_config
184
+ grasp_labels = [{
185
+ "class_labels": torch.zeros([len(x["grasp_labels"])], dtype=torch.long).to(self.device),
186
+ "boxes": x["grasp_labels"],
187
+ } for x in batched_input]
188
+ # First: create the matcher
189
+ matcher = DetrHungarianMatcher(
190
+ class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
191
+ )
192
+ matcher.forward = modify_matcher_forward(matcher)
193
+ # Second: create the criterion
194
+ losses = ["labels", "boxes"]
195
+ criterion = DetrLoss(
196
+ matcher=matcher,
197
+ num_classes=config.num_labels,
198
+ eos_coef=config.eos_coefficient,
199
+ losses=losses,
200
+ )
201
+ criterion.loss_labels, criterion.loss_boxes = modify_grasp_loss_forward(criterion)
202
+ criterion.to(self.device)
203
+ # Third: compute the losses, based on outputs and labels
204
+ outputs_loss = {}
205
+ outputs_loss["logits"] = grasp_logits
206
+ outputs_loss["pred_boxes"] = pred_grasps
207
+
208
+ grasp_loss_dict = criterion(outputs_loss, grasp_labels)
209
+ # Fourth: compute total loss, as a weighted sum of the various losses
210
+ weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
211
+ weight_dict["loss_giou"] = config.giou_loss_coefficient
212
+ if config.auxiliary_loss:
213
+ aux_weight_dict = {}
214
+ for i in range(config.decoder_layers - 1):
215
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
216
+ weight_dict.update(aux_weight_dict)
217
+ grasp_loss = sum(grasp_loss_dict[k] * weight_dict[k] for k in grasp_loss_dict.keys() if k in weight_dict)
218
+
219
+ # merge grasp branch loss into variable loss & loss_dict
220
+ loss += grasp_loss
221
+ loss_dict.update(grasp_loss_dict)
222
+ pred_masks = self.postprocess_masks(
223
+ torch.cat([x['low_res_logits'] for x in outputs], dim=0),
224
+ input_size=image_record["image"].shape[-2:],
225
+ original_size=(1024, 1024),
226
+ )
227
+ if 'masks' in batched_input[0]:
228
+ # 4. Calculate segmentation loss
229
+ sf_loss = sigmoid_focal_loss(pred_masks.flatten(1),
230
+ torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
231
+ d_loss = dice_loss(pred_masks.flatten(1),
232
+ torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
233
+ loss += sf_loss + d_loss
234
+ loss_dict["sf_loss"] = sf_loss
235
+ loss_dict["d_loss"] = d_loss
236
+ return DetrSegmentationOutput(
237
+ loss=loss,
238
+ loss_dict=loss_dict,
239
+ logits=grasp_logits,
240
+ pred_boxes=pred_grasps,
241
+ pred_masks=pred_masks,
242
+ )
243
+
244
+ return modified_forward
245
+
246
+ def add_inference_method(self):
247
+ def infer(
248
+ batched_input: List[Dict[str, Any]],
249
+ multimask_output: bool,
250
+ ):
251
+ input_images = torch.stack([x["image"] for x in batched_input], dim=0)
252
+ image_embeddings = self.image_encoder(input_images)
253
+
254
+ outputs = []
255
+ srcs = []
256
+ curr_embedding = image_embeddings[0]
257
+ image_record = batched_input[0]
258
+
259
+ if "point_coords" in image_record:
260
+ points = (image_record["point_coords"], image_record["point_labels"])
261
+ else:
262
+ points = None
263
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
264
+ points=points,
265
+ boxes=image_record.get("boxes", None),
266
+ masks=image_record.get("mask_inputs", None),
267
+ )
268
+ low_res_masks, iou_predictions, src = self.mask_decoder(
269
+ image_embeddings=curr_embedding.unsqueeze(0),
270
+ image_pe=self.prompt_encoder.get_dense_pe(),
271
+ sparse_prompt_embeddings=sparse_embeddings,
272
+ dense_prompt_embeddings=dense_embeddings,
273
+ multimask_output=multimask_output,
274
+ )
275
+ outputs.append(
276
+ {
277
+ "iou_predictions": iou_predictions,
278
+ "low_res_logits": low_res_masks,
279
+ }
280
+ )
281
+ srcs.append(src[0])
282
+
283
+ n_queries = iou_predictions.size(0)
284
+
285
+ # forward grasp decoder here
286
+ # 1. Get encoder hidden states
287
+ grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
288
+ # 2. Get query embeddings
289
+ grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
290
+ # repeat to batchsize
291
+ grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
292
+ grasp_decoder_outputs = self.grasp_decoder(
293
+ inputs_embeds=torch.zeros_like(grasp_query_pe),
294
+ attention_mask=None,
295
+ position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
296
+ query_position_embeddings=grasp_query_pe,
297
+ encoder_hidden_states=grasp_encoder_hidden_states,
298
+ encoder_attention_mask=None,
299
+ output_attentions=False,
300
+ output_hidden_states=False,
301
+ return_dict=True,
302
+ )
303
+ grasp_sequence_output = grasp_decoder_outputs[0]
304
+ grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
305
+ pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()
306
+ pred_masks = self.postprocess_masks(
307
+ torch.cat([x['low_res_logits'] for x in outputs], dim=0),
308
+ input_size=image_record["image"].shape[-2:],
309
+ original_size=(1024, 1024),
310
+ )
311
+ return DetrSegmentationOutput(
312
+ loss=0,
313
+ loss_dict={},
314
+ logits=grasp_logits,
315
+ pred_boxes=pred_grasps,
316
+ pred_masks=pred_masks,
317
+ )
318
+ return infer
models/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam
8
+ from .image_encoder import ImageEncoderViT
9
+ from .mask_decoder import MaskDecoder
10
+ from .prompt_encoder import PromptEncoder
11
+ from .transformer import TwoWayTransformer
models/modeling/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
models/modeling/image_encoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+
58
+ self.patch_embed = PatchEmbed(
59
+ kernel_size=(patch_size, patch_size),
60
+ stride=(patch_size, patch_size),
61
+ in_chans=in_chans,
62
+ embed_dim=embed_dim,
63
+ )
64
+
65
+ self.pos_embed: Optional[nn.Parameter] = None
66
+ if use_abs_pos:
67
+ # Initialize absolute positional embedding with pretrain image size.
68
+ self.pos_embed = nn.Parameter(
69
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70
+ )
71
+
72
+ self.blocks = nn.ModuleList()
73
+ for i in range(depth):
74
+ block = Block(
75
+ dim=embed_dim,
76
+ num_heads=num_heads,
77
+ mlp_ratio=mlp_ratio,
78
+ qkv_bias=qkv_bias,
79
+ norm_layer=norm_layer,
80
+ act_layer=act_layer,
81
+ use_rel_pos=use_rel_pos,
82
+ rel_pos_zero_init=rel_pos_zero_init,
83
+ window_size=window_size if i not in global_attn_indexes else 0,
84
+ input_size=(img_size // patch_size, img_size // patch_size),
85
+ )
86
+ self.blocks.append(block)
87
+
88
+ self.neck = nn.Sequential(
89
+ nn.Conv2d(
90
+ embed_dim,
91
+ out_chans,
92
+ kernel_size=1,
93
+ bias=False,
94
+ ),
95
+ LayerNorm2d(out_chans),
96
+ nn.Conv2d(
97
+ out_chans,
98
+ out_chans,
99
+ kernel_size=3,
100
+ padding=1,
101
+ bias=False,
102
+ ),
103
+ LayerNorm2d(out_chans),
104
+ )
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = self.patch_embed(x)
108
+ if self.pos_embed is not None:
109
+ x = x + self.pos_embed
110
+
111
+ for blk in self.blocks:
112
+ x = blk(x)
113
+
114
+ x = self.neck(x.permute(0, 3, 1, 2))
115
+
116
+ return x
117
+
118
+
119
+ class Block(nn.Module):
120
+ """Transformer blocks with support of window attention and residual propagation blocks"""
121
+
122
+ def __init__(
123
+ self,
124
+ dim: int,
125
+ num_heads: int,
126
+ mlp_ratio: float = 4.0,
127
+ qkv_bias: bool = True,
128
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
129
+ act_layer: Type[nn.Module] = nn.GELU,
130
+ use_rel_pos: bool = False,
131
+ rel_pos_zero_init: bool = True,
132
+ window_size: int = 0,
133
+ input_size: Optional[Tuple[int, int]] = None,
134
+ ) -> None:
135
+ """
136
+ Args:
137
+ dim (int): Number of input channels.
138
+ num_heads (int): Number of attention heads in each ViT block.
139
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
140
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
141
+ norm_layer (nn.Module): Normalization layer.
142
+ act_layer (nn.Module): Activation layer.
143
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
144
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
145
+ window_size (int): Window size for window attention blocks. If it equals 0, then
146
+ use global attention.
147
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
148
+ positional parameter size.
149
+ """
150
+ super().__init__()
151
+ self.norm1 = norm_layer(dim)
152
+ self.attn = Attention(
153
+ dim,
154
+ num_heads=num_heads,
155
+ qkv_bias=qkv_bias,
156
+ use_rel_pos=use_rel_pos,
157
+ rel_pos_zero_init=rel_pos_zero_init,
158
+ input_size=input_size if window_size == 0 else (window_size, window_size),
159
+ )
160
+
161
+ self.norm2 = norm_layer(dim)
162
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
163
+
164
+ self.window_size = window_size
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ shortcut = x
168
+ x = self.norm1(x)
169
+ # Window partition
170
+ if self.window_size > 0:
171
+ H, W = x.shape[1], x.shape[2]
172
+ x, pad_hw = window_partition(x, self.window_size)
173
+
174
+ x = self.attn(x)
175
+ # Reverse window partition
176
+ if self.window_size > 0:
177
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
178
+
179
+ x = shortcut + x
180
+ x = x + self.mlp(self.norm2(x))
181
+
182
+ return x
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """Multi-head Attention block with relative position embeddings."""
187
+
188
+ def __init__(
189
+ self,
190
+ dim: int,
191
+ num_heads: int = 8,
192
+ qkv_bias: bool = True,
193
+ use_rel_pos: bool = False,
194
+ rel_pos_zero_init: bool = True,
195
+ input_size: Optional[Tuple[int, int]] = None,
196
+ ) -> None:
197
+ """
198
+ Args:
199
+ dim (int): Number of input channels.
200
+ num_heads (int): Number of attention heads.
201
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
202
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
203
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
204
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
205
+ positional parameter size.
206
+ """
207
+ super().__init__()
208
+ self.num_heads = num_heads
209
+ head_dim = dim // num_heads
210
+ self.scale = head_dim**-0.5
211
+
212
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213
+ self.proj = nn.Linear(dim, dim)
214
+
215
+ self.use_rel_pos = use_rel_pos
216
+ if self.use_rel_pos:
217
+ assert (
218
+ input_size is not None
219
+ ), "Input size must be provided if using relative positional encoding."
220
+ # initialize relative positional embeddings
221
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
222
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
223
+
224
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
225
+ B, H, W, _ = x.shape
226
+ # qkv with shape (3, B, nHead, H * W, C)
227
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
228
+ # q, k, v with shape (B * nHead, H * W, C)
229
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
230
+
231
+ attn = (q * self.scale) @ k.transpose(-2, -1)
232
+
233
+ if self.use_rel_pos:
234
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
235
+
236
+ attn = attn.softmax(dim=-1)
237
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
238
+ x = self.proj(x)
239
+
240
+ return x
241
+
242
+
243
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
244
+ """
245
+ Partition into non-overlapping windows with padding if needed.
246
+ Args:
247
+ x (tensor): input tokens with [B, H, W, C].
248
+ window_size (int): window size.
249
+
250
+ Returns:
251
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
252
+ (Hp, Wp): padded height and width before partition
253
+ """
254
+ B, H, W, C = x.shape
255
+
256
+ pad_h = (window_size - H % window_size) % window_size
257
+ pad_w = (window_size - W % window_size) % window_size
258
+ if pad_h > 0 or pad_w > 0:
259
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
260
+ Hp, Wp = H + pad_h, W + pad_w
261
+
262
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
263
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
264
+ return windows, (Hp, Wp)
265
+
266
+
267
+ def window_unpartition(
268
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
269
+ ) -> torch.Tensor:
270
+ """
271
+ Window unpartition into original sequences and removing padding.
272
+ Args:
273
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
274
+ window_size (int): window size.
275
+ pad_hw (Tuple): padded height and width (Hp, Wp).
276
+ hw (Tuple): original height and width (H, W) before padding.
277
+
278
+ Returns:
279
+ x: unpartitioned sequences with [B, H, W, C].
280
+ """
281
+ Hp, Wp = pad_hw
282
+ H, W = hw
283
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286
+
287
+ if Hp > H or Wp > W:
288
+ x = x[:, :H, :W, :].contiguous()
289
+ return x
290
+
291
+
292
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293
+ """
294
+ Get relative positional embeddings according to the relative positions of
295
+ query and key sizes.
296
+ Args:
297
+ q_size (int): size of query q.
298
+ k_size (int): size of key k.
299
+ rel_pos (Tensor): relative position embeddings (L, C).
300
+
301
+ Returns:
302
+ Extracted positional embeddings according to relative positions.
303
+ """
304
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
305
+ # Interpolate rel pos if needed.
306
+ if rel_pos.shape[0] != max_rel_dist:
307
+ # Interpolate rel pos.
308
+ rel_pos_resized = F.interpolate(
309
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
310
+ size=max_rel_dist,
311
+ mode="linear",
312
+ )
313
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
314
+ else:
315
+ rel_pos_resized = rel_pos
316
+
317
+ # Scale the coords with short length if shapes for q and k are different.
318
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
320
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321
+
322
+ return rel_pos_resized[relative_coords.long()]
323
+
324
+
325
+ def add_decomposed_rel_pos(
326
+ attn: torch.Tensor,
327
+ q: torch.Tensor,
328
+ rel_pos_h: torch.Tensor,
329
+ rel_pos_w: torch.Tensor,
330
+ q_size: Tuple[int, int],
331
+ k_size: Tuple[int, int],
332
+ ) -> torch.Tensor:
333
+ """
334
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
335
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
336
+ Args:
337
+ attn (Tensor): attention map.
338
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
339
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
340
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
341
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
342
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
343
+
344
+ Returns:
345
+ attn (Tensor): attention map with added relative positional embeddings.
346
+ """
347
+ q_h, q_w = q_size
348
+ k_h, k_w = k_size
349
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
350
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
351
+
352
+ B, _, dim = q.shape
353
+ r_q = q.reshape(B, q_h, q_w, dim)
354
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
355
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
356
+
357
+ attn = (
358
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
359
+ ).view(B, q_h * q_w, k_h * k_w)
360
+
361
+ return attn
362
+
363
+
364
+ class PatchEmbed(nn.Module):
365
+ """
366
+ Image to Patch Embedding.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ kernel_size: Tuple[int, int] = (16, 16),
372
+ stride: Tuple[int, int] = (16, 16),
373
+ padding: Tuple[int, int] = (0, 0),
374
+ in_chans: int = 3,
375
+ embed_dim: int = 768,
376
+ ) -> None:
377
+ """
378
+ Args:
379
+ kernel_size (Tuple): kernel size of the projection layer.
380
+ stride (Tuple): stride of the projection layer.
381
+ padding (Tuple): padding size of the projection layer.
382
+ in_chans (int): Number of input image channels.
383
+ embed_dim (int): Patch embedding dimension.
384
+ """
385
+ super().__init__()
386
+
387
+ self.proj = nn.Conv2d(
388
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
389
+ )
390
+
391
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
392
+ x = self.proj(x)
393
+ # B C H W -> B H W C
394
+ x = x.permute(0, 2, 3, 1)
395
+ return x
models/modeling/mask_decoder.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import List, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ ) -> None:
27
+ """
28
+ Predicts masks given an image and prompt embeddings, using a
29
+ transformer architecture.
30
+
31
+ Arguments:
32
+ transformer_dim (int): the channel dimension of the transformer
33
+ transformer (nn.Module): the transformer used to predict masks
34
+ num_multimask_outputs (int): the number of masks to predict
35
+ when disambiguating masks
36
+ activation (nn.Module): the type of activation to use when
37
+ upscaling masks
38
+ iou_head_depth (int): the depth of the MLP used to predict
39
+ mask quality
40
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
41
+ used to predict mask quality
42
+ """
43
+ super().__init__()
44
+ self.transformer_dim = transformer_dim
45
+ self.transformer = transformer
46
+
47
+ self.num_multimask_outputs = num_multimask_outputs
48
+
49
+ self.iou_token = nn.Embedding(1, transformer_dim)
50
+ self.num_mask_tokens = num_multimask_outputs + 1
51
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52
+
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55
+ LayerNorm2d(transformer_dim // 4),
56
+ activation(),
57
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58
+ activation(),
59
+ )
60
+ self.output_hypernetworks_mlps = nn.ModuleList(
61
+ [
62
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63
+ for i in range(self.num_mask_tokens)
64
+ ]
65
+ )
66
+
67
+ self.iou_prediction_head = MLP(
68
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ image_embeddings: torch.Tensor,
74
+ image_pe: torch.Tensor,
75
+ sparse_prompt_embeddings: torch.Tensor,
76
+ dense_prompt_embeddings: torch.Tensor,
77
+ multimask_output: bool,
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ """
80
+ Predict masks given image and prompt embeddings.
81
+
82
+ Arguments:
83
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
84
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87
+ multimask_output (bool): Whether to return multiple masks or a single
88
+ mask.
89
+
90
+ Returns:
91
+ torch.Tensor: batched predicted masks
92
+ torch.Tensor: batched predictions of mask quality
93
+ """
94
+ masks, iou_pred, src = self.predict_masks(
95
+ image_embeddings=image_embeddings,
96
+ image_pe=image_pe,
97
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
98
+ dense_prompt_embeddings=dense_prompt_embeddings,
99
+ )
100
+
101
+ # Select the correct mask or masks for output
102
+ if multimask_output:
103
+ mask_slice = slice(1, None)
104
+ else:
105
+ mask_slice = slice(0, 1)
106
+ masks = masks[:, mask_slice, :, :]
107
+ iou_pred = iou_pred[:, mask_slice]
108
+
109
+ # Prepare output
110
+ return masks, iou_pred, src
111
+
112
+ def predict_masks(
113
+ self,
114
+ image_embeddings: torch.Tensor,
115
+ image_pe: torch.Tensor,
116
+ sparse_prompt_embeddings: torch.Tensor,
117
+ dense_prompt_embeddings: torch.Tensor,
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ """Predicts masks. See 'forward' for more details."""
120
+ # Concatenate output tokens
121
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124
+
125
+ # Expand per-image data in batch direction to be per-mask
126
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127
+ src = src + dense_prompt_embeddings
128
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129
+ b, c, h, w = src.shape
130
+
131
+ # Run the transformer
132
+ hs, src = self.transformer(src, pos_src, tokens)
133
+ iou_token_out = hs[:, 0, :]
134
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135
+
136
+ # Upscale mask embeddings and predict masks using the mask tokens
137
+ src = src.transpose(1, 2).view(b, c, h, w)
138
+ upscaled_embedding = self.output_upscaling(src)
139
+ hyper_in_list: List[torch.Tensor] = []
140
+ for i in range(self.num_mask_tokens):
141
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142
+ hyper_in = torch.stack(hyper_in_list, dim=1)
143
+ b, c, h, w = upscaled_embedding.shape
144
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145
+
146
+ # Generate mask quality predictions
147
+ iou_pred = self.iou_prediction_head(iou_token_out)
148
+
149
+ return masks, iou_pred, src
150
+
151
+
152
+ # Lightly adapted from
153
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154
+ class MLP(nn.Module):
155
+ def __init__(
156
+ self,
157
+ input_dim: int,
158
+ hidden_dim: int,
159
+ output_dim: int,
160
+ num_layers: int,
161
+ sigmoid_output: bool = False,
162
+ ) -> None:
163
+ super().__init__()
164
+ self.num_layers = num_layers
165
+ h = [hidden_dim] * (num_layers - 1)
166
+ self.layers = nn.ModuleList(
167
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168
+ )
169
+ self.sigmoid_output = sigmoid_output
170
+
171
+ def forward(self, x):
172
+ for i, layer in enumerate(self.layers):
173
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174
+ if self.sigmoid_output:
175
+ x = F.sigmoid(x)
176
+ return x
models/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from typing import Any, Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class PromptEncoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ image_embedding_size: Tuple[int, int],
21
+ input_image_size: Tuple[int, int],
22
+ mask_in_chans: int,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ ) -> None:
25
+ """
26
+ Encodes prompts for input to SAM's mask decoder.
27
+
28
+ Arguments:
29
+ embed_dim (int): The prompts' embedding dimension
30
+ image_embedding_size (tuple(int, int)): The spatial size of the
31
+ image embedding, as (H, W).
32
+ input_image_size (int): The padded size of the image as input
33
+ to the image encoder, as (H, W).
34
+ mask_in_chans (int): The number of hidden channels used for
35
+ encoding input masks.
36
+ activation (nn.Module): The activation to use when encoding
37
+ input masks.
38
+ """
39
+ super().__init__()
40
+ self.embed_dim = embed_dim
41
+ self.input_image_size = input_image_size
42
+ self.image_embedding_size = image_embedding_size
43
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
+
45
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47
+ self.point_embeddings = nn.ModuleList(point_embeddings)
48
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
49
+
50
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51
+ self.mask_downscaling = nn.Sequential(
52
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53
+ LayerNorm2d(mask_in_chans // 4),
54
+ activation(),
55
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56
+ LayerNorm2d(mask_in_chans),
57
+ activation(),
58
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59
+ )
60
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
61
+
62
+ def get_dense_pe(self) -> torch.Tensor:
63
+ """
64
+ Returns the positional encoding used to encode point prompts,
65
+ applied to a dense set of points the shape of the image encoding.
66
+
67
+ Returns:
68
+ torch.Tensor: Positional encoding with shape
69
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
70
+ """
71
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72
+
73
+ def _embed_points(
74
+ self,
75
+ points: torch.Tensor,
76
+ labels: torch.Tensor,
77
+ pad: bool,
78
+ ) -> torch.Tensor:
79
+ """Embeds point prompts."""
80
+ points = points + 0.5 # Shift to center of pixel
81
+ if pad:
82
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84
+ points = torch.cat([points, padding_point], dim=1)
85
+ labels = torch.cat([labels, padding_label], dim=1)
86
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87
+ point_embedding[labels == -1] = 0.0
88
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
89
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
90
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
91
+ return point_embedding
92
+
93
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94
+ """Embeds box prompts."""
95
+ boxes = boxes + 0.5 # Shift to center of pixel
96
+ coords = boxes.reshape(-1, 2, 2)
97
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100
+ return corner_embedding
101
+
102
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103
+ """Embeds mask inputs."""
104
+ mask_embedding = self.mask_downscaling(masks)
105
+ return mask_embedding
106
+
107
+ def _get_batch_size(
108
+ self,
109
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110
+ boxes: Optional[torch.Tensor],
111
+ masks: Optional[torch.Tensor],
112
+ ) -> int:
113
+ """
114
+ Gets the batch size of the output given the batch size of the input prompts.
115
+ """
116
+ if points is not None:
117
+ return points[0].shape[0]
118
+ elif boxes is not None:
119
+ return boxes.shape[0]
120
+ elif masks is not None:
121
+ return masks.shape[0]
122
+ else:
123
+ return 1
124
+
125
+ def _get_device(self) -> torch.device:
126
+ return self.point_embeddings[0].weight.device
127
+
128
+ def forward(
129
+ self,
130
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131
+ boxes: Optional[torch.Tensor],
132
+ masks: Optional[torch.Tensor],
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ """
135
+ Embeds different types of prompts, returning both sparse and dense
136
+ embeddings.
137
+
138
+ Arguments:
139
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140
+ and labels to embed.
141
+ boxes (torch.Tensor or none): boxes to embed
142
+ masks (torch.Tensor or none): masks to embed
143
+
144
+ Returns:
145
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
146
+ BxNx(embed_dim), where N is determined by the number of input points
147
+ and boxes.
148
+ torch.Tensor: dense embeddings for the masks, in the shape
149
+ Bx(embed_dim)x(embed_H)x(embed_W)
150
+ """
151
+ bs = self._get_batch_size(points, boxes, masks)
152
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153
+ if points is not None:
154
+ coords, labels = points
155
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157
+ if boxes is not None:
158
+ box_embeddings = self._embed_boxes(boxes)
159
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160
+
161
+ if masks is not None:
162
+ dense_embeddings = self._embed_masks(masks)
163
+ else:
164
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166
+ )
167
+
168
+ return sparse_embeddings, dense_embeddings
169
+
170
+
171
+ class PositionEmbeddingRandom(nn.Module):
172
+ """
173
+ Positional encoding using random spatial frequencies.
174
+ """
175
+
176
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177
+ super().__init__()
178
+ if scale is None or scale <= 0.0:
179
+ scale = 1.0
180
+ self.register_buffer(
181
+ "positional_encoding_gaussian_matrix",
182
+ scale * torch.randn((2, num_pos_feats)),
183
+ )
184
+
185
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186
+ """Positionally encode points that are normalized to [0,1]."""
187
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188
+ coords = 2 * coords - 1
189
+ coords = coords @ self.positional_encoding_gaussian_matrix
190
+ coords = 2 * np.pi * coords
191
+ # outputs d_1 x ... x d_n x C shape
192
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193
+
194
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195
+ """Generate positional encoding for a grid of the specified size."""
196
+ h, w = size
197
+ device: Any = self.positional_encoding_gaussian_matrix.device
198
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
199
+ y_embed = grid.cumsum(dim=0) - 0.5
200
+ x_embed = grid.cumsum(dim=1) - 0.5
201
+ y_embed = y_embed / h
202
+ x_embed = x_embed / w
203
+
204
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205
+ return pe.permute(2, 0, 1) # C x H x W
206
+
207
+ def forward_with_coords(
208
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209
+ ) -> torch.Tensor:
210
+ """Positionally encode points that are not normalized to [0,1]."""
211
+ coords = coords_input.clone()
212
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
models/modeling/sam.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ from .image_encoder import ImageEncoderViT
14
+ from .mask_decoder import MaskDecoder
15
+ from .prompt_encoder import PromptEncoder
16
+
17
+
18
+ class Sam(nn.Module):
19
+ mask_threshold: float = 0.0
20
+ image_format: str = "RGB"
21
+
22
+ def __init__(
23
+ self,
24
+ image_encoder: ImageEncoderViT,
25
+ prompt_encoder: PromptEncoder,
26
+ mask_decoder: MaskDecoder,
27
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
28
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
29
+ ) -> None:
30
+ """
31
+ SAM predicts object masks from an image and input prompts.
32
+
33
+ Arguments:
34
+ image_encoder (ImageEncoderViT): The backbone used to encode the
35
+ image into image embeddings that allow for efficient mask prediction.
36
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38
+ and encoded prompts.
39
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
41
+ """
42
+ super().__init__()
43
+ self.image_encoder = image_encoder
44
+ self.prompt_encoder = prompt_encoder
45
+ self.mask_decoder = mask_decoder
46
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48
+
49
+ @property
50
+ def device(self) -> Any:
51
+ return self.pixel_mean.device
52
+
53
+ @torch.no_grad()
54
+ def forward(
55
+ self,
56
+ batched_input: List[Dict[str, Any]],
57
+ multimask_output: bool,
58
+ ) -> List[Dict[str, torch.Tensor]]:
59
+ """
60
+ Predicts masks end-to-end from provided images and prompts.
61
+ If prompts are not known in advance, using SamPredictor is
62
+ recommended over calling the model directly.
63
+
64
+ Arguments:
65
+ batched_input (list(dict)): A list over input images, each a
66
+ dictionary with the following keys. A prompt key can be
67
+ excluded if it is not present.
68
+ 'image': The image as a torch tensor in 3xHxW format,
69
+ already transformed for input to the model.
70
+ 'original_size': (tuple(int, int)) The original size of
71
+ the image before transformation, as (H, W).
72
+ 'point_coords': (torch.Tensor) Batched point prompts for
73
+ this image, with shape BxNx2. Already transformed to the
74
+ input frame of the model.
75
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
76
+ with shape BxN.
77
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78
+ Already transformed to the input frame of the model.
79
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80
+ in the form Bx1xHxW.
81
+ multimask_output (bool): Whether the model should predict multiple
82
+ disambiguating masks, or return a single mask.
83
+
84
+ Returns:
85
+ (list(dict)): A list over input images, where each element is
86
+ as dictionary with the following keys.
87
+ 'masks': (torch.Tensor) Batched binary mask predictions,
88
+ with shape BxCxHxW, where B is the number of input prompts,
89
+ C is determined by multimask_output, and (H, W) is the
90
+ original size of the image.
91
+ 'iou_predictions': (torch.Tensor) The model's predictions
92
+ of mask quality, in shape BxC.
93
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
94
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
95
+ to subsequent iterations of prediction.
96
+ """
97
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98
+ image_embeddings = self.image_encoder(input_images)
99
+
100
+ outputs = []
101
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
102
+ if "point_coords" in image_record:
103
+ points = (image_record["point_coords"], image_record["point_labels"])
104
+ else:
105
+ points = None
106
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
107
+ points=points,
108
+ boxes=image_record.get("boxes", None),
109
+ masks=image_record.get("mask_inputs", None),
110
+ )
111
+ low_res_masks, iou_predictions = self.mask_decoder(
112
+ image_embeddings=curr_embedding.unsqueeze(0),
113
+ image_pe=self.prompt_encoder.get_dense_pe(),
114
+ sparse_prompt_embeddings=sparse_embeddings,
115
+ dense_prompt_embeddings=dense_embeddings,
116
+ multimask_output=multimask_output,
117
+ )
118
+ masks = self.postprocess_masks(
119
+ low_res_masks,
120
+ input_size=image_record["image"].shape[-2:],
121
+ original_size=image_record["original_size"],
122
+ )
123
+ masks = masks > self.mask_threshold
124
+ outputs.append(
125
+ {
126
+ "masks": masks,
127
+ "iou_predictions": iou_predictions,
128
+ "low_res_logits": low_res_masks,
129
+ }
130
+ )
131
+ return outputs
132
+
133
+ def postprocess_masks(
134
+ self,
135
+ masks: torch.Tensor,
136
+ input_size: Tuple[int, ...],
137
+ original_size: Tuple[int, ...],
138
+ ) -> torch.Tensor:
139
+ """
140
+ Remove padding and upscale masks to the original image size.
141
+
142
+ Arguments:
143
+ masks (torch.Tensor): Batched masks from the mask_decoder,
144
+ in BxCxHxW format.
145
+ input_size (tuple(int, int)): The size of the image input to the
146
+ model, in (H, W) format. Used to remove padding.
147
+ original_size (tuple(int, int)): The original size of the image
148
+ before resizing for input to the model, in (H, W) format.
149
+
150
+ Returns:
151
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
152
+ is given by original_size.
153
+ """
154
+ masks = F.interpolate(
155
+ masks,
156
+ (self.image_encoder.img_size, self.image_encoder.img_size),
157
+ mode="bilinear",
158
+ align_corners=False,
159
+ )
160
+ masks = masks[..., : input_size[0], : input_size[1]]
161
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
162
+ return masks
163
+
164
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
165
+ """Normalize pixel values and pad to a square input."""
166
+ # Normalize colors
167
+ x = (x - self.pixel_mean) / self.pixel_std
168
+
169
+ # Pad
170
+ h, w = x.shape[-2:]
171
+ padh = self.image_encoder.img_size - h
172
+ padw = self.image_encoder.img_size - w
173
+ x = F.pad(x, (0, padw, 0, padh))
174
+ return x
models/modeling/transformer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ import math
11
+ from typing import Tuple, Type
12
+
13
+ from .common import MLPBlock
14
+
15
+
16
+ class TwoWayTransformer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ depth: int,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ mlp_dim: int,
23
+ activation: Type[nn.Module] = nn.ReLU,
24
+ attention_downsample_rate: int = 2,
25
+ ) -> None:
26
+ """
27
+ A transformer decoder that attends to an input image using
28
+ queries whose positional embedding is supplied.
29
+
30
+ Args:
31
+ depth (int): number of layers in the transformer
32
+ embedding_dim (int): the channel dimension for the input embeddings
33
+ num_heads (int): the number of heads for multihead attention. Must
34
+ divide embedding_dim
35
+ mlp_dim (int): the channel dimension internal to the MLP block
36
+ activation (nn.Module): the activation to use in the MLP block
37
+ """
38
+ super().__init__()
39
+ self.depth = depth
40
+ self.embedding_dim = embedding_dim
41
+ self.num_heads = num_heads
42
+ self.mlp_dim = mlp_dim
43
+ self.layers = nn.ModuleList()
44
+
45
+ for i in range(depth):
46
+ self.layers.append(
47
+ TwoWayAttentionBlock(
48
+ embedding_dim=embedding_dim,
49
+ num_heads=num_heads,
50
+ mlp_dim=mlp_dim,
51
+ activation=activation,
52
+ attention_downsample_rate=attention_downsample_rate,
53
+ skip_first_layer_pe=(i == 0),
54
+ )
55
+ )
56
+
57
+ self.final_attn_token_to_image = Attention(
58
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59
+ )
60
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
61
+
62
+ def forward(
63
+ self,
64
+ image_embedding: Tensor,
65
+ image_pe: Tensor,
66
+ point_embedding: Tensor,
67
+ ) -> Tuple[Tensor, Tensor]:
68
+ """
69
+ Args:
70
+ image_embedding (torch.Tensor): image to attend to. Should be shape
71
+ B x embedding_dim x h x w for any h and w.
72
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
73
+ have the same shape as image_embedding.
74
+ point_embedding (torch.Tensor): the embedding to add to the query points.
75
+ Must have shape B x N_points x embedding_dim for any N_points.
76
+
77
+ Returns:
78
+ torch.Tensor: the processed point_embedding
79
+ torch.Tensor: the processed image_embedding
80
+ """
81
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82
+ bs, c, h, w = image_embedding.shape
83
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
85
+
86
+ # Prepare queries
87
+ queries = point_embedding
88
+ keys = image_embedding
89
+
90
+ # Apply transformer blocks and final layernorm
91
+ for layer in self.layers:
92
+ queries, keys = layer(
93
+ queries=queries,
94
+ keys=keys,
95
+ query_pe=point_embedding,
96
+ key_pe=image_pe,
97
+ )
98
+
99
+ # Apply the final attention layer from the points to the image
100
+ q = queries + point_embedding
101
+ k = keys + image_pe
102
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103
+ queries = queries + attn_out
104
+ queries = self.norm_final_attn(queries)
105
+
106
+ return queries, keys
107
+
108
+
109
+ class TwoWayAttentionBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ embedding_dim: int,
113
+ num_heads: int,
114
+ mlp_dim: int = 2048,
115
+ activation: Type[nn.Module] = nn.ReLU,
116
+ attention_downsample_rate: int = 2,
117
+ skip_first_layer_pe: bool = False,
118
+ ) -> None:
119
+ """
120
+ A transformer block with four layers: (1) self-attention of sparse
121
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
123
+ inputs.
124
+
125
+ Arguments:
126
+ embedding_dim (int): the channel dimension of the embeddings
127
+ num_heads (int): the number of heads in the attention layers
128
+ mlp_dim (int): the hidden dimension of the mlp block
129
+ activation (nn.Module): the activation of the mlp block
130
+ skip_first_layer_pe (bool): skip the PE on the first layer
131
+ """
132
+ super().__init__()
133
+ self.self_attn = Attention(embedding_dim, num_heads)
134
+ self.norm1 = nn.LayerNorm(embedding_dim)
135
+
136
+ self.cross_attn_token_to_image = Attention(
137
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142
+ self.norm3 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.norm4 = nn.LayerNorm(embedding_dim)
145
+ self.cross_attn_image_to_token = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+
149
+ self.skip_first_layer_pe = skip_first_layer_pe
150
+
151
+ def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153
+ ) -> Tuple[Tensor, Tensor]:
154
+ # Self attention block
155
+ if self.skip_first_layer_pe:
156
+ queries = self.self_attn(q=queries, k=queries, v=queries)
157
+ else:
158
+ q = queries + query_pe
159
+ attn_out = self.self_attn(q=q, k=q, v=queries)
160
+ queries = queries + attn_out
161
+ queries = self.norm1(queries)
162
+
163
+ # Cross attention block, tokens attending to image embedding
164
+ q = queries + query_pe
165
+ k = keys + key_pe
166
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167
+ queries = queries + attn_out
168
+ queries = self.norm2(queries)
169
+
170
+ # MLP block
171
+ mlp_out = self.mlp(queries)
172
+ queries = queries + mlp_out
173
+ queries = self.norm3(queries)
174
+
175
+ # Cross attention block, image embedding attending to tokens
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179
+ keys = keys + attn_out
180
+ keys = self.norm4(keys)
181
+
182
+ return queries, keys
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """
187
+ An attention layer that allows for downscaling the size of the embedding
188
+ after projection to queries, keys, and values.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ embedding_dim: int,
194
+ num_heads: int,
195
+ downsample_rate: int = 1,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.embedding_dim = embedding_dim
199
+ self.internal_dim = embedding_dim // downsample_rate
200
+ self.num_heads = num_heads
201
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202
+
203
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207
+
208
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209
+ b, n, c = x.shape
210
+ x = x.reshape(b, n, num_heads, c // num_heads)
211
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212
+
213
+ def _recombine_heads(self, x: Tensor) -> Tensor:
214
+ b, n_heads, n_tokens, c_per_head = x.shape
215
+ x = x.transpose(1, 2)
216
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217
+
218
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219
+ # Input projections
220
+ q = self.q_proj(q)
221
+ k = self.k_proj(k)
222
+ v = self.v_proj(v)
223
+
224
+ # Separate into heads
225
+ q = self._separate_heads(q, self.num_heads)
226
+ k = self._separate_heads(k, self.num_heads)
227
+ v = self._separate_heads(v, self.num_heads)
228
+
229
+ # Attention
230
+ _, _, _, c_per_head = q.shape
231
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232
+ attn = attn / math.sqrt(c_per_head)
233
+ attn = torch.softmax(attn, dim=-1)
234
+
235
+ # Get output
236
+ out = attn @ v
237
+ out = self._recombine_heads(out)
238
+ out = self.out_proj(out)
239
+
240
+ return out
models/predictor.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from models.modeling import Sam
11
+
12
+ from typing import Optional, Tuple
13
+
14
+ from .utils.transforms import ResizeLongestSide
15
+
16
+
17
+ class SamPredictor:
18
+ def __init__(
19
+ self,
20
+ sam_model: Sam,
21
+ ) -> None:
22
+ """
23
+ Uses SAM to calculate the image embedding for an image, and then
24
+ allow repeated, efficient mask prediction given prompts.
25
+
26
+ Arguments:
27
+ sam_model (Sam): The model to use for mask prediction.
28
+ """
29
+ super().__init__()
30
+ self.model = sam_model
31
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32
+ self.reset_image()
33
+
34
+ def set_image(
35
+ self,
36
+ image: np.ndarray,
37
+ image_format: str = "RGB",
38
+ ) -> None:
39
+ """
40
+ Calculates the image embeddings for the provided image, allowing
41
+ masks to be predicted with the 'predict' method.
42
+
43
+ Arguments:
44
+ image (np.ndarray): The image for calculating masks. Expects an
45
+ image in HWC uint8 format, with pixel values in [0, 255].
46
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
47
+ """
48
+ assert image_format in [
49
+ "RGB",
50
+ "BGR",
51
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52
+ if image_format != self.model.image_format:
53
+ image = image[..., ::-1]
54
+
55
+ # Transform the image to the form expected by the model
56
+ input_image = self.transform.apply_image(image)
57
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
58
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59
+
60
+ self.set_torch_image(input_image_torch, image.shape[:2])
61
+
62
+ @torch.no_grad()
63
+ def set_torch_image(
64
+ self,
65
+ transformed_image: torch.Tensor,
66
+ original_image_size: Tuple[int, ...],
67
+ ) -> None:
68
+ """
69
+ Calculates the image embeddings for the provided image, allowing
70
+ masks to be predicted with the 'predict' method. Expects the input
71
+ image to be already transformed to the format expected by the model.
72
+
73
+ Arguments:
74
+ transformed_image (torch.Tensor): The input image, with shape
75
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
76
+ original_image_size (tuple(int, int)): The size of the image
77
+ before transformation, in (H, W) format.
78
+ """
79
+ assert (
80
+ len(transformed_image.shape) == 4
81
+ and transformed_image.shape[1] == 3
82
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84
+ self.reset_image()
85
+
86
+ self.original_size = original_image_size
87
+ self.input_size = tuple(transformed_image.shape[-2:])
88
+ input_image = self.model.preprocess(transformed_image)
89
+ self.features = self.model.image_encoder(input_image)
90
+ self.is_image_set = True
91
+
92
+ def predict(
93
+ self,
94
+ point_coords: Optional[np.ndarray] = None,
95
+ point_labels: Optional[np.ndarray] = None,
96
+ box: Optional[np.ndarray] = None,
97
+ mask_input: Optional[np.ndarray] = None,
98
+ multimask_output: bool = True,
99
+ return_logits: bool = False,
100
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101
+ """
102
+ Predict masks for the given input prompts, using the currently set image.
103
+
104
+ Arguments:
105
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106
+ model. Each point is in (X,Y) in pixels.
107
+ point_labels (np.ndarray or None): A length N array of labels for the
108
+ point prompts. 1 indicates a foreground point and 0 indicates a
109
+ background point.
110
+ box (np.ndarray or None): A length 4 array given a box prompt to the
111
+ model, in XYXY format.
112
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
113
+ coming from a previous prediction iteration. Has form 1xHxW, where
114
+ for SAM, H=W=256.
115
+ multimask_output (bool): If true, the model will return three masks.
116
+ For ambiguous input prompts (such as a single click), this will often
117
+ produce better masks than a single prediction. If only a single
118
+ mask is needed, the model's predicted quality score can be used
119
+ to select the best mask. For non-ambiguous prompts, such as multiple
120
+ input prompts, multimask_output=False can give better results.
121
+ return_logits (bool): If true, returns un-thresholded masks logits
122
+ instead of a binary mask.
123
+
124
+ Returns:
125
+ (np.ndarray): The output masks in CxHxW format, where C is the
126
+ number of masks, and (H, W) is the original image size.
127
+ (np.ndarray): An array of length C containing the model's
128
+ predictions for the quality of each mask.
129
+ (np.ndarray): An array of shape CxHxW, where C is the number
130
+ of masks and H=W=256. These low resolution logits can be passed to
131
+ a subsequent iteration as mask input.
132
+ """
133
+ if not self.is_image_set:
134
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135
+
136
+ # Transform input prompts
137
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138
+ if point_coords is not None:
139
+ assert (
140
+ point_labels is not None
141
+ ), "point_labels must be supplied if point_coords is supplied."
142
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
143
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146
+ if box is not None:
147
+ box = self.transform.apply_boxes(box, self.original_size)
148
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149
+ box_torch = box_torch[None, :]
150
+ if mask_input is not None:
151
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152
+ mask_input_torch = mask_input_torch[None, :, :, :]
153
+
154
+ masks, iou_predictions, low_res_masks = self.predict_torch(
155
+ coords_torch,
156
+ labels_torch,
157
+ box_torch,
158
+ mask_input_torch,
159
+ multimask_output,
160
+ return_logits=return_logits,
161
+ )
162
+
163
+ masks_np = masks[0].detach().cpu().numpy()
164
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166
+ return masks_np, iou_predictions_np, low_res_masks_np
167
+
168
+ @torch.no_grad()
169
+ def predict_torch(
170
+ self,
171
+ point_coords: Optional[torch.Tensor],
172
+ point_labels: Optional[torch.Tensor],
173
+ boxes: Optional[torch.Tensor] = None,
174
+ mask_input: Optional[torch.Tensor] = None,
175
+ multimask_output: bool = True,
176
+ return_logits: bool = False,
177
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ """
179
+ Predict masks for the given input prompts, using the currently set image.
180
+ Input prompts are batched torch tensors and are expected to already be
181
+ transformed to the input frame using ResizeLongestSide.
182
+
183
+ Arguments:
184
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185
+ model. Each point is in (X,Y) in pixels.
186
+ point_labels (torch.Tensor or None): A BxN array of labels for the
187
+ point prompts. 1 indicates a foreground point and 0 indicates a
188
+ background point.
189
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
190
+ model, in XYXY format.
191
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
192
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
193
+ for SAM, H=W=256. Masks returned by a previous iteration of the
194
+ predict method do not need further transformation.
195
+ multimask_output (bool): If true, the model will return three masks.
196
+ For ambiguous input prompts (such as a single click), this will often
197
+ produce better masks than a single prediction. If only a single
198
+ mask is needed, the model's predicted quality score can be used
199
+ to select the best mask. For non-ambiguous prompts, such as multiple
200
+ input prompts, multimask_output=False can give better results.
201
+ return_logits (bool): If true, returns un-thresholded masks logits
202
+ instead of a binary mask.
203
+
204
+ Returns:
205
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
206
+ number of masks, and (H, W) is the original image size.
207
+ (torch.Tensor): An array of shape BxC containing the model's
208
+ predictions for the quality of each mask.
209
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
210
+ of masks and H=W=256. These low res logits can be passed to
211
+ a subsequent iteration as mask input.
212
+ """
213
+ if not self.is_image_set:
214
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215
+
216
+ if point_coords is not None:
217
+ points = (point_coords, point_labels)
218
+ else:
219
+ points = None
220
+
221
+ # Embed prompts
222
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223
+ points=points,
224
+ boxes=boxes,
225
+ masks=mask_input,
226
+ )
227
+
228
+ # Predict masks
229
+ low_res_masks, iou_predictions = self.model.mask_decoder(
230
+ image_embeddings=self.features,
231
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
232
+ sparse_prompt_embeddings=sparse_embeddings,
233
+ dense_prompt_embeddings=dense_embeddings,
234
+ multimask_output=multimask_output,
235
+ )
236
+
237
+ # Upscale the masks to the original image resolution
238
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239
+
240
+ if not return_logits:
241
+ masks = masks > self.model.mask_threshold
242
+
243
+ return masks, iou_predictions, low_res_masks
244
+
245
+ def get_image_embedding(self) -> torch.Tensor:
246
+ """
247
+ Returns the image embeddings for the currently set image, with
248
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
250
+ """
251
+ if not self.is_image_set:
252
+ raise RuntimeError(
253
+ "An image must be set with .set_image(...) to generate an embedding."
254
+ )
255
+ assert self.features is not None, "Features must exist if an image has been set."
256
+ return self.features
257
+
258
+ @property
259
+ def device(self) -> torch.device:
260
+ return self.model.device
261
+
262
+ def reset_image(self) -> None:
263
+ """Resets the currently set image."""
264
+ self.is_image_set = False
265
+ self.features = None
266
+ self.orig_h = None
267
+ self.orig_w = None
268
+ self.input_h = None
269
+ self.input_w = None
models/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
models/utils/amg.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ import math
11
+ from copy import deepcopy
12
+ from itertools import product
13
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
14
+
15
+
16
+ class MaskData:
17
+ """
18
+ A structure for storing masks and their related data in batched format.
19
+ Implements basic filtering and concatenation.
20
+ """
21
+
22
+ def __init__(self, **kwargs) -> None:
23
+ for v in kwargs.values():
24
+ assert isinstance(
25
+ v, (list, np.ndarray, torch.Tensor)
26
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
27
+ self._stats = dict(**kwargs)
28
+
29
+ def __setitem__(self, key: str, item: Any) -> None:
30
+ assert isinstance(
31
+ item, (list, np.ndarray, torch.Tensor)
32
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
33
+ self._stats[key] = item
34
+
35
+ def __delitem__(self, key: str) -> None:
36
+ del self._stats[key]
37
+
38
+ def __getitem__(self, key: str) -> Any:
39
+ return self._stats[key]
40
+
41
+ def items(self) -> ItemsView[str, Any]:
42
+ return self._stats.items()
43
+
44
+ def filter(self, keep: torch.Tensor) -> None:
45
+ for k, v in self._stats.items():
46
+ if v is None:
47
+ self._stats[k] = None
48
+ elif isinstance(v, torch.Tensor):
49
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50
+ elif isinstance(v, np.ndarray):
51
+ self._stats[k] = v[keep.detach().cpu().numpy()]
52
+ elif isinstance(v, list) and keep.dtype == torch.bool:
53
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54
+ elif isinstance(v, list):
55
+ self._stats[k] = [v[i] for i in keep]
56
+ else:
57
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58
+
59
+ def cat(self, new_stats: "MaskData") -> None:
60
+ for k, v in new_stats.items():
61
+ if k not in self._stats or self._stats[k] is None:
62
+ self._stats[k] = deepcopy(v)
63
+ elif isinstance(v, torch.Tensor):
64
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65
+ elif isinstance(v, np.ndarray):
66
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67
+ elif isinstance(v, list):
68
+ self._stats[k] = self._stats[k] + deepcopy(v)
69
+ else:
70
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71
+
72
+ def to_numpy(self) -> None:
73
+ for k, v in self._stats.items():
74
+ if isinstance(v, torch.Tensor):
75
+ self._stats[k] = v.detach().cpu().numpy()
76
+
77
+
78
+ def is_box_near_crop_edge(
79
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80
+ ) -> torch.Tensor:
81
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
82
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88
+ return torch.any(near_crop_edge, dim=1)
89
+
90
+
91
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92
+ box_xywh = deepcopy(box_xyxy)
93
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
94
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
95
+ return box_xywh
96
+
97
+
98
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99
+ assert len(args) > 0 and all(
100
+ len(a) == len(args[0]) for a in args
101
+ ), "Batched iteration must have inputs of all the same size."
102
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103
+ for b in range(n_batches):
104
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105
+
106
+
107
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108
+ """
109
+ Encodes masks to an uncompressed RLE, in the format expected by
110
+ pycoco tools.
111
+ """
112
+ # Put in fortran order and flatten h,w
113
+ b, h, w = tensor.shape
114
+ tensor = tensor.permute(0, 2, 1).flatten(1)
115
+
116
+ # Compute change indices
117
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
118
+ change_indices = diff.nonzero()
119
+
120
+ # Encode run length
121
+ out = []
122
+ for i in range(b):
123
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124
+ cur_idxs = torch.cat(
125
+ [
126
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127
+ cur_idxs + 1,
128
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129
+ ]
130
+ )
131
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132
+ counts = [] if tensor[i, 0] == 0 else [0]
133
+ counts.extend(btw_idxs.detach().cpu().tolist())
134
+ out.append({"size": [h, w], "counts": counts})
135
+ return out
136
+
137
+
138
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139
+ """Compute a binary mask from an uncompressed RLE."""
140
+ h, w = rle["size"]
141
+ mask = np.empty(h * w, dtype=bool)
142
+ idx = 0
143
+ parity = False
144
+ for count in rle["counts"]:
145
+ mask[idx : idx + count] = parity
146
+ idx += count
147
+ parity ^= True
148
+ mask = mask.reshape(w, h)
149
+ return mask.transpose() # Put in C order
150
+
151
+
152
+ def area_from_rle(rle: Dict[str, Any]) -> int:
153
+ return sum(rle["counts"][1::2])
154
+
155
+
156
+ def calculate_stability_score(
157
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158
+ ) -> torch.Tensor:
159
+ """
160
+ Computes the stability score for a batch of masks. The stability
161
+ score is the IoU between the binary masks obtained by thresholding
162
+ the predicted mask logits at high and low values.
163
+ """
164
+ # One mask is always contained inside the other.
165
+ # Save memory by preventing unnecessary cast to torch.int64
166
+ intersections = (
167
+ (masks > (mask_threshold + threshold_offset))
168
+ .sum(-1, dtype=torch.int16)
169
+ .sum(-1, dtype=torch.int32)
170
+ )
171
+ unions = (
172
+ (masks > (mask_threshold - threshold_offset))
173
+ .sum(-1, dtype=torch.int16)
174
+ .sum(-1, dtype=torch.int32)
175
+ )
176
+ return intersections / unions
177
+
178
+
179
+ def build_point_grid(n_per_side: int) -> np.ndarray:
180
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181
+ offset = 1 / (2 * n_per_side)
182
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186
+ return points
187
+
188
+
189
+ def build_all_layer_point_grids(
190
+ n_per_side: int, n_layers: int, scale_per_layer: int
191
+ ) -> List[np.ndarray]:
192
+ """Generates point grids for all crop layers."""
193
+ points_by_layer = []
194
+ for i in range(n_layers + 1):
195
+ n_points = int(n_per_side / (scale_per_layer**i))
196
+ points_by_layer.append(build_point_grid(n_points))
197
+ return points_by_layer
198
+
199
+
200
+ def generate_crop_boxes(
201
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202
+ ) -> Tuple[List[List[int]], List[int]]:
203
+ """
204
+ Generates a list of crop boxes of different sizes. Each layer
205
+ has (2**i)**2 boxes for the ith layer.
206
+ """
207
+ crop_boxes, layer_idxs = [], []
208
+ im_h, im_w = im_size
209
+ short_side = min(im_h, im_w)
210
+
211
+ # Original image
212
+ crop_boxes.append([0, 0, im_w, im_h])
213
+ layer_idxs.append(0)
214
+
215
+ def crop_len(orig_len, n_crops, overlap):
216
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217
+
218
+ for i_layer in range(n_layers):
219
+ n_crops_per_side = 2 ** (i_layer + 1)
220
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221
+
222
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
223
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
224
+
225
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227
+
228
+ # Crops in XYWH format
229
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
230
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231
+ crop_boxes.append(box)
232
+ layer_idxs.append(i_layer + 1)
233
+
234
+ return crop_boxes, layer_idxs
235
+
236
+
237
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238
+ x0, y0, _, _ = crop_box
239
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240
+ # Check if boxes has a channel dimension
241
+ if len(boxes.shape) == 3:
242
+ offset = offset.unsqueeze(1)
243
+ return boxes + offset
244
+
245
+
246
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247
+ x0, y0, _, _ = crop_box
248
+ offset = torch.tensor([[x0, y0]], device=points.device)
249
+ # Check if points has a channel dimension
250
+ if len(points.shape) == 3:
251
+ offset = offset.unsqueeze(1)
252
+ return points + offset
253
+
254
+
255
+ def uncrop_masks(
256
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257
+ ) -> torch.Tensor:
258
+ x0, y0, x1, y1 = crop_box
259
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260
+ return masks
261
+ # Coordinate transform masks
262
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
264
+ return torch.nn.functional.pad(masks, pad, value=0)
265
+
266
+
267
+ def remove_small_regions(
268
+ mask: np.ndarray, area_thresh: float, mode: str
269
+ ) -> Tuple[np.ndarray, bool]:
270
+ """
271
+ Removes small disconnected regions and holes in a mask. Returns the
272
+ mask and an indicator of if the mask has been modified.
273
+ """
274
+ import cv2 # type: ignore
275
+
276
+ assert mode in ["holes", "islands"]
277
+ correct_holes = mode == "holes"
278
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
279
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280
+ sizes = stats[:, -1][1:] # Row 0 is background label
281
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282
+ if len(small_regions) == 0:
283
+ return mask, False
284
+ fill_labels = [0] + small_regions
285
+ if not correct_holes:
286
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287
+ # If every region is below threshold, keep largest
288
+ if len(fill_labels) == 0:
289
+ fill_labels = [int(np.argmax(sizes)) + 1]
290
+ mask = np.isin(regions, fill_labels)
291
+ return mask, True
292
+
293
+
294
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295
+ from pycocotools import mask as mask_utils # type: ignore
296
+
297
+ h, w = uncompressed_rle["size"]
298
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300
+ return rle
301
+
302
+
303
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304
+ """
305
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307
+ """
308
+ # torch.max below raises an error on empty inputs, just skip in this case
309
+ if torch.numel(masks) == 0:
310
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311
+
312
+ # Normalize shape to CxHxW
313
+ shape = masks.shape
314
+ h, w = shape[-2:]
315
+ if len(shape) > 2:
316
+ masks = masks.flatten(0, -3)
317
+ else:
318
+ masks = masks.unsqueeze(0)
319
+
320
+ # Get top and bottom edges
321
+ in_height, _ = torch.max(masks, dim=-1)
322
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324
+ in_height_coords = in_height_coords + h * (~in_height)
325
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
326
+
327
+ # Get left and right edges
328
+ in_width, _ = torch.max(masks, dim=-2)
329
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
331
+ in_width_coords = in_width_coords + w * (~in_width)
332
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
333
+
334
+ # If the mask is empty the right edge will be to the left of the left edge.
335
+ # Replace these boxes with [0, 0, 0, 0]
336
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338
+ out = out * (~empty_filter).unsqueeze(-1)
339
+
340
+ # Return to original shape
341
+ if len(shape) > 2:
342
+ out = out.reshape(*shape[:-2], 4)
343
+ else:
344
+ out = out[0]
345
+
346
+ return out
models/utils/onnx.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import Tuple
12
+
13
+ from ..modeling import Sam
14
+ from .amg import calculate_stability_score
15
+
16
+
17
+ class SamOnnxModel(nn.Module):
18
+ """
19
+ This model should not be called directly, but is used in ONNX export.
20
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21
+ with some functions modified to enable model tracing. Also supports extra
22
+ options controlling what information. See the ONNX export script for details.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: Sam,
28
+ return_single_mask: bool,
29
+ use_stability_score: bool = False,
30
+ return_extra_metrics: bool = False,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.mask_decoder = model.mask_decoder
34
+ self.model = model
35
+ self.img_size = model.image_encoder.img_size
36
+ self.return_single_mask = return_single_mask
37
+ self.use_stability_score = use_stability_score
38
+ self.stability_score_offset = 1.0
39
+ self.return_extra_metrics = return_extra_metrics
40
+
41
+ @staticmethod
42
+ def resize_longest_image_size(
43
+ input_image_size: torch.Tensor, longest_side: int
44
+ ) -> torch.Tensor:
45
+ input_image_size = input_image_size.to(torch.float32)
46
+ scale = longest_side / torch.max(input_image_size)
47
+ transformed_size = scale * input_image_size
48
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49
+ return transformed_size
50
+
51
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52
+ point_coords = point_coords + 0.5
53
+ point_coords = point_coords / self.img_size
54
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56
+
57
+ point_embedding = point_embedding * (point_labels != -1)
58
+ point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59
+ point_labels == -1
60
+ )
61
+
62
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
63
+ point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64
+ i
65
+ ].weight * (point_labels == i)
66
+
67
+ return point_embedding
68
+
69
+ def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71
+ mask_embedding = mask_embedding + (
72
+ 1 - has_mask_input
73
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74
+ return mask_embedding
75
+
76
+ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77
+ masks = F.interpolate(
78
+ masks,
79
+ size=(self.img_size, self.img_size),
80
+ mode="bilinear",
81
+ align_corners=False,
82
+ )
83
+
84
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86
+
87
+ orig_im_size = orig_im_size.to(torch.int64)
88
+ h, w = orig_im_size[0], orig_im_size[1]
89
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90
+ return masks
91
+
92
+ def select_masks(
93
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ # Determine if we should return the multiclick mask or not from the number of points.
96
+ # The reweighting is used to avoid control flow.
97
+ score_reweight = torch.tensor(
98
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99
+ ).to(iou_preds.device)
100
+ score = iou_preds + (num_points - 2.5) * score_reweight
101
+ best_idx = torch.argmax(score, dim=1)
102
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104
+
105
+ return masks, iou_preds
106
+
107
+ @torch.no_grad()
108
+ def forward(
109
+ self,
110
+ image_embeddings: torch.Tensor,
111
+ point_coords: torch.Tensor,
112
+ point_labels: torch.Tensor,
113
+ mask_input: torch.Tensor,
114
+ has_mask_input: torch.Tensor,
115
+ orig_im_size: torch.Tensor,
116
+ ):
117
+ sparse_embedding = self._embed_points(point_coords, point_labels)
118
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
119
+
120
+ masks, scores = self.model.mask_decoder.predict_masks(
121
+ image_embeddings=image_embeddings,
122
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
123
+ sparse_prompt_embeddings=sparse_embedding,
124
+ dense_prompt_embeddings=dense_embedding,
125
+ )
126
+
127
+ if self.use_stability_score:
128
+ scores = calculate_stability_score(
129
+ masks, self.model.mask_threshold, self.stability_score_offset
130
+ )
131
+
132
+ if self.return_single_mask:
133
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134
+
135
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136
+
137
+ if self.return_extra_metrics:
138
+ stability_scores = calculate_stability_score(
139
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140
+ )
141
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142
+ return upscaled_masks, scores, stability_scores, areas, masks
143
+
144
+ return upscaled_masks, scores, masks
models/utils/transforms.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11
+
12
+ from copy import deepcopy
13
+ from typing import Tuple
14
+
15
+
16
+ class ResizeLongestSide:
17
+ """
18
+ Resizes images to the longest side 'target_length', as well as provides
19
+ methods for resizing coordinates and boxes. Provides methods for
20
+ transforming both numpy array and batched torch tensors.
21
+ """
22
+
23
+ def __init__(self, target_length: int) -> None:
24
+ self.target_length = target_length
25
+
26
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
27
+ """
28
+ Expects a numpy array with shape HxWxC in uint8 format.
29
+ """
30
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31
+ return np.array(resize(to_pil_image(image), target_size))
32
+
33
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34
+ """
35
+ Expects a numpy array of length 2 in the final dimension. Requires the
36
+ original image size in (H, W) format.
37
+ """
38
+ old_h, old_w = original_size
39
+ new_h, new_w = self.get_preprocess_shape(
40
+ original_size[0], original_size[1], self.target_length
41
+ )
42
+ coords = deepcopy(coords).astype(float)
43
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
44
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
45
+ return coords
46
+
47
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48
+ """
49
+ Expects a numpy array shape Bx4. Requires the original image size
50
+ in (H, W) format.
51
+ """
52
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53
+ return boxes.reshape(-1, 4)
54
+
55
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Expects batched images with shape BxCxHxW and float format. This
58
+ transformation may not exactly match apply_image. apply_image is
59
+ the transformation expected by the model.
60
+ """
61
+ # Expects an image in BCHW format. May not exactly match apply_image.
62
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63
+ return F.interpolate(
64
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
65
+ )
66
+
67
+ def apply_coords_torch(
68
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
69
+ ) -> torch.Tensor:
70
+ """
71
+ Expects a torch tensor with length 2 in the last dimension. Requires the
72
+ original image size in (H, W) format.
73
+ """
74
+ old_h, old_w = original_size
75
+ new_h, new_w = self.get_preprocess_shape(
76
+ original_size[0], original_size[1], self.target_length
77
+ )
78
+ coords = deepcopy(coords).to(torch.float)
79
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
80
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
81
+ return coords
82
+
83
+ def apply_boxes_torch(
84
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85
+ ) -> torch.Tensor:
86
+ """
87
+ Expects a torch tensor with shape Bx4. Requires the original image
88
+ size in (H, W) format.
89
+ """
90
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91
+ return boxes.reshape(-1, 4)
92
+
93
+ @staticmethod
94
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95
+ """
96
+ Compute the output size given input size and target long side length.
97
+ """
98
+ scale = long_side_length * 1.0 / max(oldh, oldw)
99
+ newh, neww = oldh * scale, oldw * scale
100
+ neww = int(neww + 0.5)
101
+ newh = int(newh + 0.5)
102
+ return (newh, neww)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pycocotools
4
+ transformers
5
+ gradio_image_prompter-0.1.0-py3-none-any.whl
src/.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .eggs/
2
+ dist/
3
+ *.pyc
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+ __tmp/*
8
+ *.pyi
9
+ node_modules
src/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
src/README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Prompter for Gradio
2
+ A gradio component to upload images and process point/box prompts.
3
+
4
+ This custom component is developed for [Tokenize Anything](https://github.com/baaivision/tokenize-anything) gradio demo.
5
+
6
+ ## Installation
7
+
8
+ ### Preliminaries
9
+
10
+ ``gradio`` >= 4.0.0
11
+
12
+ ### Installing Package
13
+
14
+ ```bash
15
+ pip install gradio-image-prompter
16
+ ```
17
+
18
+ ## Quick Start
19
+
20
+ ### Development
21
+
22
+ ```bash
23
+ cd gradio-image-prompter
24
+ gradio cc install
25
+ gradio cc dev
26
+ ```
27
+
28
+ ### Example
29
+
30
+ ```python
31
+ import gradio as gr
32
+ from gradio_image_prompter import ImagePrompter
33
+
34
+ demo = gr.Interface(
35
+ lambda prompts: (prompts["image"], prompts["points"]),
36
+ ImagePrompter(show_label=False),
37
+ [gr.Image(show_label=False), gr.Dataframe(label="Points")],
38
+ )
39
+ demo.launch()
40
+
41
+ ```
42
+
43
+ ## License
44
+ [Apache License 2.0](LICENSE)
45
+
46
+ ## Acknowledgement
47
+
48
+ We thank the repositories: [SAM](https://github.com/facebookresearch/segment-anything), [GradioBox](https://github.com/ShoufaChen/gradio-box) and [Gradio](https://github.com/gradio-app/gradio).
src/backend/gradio_image_prompter/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .image_prompter import ImagePrompter
2
+
3
+ __all__ = ["ImagePrompter"]
src/backend/gradio_image_prompter/image_prompter.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, PhyscalX. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Gradio ``ImagePrompter`` component."""
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional, List, TypedDict, Union, Literal
21
+
22
+ import numpy as np
23
+ import gradio
24
+ from gradio.data_classes import FileData, GradioModel
25
+ from gradio_client.documentation import document, set_documentation_group
26
+ from PIL import Image as _Image # using _ to minimize namespace pollution
27
+
28
+ set_documentation_group("component")
29
+
30
+
31
+ class PromptData(GradioModel):
32
+ image: FileData
33
+ points: List[List[float]]
34
+
35
+
36
+ class PromptValue(TypedDict):
37
+ image: Optional[Union[np.ndarray, _Image.Image, str]]
38
+ points: Optional[List[List[float]]]
39
+
40
+
41
+ @document()
42
+ class ImagePrompter(gradio.Image):
43
+ """Create an image prompter to upload images and process point/box prompts."""
44
+
45
+ data_model = PromptData
46
+
47
+ def __init__(
48
+ self,
49
+ value: str | _Image.Image | np.ndarray | None = None,
50
+ *,
51
+ height: int | None = None,
52
+ width: int | None = None,
53
+ image_mode: Literal[
54
+ "1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"
55
+ ] = "RGB",
56
+ sources: list[Literal["upload", "clipboard"]] | None = None,
57
+ type: Literal["numpy", "pil", "filepath"] = "numpy",
58
+ label: str | None = None,
59
+ every: float | None = None,
60
+ show_label: bool | None = None,
61
+ show_download_button: bool = True,
62
+ container: bool = True,
63
+ scale: int | None = None,
64
+ min_width: int = 160,
65
+ interactive: bool | None = None,
66
+ visible: bool = True,
67
+ elem_id: str | None = None,
68
+ elem_classes: list[str] | str | None = None,
69
+ render: bool = True,
70
+ show_share_button: bool | None = None,
71
+ ):
72
+ """
73
+ Parameters:
74
+ value: A PIL Image, numpy array, path or URL for the default value. If callable, it will be called set the initial value.
75
+ height: Height of the displayed image in pixels.
76
+ width: Width of the displayed image in pixels.
77
+ image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html.
78
+ sources: List of sources for the image.
79
+ type: The format the image is converted before being passed into the prediction function.
80
+ label: The label for this component.
81
+ every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open.
82
+ show_label: if True, will display label.
83
+ show_download_button: If True, will display button to download image.
84
+ container: If True, will place the component in a container - providing some extra padding around the border.
85
+ scale: relative width compared to adjacent Components in a Row. Should be an integer.
86
+ min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value.
87
+ interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images.
88
+ visible: If False, component will be hidden.
89
+ streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.
90
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM.
91
+ elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM.
92
+ render: If False, component will not render be rendered in the Blocks context.
93
+ mirror_webcam: If True webcam will be mirrored. Default is True.
94
+ show_share_button: If True, show a share icon that allows user to share outputs to Hugging Face Spaces Discussions.
95
+ """
96
+ super(ImagePrompter, self).__init__(
97
+ value=value,
98
+ height=height,
99
+ width=width,
100
+ image_mode=image_mode,
101
+ sources=["upload", "clipboard"] if sources is None else sources,
102
+ type=type,
103
+ label=label,
104
+ every=every,
105
+ show_label=show_label,
106
+ show_download_button=show_download_button,
107
+ container=container,
108
+ scale=scale,
109
+ min_width=min_width,
110
+ interactive=interactive,
111
+ visible=visible,
112
+ elem_id=elem_id,
113
+ elem_classes=elem_classes,
114
+ render=render,
115
+ show_share_button=show_share_button,
116
+ )
117
+
118
+ def preprocess(self, x: PromptData) -> PromptValue | None:
119
+ if x is None:
120
+ return x
121
+ im = super().preprocess(x.image)
122
+ return {"image": im, "points": x.points}
123
+
124
+ def postprocess(self, y: PromptValue) -> PromptData | None:
125
+ if y is None:
126
+ return None
127
+ image, points = y.get("image", None), y.get("points", [])
128
+ return PromptData(image=super().postprocess(image), points=points)
129
+
130
+ def as_example(self, y: PromptValue) -> str | None:
131
+ if y is None:
132
+ return None
133
+ return self.move_resource_to_block_cache(y.get("image", None))
src/backend/gradio_image_prompter/image_prompter.pyi ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, PhyscalX. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Gradio ``ImagePrompter`` component."""
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional, List, TypedDict, Union, Literal
21
+
22
+ import numpy as np
23
+ import gradio
24
+ from gradio.data_classes import FileData, GradioModel
25
+ from gradio_client.documentation import document, set_documentation_group
26
+ from PIL import Image as _Image # using _ to minimize namespace pollution
27
+
28
+ set_documentation_group("component")
29
+
30
+
31
+ class PromptData(GradioModel):
32
+ image: FileData
33
+ points: List[List[float]]
34
+
35
+
36
+ class PromptValue(TypedDict):
37
+ image: Optional[Union[np.ndarray, _Image.Image, str]]
38
+ points: Optional[list[list[float]]]
39
+
40
+ from gradio.events import Dependency
41
+
42
+ @document()
43
+ class ImagePrompter(gradio.Image):
44
+ """Create an image prompter to upload images and process point/box prompts."""
45
+
46
+ data_model = PromptData
47
+
48
+ def __init__(
49
+ self,
50
+ value: str | _Image.Image | np.ndarray | None = None,
51
+ *,
52
+ height: int | None = None,
53
+ width: int | None = None,
54
+ image_mode: Literal[
55
+ "1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"
56
+ ] = "RGB",
57
+ sources: list[Literal["upload", "clipboard"]] | None = None,
58
+ type: Literal["numpy", "pil", "filepath"] = "numpy",
59
+ label: str | None = None,
60
+ every: float | None = None,
61
+ show_label: bool | None = None,
62
+ show_download_button: bool = True,
63
+ container: bool = True,
64
+ scale: int | None = None,
65
+ min_width: int = 160,
66
+ interactive: bool | None = None,
67
+ visible: bool = True,
68
+ elem_id: str | None = None,
69
+ elem_classes: list[str] | str | None = None,
70
+ render: bool = True,
71
+ show_share_button: bool | None = None,
72
+ ):
73
+ """
74
+ Parameters:
75
+ value: A PIL Image, numpy array, path or URL for the default value. If callable, it will be called set the initial value.
76
+ height: Height of the displayed image in pixels.
77
+ width: Width of the displayed image in pixels.
78
+ image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html.
79
+ sources: List of sources for the image.
80
+ type: The format the image is converted before being passed into the prediction function.
81
+ label: The label for this component.
82
+ every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open.
83
+ show_label: if True, will display label.
84
+ show_download_button: If True, will display button to download image.
85
+ container: If True, will place the component in a container - providing some extra padding around the border.
86
+ scale: relative width compared to adjacent Components in a Row. Should be an integer.
87
+ min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value.
88
+ interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images.
89
+ visible: If False, component will be hidden.
90
+ streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.
91
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM.
92
+ elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM.
93
+ render: If False, component will not render be rendered in the Blocks context.
94
+ mirror_webcam: If True webcam will be mirrored. Default is True.
95
+ show_share_button: If True, show a share icon that allows user to share outputs to Hugging Face Spaces Discussions.
96
+ """
97
+ super(ImagePrompter, self).__init__(
98
+ value=value,
99
+ height=height,
100
+ width=width,
101
+ image_mode=image_mode,
102
+ sources=["upload", "clipboard"] if sources is None else sources,
103
+ type=type,
104
+ label=label,
105
+ every=every,
106
+ show_label=show_label,
107
+ show_download_button=show_download_button,
108
+ container=container,
109
+ scale=scale,
110
+ min_width=min_width,
111
+ interactive=interactive,
112
+ visible=visible,
113
+ elem_id=elem_id,
114
+ elem_classes=elem_classes,
115
+ render=render,
116
+ show_share_button=show_share_button,
117
+ )
118
+
119
+ def preprocess(self, x: PromptData) -> PromptValue | None:
120
+ if x is None:
121
+ return x
122
+ im = super().preprocess(x.image)
123
+ return {"image": im, "points": x.points}
124
+
125
+ def postprocess(self, y: PromptValue) -> PromptData | None:
126
+ if y is None:
127
+ return None
128
+ image, points = y.get("image", None), y.get("points", [])
129
+ return PromptData(image=super().postprocess(image), points=points)
130
+
131
+ def as_example(self, y: PromptValue) -> str | None:
132
+ if y is None:
133
+ return None
134
+ return self.move_resource_to_block_cache(y.get("image", None))
src/backend/gradio_image_prompter/templates/component/__vite-browser-external-2447137e.js ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ const e = {};
2
+ export {
3
+ e as default
4
+ };
src/backend/gradio_image_prompter/templates/component/index.js ADDED
The diff for this file is too large to render. See raw diff
 
src/backend/gradio_image_prompter/templates/component/style.css ADDED
@@ -0,0 +1 @@
 
 
1
+ .block.svelte-1t38q2d{position:relative;margin:0;box-shadow:var(--block-shadow);border-width:var(--block-border-width);border-color:var(--block-border-color);border-radius:var(--block-radius);background:var(--block-background-fill);width:100%;line-height:var(--line-sm)}.block.border_focus.svelte-1t38q2d{border-color:var(--color-accent)}.padded.svelte-1t38q2d{padding:var(--block-padding)}.hidden.svelte-1t38q2d{display:none}.hide-container.svelte-1t38q2d{margin:0;box-shadow:none;--block-border-width:0;background:transparent;padding:0;overflow:visible}div.svelte-1hnfib2{margin-bottom:var(--spacing-lg);color:var(--block-info-text-color);font-weight:var(--block-info-text-weight);font-size:var(--block-info-text-size);line-height:var(--line-sm)}span.has-info.svelte-22c38v{margin-bottom:var(--spacing-xs)}span.svelte-22c38v:not(.has-info){margin-bottom:var(--spacing-lg)}span.svelte-22c38v{display:inline-block;position:relative;z-index:var(--layer-4);border:solid var(--block-title-border-width) var(--block-title-border-color);border-radius:var(--block-title-radius);background:var(--block-title-background-fill);padding:var(--block-title-padding);color:var(--block-title-text-color);font-weight:var(--block-title-text-weight);font-size:var(--block-title-text-size);line-height:var(--line-sm)}.hide.svelte-22c38v{margin:0;height:0}label.svelte-9gxdi0{display:inline-flex;align-items:center;z-index:var(--layer-2);box-shadow:var(--block-label-shadow);border:var(--block-label-border-width) solid var(--border-color-primary);border-top:none;border-left:none;border-radius:var(--block-label-radius);background:var(--block-label-background-fill);padding:var(--block-label-padding);pointer-events:none;color:var(--block-label-text-color);font-weight:var(--block-label-text-weight);font-size:var(--block-label-text-size);line-height:var(--line-sm)}.gr-group label.svelte-9gxdi0{border-top-left-radius:0}label.float.svelte-9gxdi0{position:absolute;top:var(--block-label-margin);left:var(--block-label-margin)}label.svelte-9gxdi0:not(.float){position:static;margin-top:var(--block-label-margin);margin-left:var(--block-label-margin)}.hide.svelte-9gxdi0{height:0}span.svelte-9gxdi0{opacity:.8;margin-right:var(--size-2);width:calc(var(--block-label-text-size) - 1px);height:calc(var(--block-label-text-size) - 1px)}.hide-label.svelte-9gxdi0{box-shadow:none;border-width:0;background:transparent;overflow:visible}button.svelte-lpi64a{display:flex;justify-content:center;align-items:center;gap:1px;z-index:var(--layer-2);border-radius:var(--radius-sm);color:var(--block-label-text-color);border:1px solid transparent}button[disabled].svelte-lpi64a{opacity:.5;box-shadow:none}button[disabled].svelte-lpi64a:hover{cursor:not-allowed}.padded.svelte-lpi64a{padding:2px;background:var(--bg-color);box-shadow:var(--shadow-drop);border:1px solid var(--button-secondary-border-color)}button.svelte-lpi64a:hover,button.highlight.svelte-lpi64a{cursor:pointer;color:var(--color-accent)}.padded.svelte-lpi64a:hover{border:2px solid var(--button-secondary-border-color-hover);padding:1px;color:var(--block-label-text-color)}span.svelte-lpi64a{padding:0 1px;font-size:10px}div.svelte-lpi64a{padding:2px;display:flex;align-items:flex-end}.small.svelte-lpi64a{width:14px;height:14px}.large.svelte-lpi64a{width:22px;height:22px}.pending.svelte-lpi64a{animation:svelte-lpi64a-flash .5s infinite}@keyframes svelte-lpi64a-flash{0%{opacity:.5}50%{opacity:1}to{opacity:.5}}.transparent.svelte-lpi64a{background:transparent;border:none;box-shadow:none}.empty.svelte-3w3rth{display:flex;justify-content:center;align-items:center;margin-top:calc(0px - var(--size-6));height:var(--size-full)}.icon.svelte-3w3rth{opacity:.5;height:var(--size-5);color:var(--body-text-color)}.small.svelte-3w3rth{min-height:calc(var(--size-32) - 20px)}.large.svelte-3w3rth{min-height:calc(var(--size-64) - 20px)}.unpadded_box.svelte-3w3rth{margin-top:0}.small_parent.svelte-3w3rth{min-height:100%!important}.dropdown-arrow.svelte-145leq6{fill:currentColor}.wrap.svelte-kzcjhc{display:flex;flex-direction:column;justify-content:center;align-items:center;min-height:var(--size-60);color:var(--block-label-text-color);line-height:var(--line-md);height:100%;padding-top:var(--size-3)}.or.svelte-kzcjhc{color:var(--body-text-color-subdued);display:flex}.icon-wrap.svelte-kzcjhc{width:30px;margin-bottom:var(--spacing-lg)}@media (--screen-md){.wrap.svelte-kzcjhc{font-size:var(--text-lg)}}.hovered.svelte-kzcjhc{color:var(--color-accent)}div.svelte-ipfyu7{border-top:1px solid transparent;display:flex;max-height:100%;justify-content:center;gap:var(--spacing-sm);height:auto;align-items:flex-end;padding-bottom:var(--spacing-xl);color:var(--block-label-text-color);flex-shrink:0;width:95%}.show_border.svelte-ipfyu7{border-top:1px solid var(--block-border-color);margin-top:var(--spacing-xxl);box-shadow:var(--shadow-drop)}.source-selection.svelte-lde7lt{display:flex;align-items:center;justify-content:center;border-top:1px solid var(--border-color-primary);width:95%;bottom:0;left:0;right:0;margin-left:auto;margin-right:auto;align-self:flex-end}.icon.svelte-lde7lt{width:22px;height:22px;margin:var(--spacing-lg) var(--spacing-xs);padding:var(--spacing-xs);color:var(--neutral-400);border-radius:var(--radius-md)}.selected.svelte-lde7lt{color:var(--color-accent)}.icon.svelte-lde7lt:hover,.icon.svelte-lde7lt:focus{color:var(--color-accent)}img.svelte-1e0ed51,button.svelte-1e0ed51{width:var(--size-full);height:var(--size-full);object-fit:contain;display:block;border-radius:var(--radius-lg)}.selectable.svelte-1e0ed51{cursor:crosshair}.icon-buttons.svelte-1e0ed51{display:flex;position:absolute;top:6px;right:6px;gap:var(--size-1)}.wrap.svelte-12ckl9l.svelte-12ckl9l{overflow-y:auto;transition:opacity .5s ease-in-out;background:var(--block-background-fill);position:relative;display:flex;flex-direction:column;align-items:center;justify-content:center;min-height:var(--size-40)}.wrap.svelte-12ckl9l.svelte-12ckl9l:after{content:"";position:absolute;top:0;left:0;width:var(--upload-progress-width);height:100%;transition:all .5s ease-in-out;z-index:1}.uploading.svelte-12ckl9l.svelte-12ckl9l{font-size:var(--text-lg);font-family:var(--font);z-index:2}.file-name.svelte-12ckl9l.svelte-12ckl9l{margin:var(--spacing-md);font-size:var(--text-lg);color:var(--body-text-color-subdued)}.file.svelte-12ckl9l.svelte-12ckl9l{font-size:var(--text-md);z-index:2;display:flex;align-items:center}.file.svelte-12ckl9l progress.svelte-12ckl9l{display:inline;height:var(--size-1);width:100%;transition:all .5s ease-in-out;color:var(--color-accent);border:none}.file.svelte-12ckl9l progress[value].svelte-12ckl9l::-webkit-progress-value{background-color:var(--color-accent);border-radius:20px}.file.svelte-12ckl9l progress[value].svelte-12ckl9l::-webkit-progress-bar{background-color:var(--border-color-accent);border-radius:20px}.progress-bar.svelte-12ckl9l.svelte-12ckl9l{width:14px;height:14px;border-radius:50%;background:radial-gradient(closest-side,var(--block-background-fill) 64%,transparent 53% 100%),conic-gradient(var(--color-accent) var(--upload-progress-width),var(--border-color-accent) 0);transition:all .5s ease-in-out}button.svelte-1aq8tno{cursor:pointer;width:var(--size-full)}.hidden.svelte-1aq8tno{display:none;height:0!important;position:absolute;width:0;flex-grow:0}.center.svelte-1aq8tno{display:flex;justify-content:center}.flex.svelte-1aq8tno{display:flex;justify-content:center;align-items:center}input.svelte-1aq8tno{display:none}div.svelte-1wj0ocy{display:flex;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-1)}.not-absolute.svelte-1wj0ocy{margin:var(--size-1)}div.svelte-1o7cyxy{display:flex;position:absolute;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-5)}canvas.svelte-1mnpmgt{display:block;position:absolute;top:0;right:0;bottom:0;left:0;margin:auto}.wrap.svelte-1mnpmgt{position:relative;width:var(--size-full);height:var(--size-full);touch-action:none}img.svelte-1qm7xww{width:var(--size-full);height:var(--size-full)}.upload-container.svelte-1qm7xww{height:100%;flex-shrink:1;max-height:100%}.image-container.svelte-1qm7xww{display:flex;height:100%;flex-direction:column;justify-content:center;align-items:center;max-height:100%}svg.svelte-43sxxs.svelte-43sxxs{width:var(--size-20);height:var(--size-20)}svg.svelte-43sxxs path.svelte-43sxxs{fill:var(--loader-color)}div.svelte-43sxxs.svelte-43sxxs{z-index:var(--layer-2)}.margin.svelte-43sxxs.svelte-43sxxs{margin:var(--size-4)}.wrap.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;justify-content:center;align-items:center;z-index:var(--layer-top);transition:opacity .1s ease-in-out;border-radius:var(--block-radius);background:var(--block-background-fill);padding:0 var(--size-6);max-height:var(--size-screen-h);overflow:hidden;pointer-events:none}.wrap.center.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;left:0}.wrap.default.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;bottom:0;left:0}.hide.svelte-1txqlrd.svelte-1txqlrd{opacity:0;pointer-events:none}.generating.svelte-1txqlrd.svelte-1txqlrd{animation:svelte-1txqlrd-pulse 2s cubic-bezier(.4,0,.6,1) infinite;border:2px solid var(--color-accent);background:transparent}.translucent.svelte-1txqlrd.svelte-1txqlrd{background:none}@keyframes svelte-1txqlrd-pulse{0%,to{opacity:1}50%{opacity:.5}}.loading.svelte-1txqlrd.svelte-1txqlrd{z-index:var(--layer-2);color:var(--body-text-color)}.eta-bar.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;bottom:0;left:0;transform-origin:left;opacity:.8;z-index:var(--layer-1);transition:10ms;background:var(--background-fill-secondary)}.progress-bar-wrap.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary);background:var(--background-fill-primary);width:55.5%;height:var(--size-4)}.progress-bar.svelte-1txqlrd.svelte-1txqlrd{transform-origin:left;background-color:var(--loader-color);width:var(--size-full);height:var(--size-full)}.progress-level.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;align-items:center;gap:1;z-index:var(--layer-2);width:var(--size-full)}.progress-level-inner.svelte-1txqlrd.svelte-1txqlrd{margin:var(--size-2) auto;color:var(--body-text-color);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text-center.svelte-1txqlrd.svelte-1txqlrd{display:flex;position:absolute;top:0;right:0;justify-content:center;align-items:center;transform:translateY(var(--size-6));z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono);text-align:center}.error.svelte-1txqlrd.svelte-1txqlrd{box-shadow:var(--shadow-drop);border:solid 1px var(--error-border-color);border-radius:var(--radius-full);background:var(--error-background-fill);padding-right:var(--size-4);padding-left:var(--size-4);color:var(--error-text-color);font-weight:var(--weight-semibold);font-size:var(--text-lg);line-height:var(--line-lg);font-family:var(--font)}.minimal.svelte-1txqlrd .progress-text.svelte-1txqlrd{background:var(--block-background-fill)}.border.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary)}.toast-body.svelte-solcu7{display:flex;position:relative;right:0;left:0;align-items:center;margin:var(--size-6) var(--size-4);margin:auto;border-radius:var(--container-radius);overflow:hidden;pointer-events:auto}.toast-body.error.svelte-solcu7{border:1px solid var(--color-red-700);background:var(--color-red-50)}.dark .toast-body.error.svelte-solcu7{border:1px solid var(--color-red-500);background-color:var(--color-grey-950)}.toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-700);background:var(--color-yellow-50)}.dark .toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-500);background-color:var(--color-grey-950)}.toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-700);background:var(--color-grey-50)}.dark .toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-500);background-color:var(--color-grey-950)}.toast-title.svelte-solcu7{display:flex;align-items:center;font-weight:var(--weight-bold);font-size:var(--text-lg);line-height:var(--line-sm);text-transform:capitalize}.toast-title.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-title.error.svelte-solcu7{color:var(--color-red-50)}.toast-title.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-title.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-title.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-title.info.svelte-solcu7{color:var(--color-grey-50)}.toast-close.svelte-solcu7{margin:0 var(--size-3);border-radius:var(--size-3);padding:0px var(--size-1-5);font-size:var(--size-5);line-height:var(--size-5)}.toast-close.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-close.error.svelte-solcu7{color:var(--color-red-500)}.toast-close.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-close.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-close.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-close.info.svelte-solcu7{color:var(--color-grey-500)}.toast-text.svelte-solcu7{font-size:var(--text-lg)}.toast-text.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-text.error.svelte-solcu7{color:var(--color-red-50)}.toast-text.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-text.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-text.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-text.info.svelte-solcu7{color:var(--color-grey-50)}.toast-details.svelte-solcu7{margin:var(--size-3) var(--size-3) var(--size-3) 0;width:100%}.toast-icon.svelte-solcu7{display:flex;position:absolute;position:relative;flex-shrink:0;justify-content:center;align-items:center;margin:var(--size-2);border-radius:var(--radius-full);padding:var(--size-1);padding-left:calc(var(--size-1) - 1px);width:35px;height:35px}.toast-icon.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-icon.error.svelte-solcu7{color:var(--color-red-500)}.toast-icon.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-icon.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-icon.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-icon.info.svelte-solcu7{color:var(--color-grey-500)}@keyframes svelte-solcu7-countdown{0%{transform:scaleX(1)}to{transform:scaleX(0)}}.timer.svelte-solcu7{position:absolute;bottom:0;left:0;transform-origin:0 0;animation:svelte-solcu7-countdown 10s linear forwards;width:100%;height:var(--size-1)}.timer.error.svelte-solcu7{background:var(--color-red-700)}.dark .timer.error.svelte-solcu7{background:var(--color-red-500)}.timer.warning.svelte-solcu7{background:var(--color-yellow-700)}.dark .timer.warning.svelte-solcu7{background:var(--color-yellow-500)}.timer.info.svelte-solcu7{background:var(--color-grey-700)}.dark .timer.info.svelte-solcu7{background:var(--color-grey-500)}.toast-wrap.svelte-gatr8h{display:flex;position:fixed;top:var(--size-4);right:var(--size-4);flex-direction:column;align-items:end;gap:var(--size-2);z-index:var(--layer-top);width:calc(100% - var(--size-8))}@media (--screen-sm){.toast-wrap.svelte-gatr8h{width:calc(var(--size-96) + var(--size-10))}}.container.svelte-h11ksk img{width:100%;height:100%}.container.selected.svelte-h11ksk{border-color:var(--border-color-accent)}.container.table.svelte-h11ksk{margin:0 auto;border:2px solid var(--border-color-primary);border-radius:var(--radius-lg);overflow:hidden;width:var(--size-20);height:var(--size-20);object-fit:cover}.container.gallery.svelte-h11ksk{height:var(--size-20);max-height:var(--size-20);object-fit:cover}
src/backend/gradio_image_prompter/templates/component/wrapper-6f348d45-f837cf34.js ADDED
@@ -0,0 +1,2455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import S from "./__vite-browser-external-2447137e.js";
2
+ function z(s) {
3
+ return s && s.__esModule && Object.prototype.hasOwnProperty.call(s, "default") ? s.default : s;
4
+ }
5
+ function gt(s) {
6
+ if (s.__esModule)
7
+ return s;
8
+ var e = s.default;
9
+ if (typeof e == "function") {
10
+ var t = function r() {
11
+ if (this instanceof r) {
12
+ var i = [null];
13
+ i.push.apply(i, arguments);
14
+ var n = Function.bind.apply(e, i);
15
+ return new n();
16
+ }
17
+ return e.apply(this, arguments);
18
+ };
19
+ t.prototype = e.prototype;
20
+ } else
21
+ t = {};
22
+ return Object.defineProperty(t, "__esModule", { value: !0 }), Object.keys(s).forEach(function(r) {
23
+ var i = Object.getOwnPropertyDescriptor(s, r);
24
+ Object.defineProperty(t, r, i.get ? i : {
25
+ enumerable: !0,
26
+ get: function() {
27
+ return s[r];
28
+ }
29
+ });
30
+ }), t;
31
+ }
32
+ const { Duplex: yt } = S;
33
+ function Oe(s) {
34
+ s.emit("close");
35
+ }
36
+ function vt() {
37
+ !this.destroyed && this._writableState.finished && this.destroy();
38
+ }
39
+ function Qe(s) {
40
+ this.removeListener("error", Qe), this.destroy(), this.listenerCount("error") === 0 && this.emit("error", s);
41
+ }
42
+ function St(s, e) {
43
+ let t = !0;
44
+ const r = new yt({
45
+ ...e,
46
+ autoDestroy: !1,
47
+ emitClose: !1,
48
+ objectMode: !1,
49
+ writableObjectMode: !1
50
+ });
51
+ return s.on("message", function(n, o) {
52
+ const l = !o && r._readableState.objectMode ? n.toString() : n;
53
+ r.push(l) || s.pause();
54
+ }), s.once("error", function(n) {
55
+ r.destroyed || (t = !1, r.destroy(n));
56
+ }), s.once("close", function() {
57
+ r.destroyed || r.push(null);
58
+ }), r._destroy = function(i, n) {
59
+ if (s.readyState === s.CLOSED) {
60
+ n(i), process.nextTick(Oe, r);
61
+ return;
62
+ }
63
+ let o = !1;
64
+ s.once("error", function(f) {
65
+ o = !0, n(f);
66
+ }), s.once("close", function() {
67
+ o || n(i), process.nextTick(Oe, r);
68
+ }), t && s.terminate();
69
+ }, r._final = function(i) {
70
+ if (s.readyState === s.CONNECTING) {
71
+ s.once("open", function() {
72
+ r._final(i);
73
+ });
74
+ return;
75
+ }
76
+ s._socket !== null && (s._socket._writableState.finished ? (i(), r._readableState.endEmitted && r.destroy()) : (s._socket.once("finish", function() {
77
+ i();
78
+ }), s.close()));
79
+ }, r._read = function() {
80
+ s.isPaused && s.resume();
81
+ }, r._write = function(i, n, o) {
82
+ if (s.readyState === s.CONNECTING) {
83
+ s.once("open", function() {
84
+ r._write(i, n, o);
85
+ });
86
+ return;
87
+ }
88
+ s.send(i, o);
89
+ }, r.on("end", vt), r.on("error", Qe), r;
90
+ }
91
+ var Et = St;
92
+ const Vs = /* @__PURE__ */ z(Et);
93
+ var te = { exports: {} }, U = {
94
+ BINARY_TYPES: ["nodebuffer", "arraybuffer", "fragments"],
95
+ EMPTY_BUFFER: Buffer.alloc(0),
96
+ GUID: "258EAFA5-E914-47DA-95CA-C5AB0DC85B11",
97
+ kForOnEventAttribute: Symbol("kIsForOnEventAttribute"),
98
+ kListener: Symbol("kListener"),
99
+ kStatusCode: Symbol("status-code"),
100
+ kWebSocket: Symbol("websocket"),
101
+ NOOP: () => {
102
+ }
103
+ }, bt, xt;
104
+ const { EMPTY_BUFFER: kt } = U, Se = Buffer[Symbol.species];
105
+ function wt(s, e) {
106
+ if (s.length === 0)
107
+ return kt;
108
+ if (s.length === 1)
109
+ return s[0];
110
+ const t = Buffer.allocUnsafe(e);
111
+ let r = 0;
112
+ for (let i = 0; i < s.length; i++) {
113
+ const n = s[i];
114
+ t.set(n, r), r += n.length;
115
+ }
116
+ return r < e ? new Se(t.buffer, t.byteOffset, r) : t;
117
+ }
118
+ function Je(s, e, t, r, i) {
119
+ for (let n = 0; n < i; n++)
120
+ t[r + n] = s[n] ^ e[n & 3];
121
+ }
122
+ function et(s, e) {
123
+ for (let t = 0; t < s.length; t++)
124
+ s[t] ^= e[t & 3];
125
+ }
126
+ function Ot(s) {
127
+ return s.length === s.buffer.byteLength ? s.buffer : s.buffer.slice(s.byteOffset, s.byteOffset + s.length);
128
+ }
129
+ function Ee(s) {
130
+ if (Ee.readOnly = !0, Buffer.isBuffer(s))
131
+ return s;
132
+ let e;
133
+ return s instanceof ArrayBuffer ? e = new Se(s) : ArrayBuffer.isView(s) ? e = new Se(s.buffer, s.byteOffset, s.byteLength) : (e = Buffer.from(s), Ee.readOnly = !1), e;
134
+ }
135
+ te.exports = {
136
+ concat: wt,
137
+ mask: Je,
138
+ toArrayBuffer: Ot,
139
+ toBuffer: Ee,
140
+ unmask: et
141
+ };
142
+ if (!process.env.WS_NO_BUFFER_UTIL)
143
+ try {
144
+ const s = require("bufferutil");
145
+ xt = te.exports.mask = function(e, t, r, i, n) {
146
+ n < 48 ? Je(e, t, r, i, n) : s.mask(e, t, r, i, n);
147
+ }, bt = te.exports.unmask = function(e, t) {
148
+ e.length < 32 ? et(e, t) : s.unmask(e, t);
149
+ };
150
+ } catch {
151
+ }
152
+ var ne = te.exports;
153
+ const Ce = Symbol("kDone"), ue = Symbol("kRun");
154
+ let Ct = class {
155
+ /**
156
+ * Creates a new `Limiter`.
157
+ *
158
+ * @param {Number} [concurrency=Infinity] The maximum number of jobs allowed
159
+ * to run concurrently
160
+ */
161
+ constructor(e) {
162
+ this[Ce] = () => {
163
+ this.pending--, this[ue]();
164
+ }, this.concurrency = e || 1 / 0, this.jobs = [], this.pending = 0;
165
+ }
166
+ /**
167
+ * Adds a job to the queue.
168
+ *
169
+ * @param {Function} job The job to run
170
+ * @public
171
+ */
172
+ add(e) {
173
+ this.jobs.push(e), this[ue]();
174
+ }
175
+ /**
176
+ * Removes a job from the queue and runs it if possible.
177
+ *
178
+ * @private
179
+ */
180
+ [ue]() {
181
+ if (this.pending !== this.concurrency && this.jobs.length) {
182
+ const e = this.jobs.shift();
183
+ this.pending++, e(this[Ce]);
184
+ }
185
+ }
186
+ };
187
+ var Tt = Ct;
188
+ const W = S, Te = ne, Lt = Tt, { kStatusCode: tt } = U, Nt = Buffer[Symbol.species], Pt = Buffer.from([0, 0, 255, 255]), se = Symbol("permessage-deflate"), w = Symbol("total-length"), V = Symbol("callback"), C = Symbol("buffers"), J = Symbol("error");
189
+ let K, Rt = class {
190
+ /**
191
+ * Creates a PerMessageDeflate instance.
192
+ *
193
+ * @param {Object} [options] Configuration options
194
+ * @param {(Boolean|Number)} [options.clientMaxWindowBits] Advertise support
195
+ * for, or request, a custom client window size
196
+ * @param {Boolean} [options.clientNoContextTakeover=false] Advertise/
197
+ * acknowledge disabling of client context takeover
198
+ * @param {Number} [options.concurrencyLimit=10] The number of concurrent
199
+ * calls to zlib
200
+ * @param {(Boolean|Number)} [options.serverMaxWindowBits] Request/confirm the
201
+ * use of a custom server window size
202
+ * @param {Boolean} [options.serverNoContextTakeover=false] Request/accept
203
+ * disabling of server context takeover
204
+ * @param {Number} [options.threshold=1024] Size (in bytes) below which
205
+ * messages should not be compressed if context takeover is disabled
206
+ * @param {Object} [options.zlibDeflateOptions] Options to pass to zlib on
207
+ * deflate
208
+ * @param {Object} [options.zlibInflateOptions] Options to pass to zlib on
209
+ * inflate
210
+ * @param {Boolean} [isServer=false] Create the instance in either server or
211
+ * client mode
212
+ * @param {Number} [maxPayload=0] The maximum allowed message length
213
+ */
214
+ constructor(e, t, r) {
215
+ if (this._maxPayload = r | 0, this._options = e || {}, this._threshold = this._options.threshold !== void 0 ? this._options.threshold : 1024, this._isServer = !!t, this._deflate = null, this._inflate = null, this.params = null, !K) {
216
+ const i = this._options.concurrencyLimit !== void 0 ? this._options.concurrencyLimit : 10;
217
+ K = new Lt(i);
218
+ }
219
+ }
220
+ /**
221
+ * @type {String}
222
+ */
223
+ static get extensionName() {
224
+ return "permessage-deflate";
225
+ }
226
+ /**
227
+ * Create an extension negotiation offer.
228
+ *
229
+ * @return {Object} Extension parameters
230
+ * @public
231
+ */
232
+ offer() {
233
+ const e = {};
234
+ return this._options.serverNoContextTakeover && (e.server_no_context_takeover = !0), this._options.clientNoContextTakeover && (e.client_no_context_takeover = !0), this._options.serverMaxWindowBits && (e.server_max_window_bits = this._options.serverMaxWindowBits), this._options.clientMaxWindowBits ? e.client_max_window_bits = this._options.clientMaxWindowBits : this._options.clientMaxWindowBits == null && (e.client_max_window_bits = !0), e;
235
+ }
236
+ /**
237
+ * Accept an extension negotiation offer/response.
238
+ *
239
+ * @param {Array} configurations The extension negotiation offers/reponse
240
+ * @return {Object} Accepted configuration
241
+ * @public
242
+ */
243
+ accept(e) {
244
+ return e = this.normalizeParams(e), this.params = this._isServer ? this.acceptAsServer(e) : this.acceptAsClient(e), this.params;
245
+ }
246
+ /**
247
+ * Releases all resources used by the extension.
248
+ *
249
+ * @public
250
+ */
251
+ cleanup() {
252
+ if (this._inflate && (this._inflate.close(), this._inflate = null), this._deflate) {
253
+ const e = this._deflate[V];
254
+ this._deflate.close(), this._deflate = null, e && e(
255
+ new Error(
256
+ "The deflate stream was closed while data was being processed"
257
+ )
258
+ );
259
+ }
260
+ }
261
+ /**
262
+ * Accept an extension negotiation offer.
263
+ *
264
+ * @param {Array} offers The extension negotiation offers
265
+ * @return {Object} Accepted configuration
266
+ * @private
267
+ */
268
+ acceptAsServer(e) {
269
+ const t = this._options, r = e.find((i) => !(t.serverNoContextTakeover === !1 && i.server_no_context_takeover || i.server_max_window_bits && (t.serverMaxWindowBits === !1 || typeof t.serverMaxWindowBits == "number" && t.serverMaxWindowBits > i.server_max_window_bits) || typeof t.clientMaxWindowBits == "number" && !i.client_max_window_bits));
270
+ if (!r)
271
+ throw new Error("None of the extension offers can be accepted");
272
+ return t.serverNoContextTakeover && (r.server_no_context_takeover = !0), t.clientNoContextTakeover && (r.client_no_context_takeover = !0), typeof t.serverMaxWindowBits == "number" && (r.server_max_window_bits = t.serverMaxWindowBits), typeof t.clientMaxWindowBits == "number" ? r.client_max_window_bits = t.clientMaxWindowBits : (r.client_max_window_bits === !0 || t.clientMaxWindowBits === !1) && delete r.client_max_window_bits, r;
273
+ }
274
+ /**
275
+ * Accept the extension negotiation response.
276
+ *
277
+ * @param {Array} response The extension negotiation response
278
+ * @return {Object} Accepted configuration
279
+ * @private
280
+ */
281
+ acceptAsClient(e) {
282
+ const t = e[0];
283
+ if (this._options.clientNoContextTakeover === !1 && t.client_no_context_takeover)
284
+ throw new Error('Unexpected parameter "client_no_context_takeover"');
285
+ if (!t.client_max_window_bits)
286
+ typeof this._options.clientMaxWindowBits == "number" && (t.client_max_window_bits = this._options.clientMaxWindowBits);
287
+ else if (this._options.clientMaxWindowBits === !1 || typeof this._options.clientMaxWindowBits == "number" && t.client_max_window_bits > this._options.clientMaxWindowBits)
288
+ throw new Error(
289
+ 'Unexpected or invalid parameter "client_max_window_bits"'
290
+ );
291
+ return t;
292
+ }
293
+ /**
294
+ * Normalize parameters.
295
+ *
296
+ * @param {Array} configurations The extension negotiation offers/reponse
297
+ * @return {Array} The offers/response with normalized parameters
298
+ * @private
299
+ */
300
+ normalizeParams(e) {
301
+ return e.forEach((t) => {
302
+ Object.keys(t).forEach((r) => {
303
+ let i = t[r];
304
+ if (i.length > 1)
305
+ throw new Error(`Parameter "${r}" must have only a single value`);
306
+ if (i = i[0], r === "client_max_window_bits") {
307
+ if (i !== !0) {
308
+ const n = +i;
309
+ if (!Number.isInteger(n) || n < 8 || n > 15)
310
+ throw new TypeError(
311
+ `Invalid value for parameter "${r}": ${i}`
312
+ );
313
+ i = n;
314
+ } else if (!this._isServer)
315
+ throw new TypeError(
316
+ `Invalid value for parameter "${r}": ${i}`
317
+ );
318
+ } else if (r === "server_max_window_bits") {
319
+ const n = +i;
320
+ if (!Number.isInteger(n) || n < 8 || n > 15)
321
+ throw new TypeError(
322
+ `Invalid value for parameter "${r}": ${i}`
323
+ );
324
+ i = n;
325
+ } else if (r === "client_no_context_takeover" || r === "server_no_context_takeover") {
326
+ if (i !== !0)
327
+ throw new TypeError(
328
+ `Invalid value for parameter "${r}": ${i}`
329
+ );
330
+ } else
331
+ throw new Error(`Unknown parameter "${r}"`);
332
+ t[r] = i;
333
+ });
334
+ }), e;
335
+ }
336
+ /**
337
+ * Decompress data. Concurrency limited.
338
+ *
339
+ * @param {Buffer} data Compressed data
340
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
341
+ * @param {Function} callback Callback
342
+ * @public
343
+ */
344
+ decompress(e, t, r) {
345
+ K.add((i) => {
346
+ this._decompress(e, t, (n, o) => {
347
+ i(), r(n, o);
348
+ });
349
+ });
350
+ }
351
+ /**
352
+ * Compress data. Concurrency limited.
353
+ *
354
+ * @param {(Buffer|String)} data Data to compress
355
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
356
+ * @param {Function} callback Callback
357
+ * @public
358
+ */
359
+ compress(e, t, r) {
360
+ K.add((i) => {
361
+ this._compress(e, t, (n, o) => {
362
+ i(), r(n, o);
363
+ });
364
+ });
365
+ }
366
+ /**
367
+ * Decompress data.
368
+ *
369
+ * @param {Buffer} data Compressed data
370
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
371
+ * @param {Function} callback Callback
372
+ * @private
373
+ */
374
+ _decompress(e, t, r) {
375
+ const i = this._isServer ? "client" : "server";
376
+ if (!this._inflate) {
377
+ const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
378
+ this._inflate = W.createInflateRaw({
379
+ ...this._options.zlibInflateOptions,
380
+ windowBits: o
381
+ }), this._inflate[se] = this, this._inflate[w] = 0, this._inflate[C] = [], this._inflate.on("error", Bt), this._inflate.on("data", st);
382
+ }
383
+ this._inflate[V] = r, this._inflate.write(e), t && this._inflate.write(Pt), this._inflate.flush(() => {
384
+ const n = this._inflate[J];
385
+ if (n) {
386
+ this._inflate.close(), this._inflate = null, r(n);
387
+ return;
388
+ }
389
+ const o = Te.concat(
390
+ this._inflate[C],
391
+ this._inflate[w]
392
+ );
393
+ this._inflate._readableState.endEmitted ? (this._inflate.close(), this._inflate = null) : (this._inflate[w] = 0, this._inflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._inflate.reset()), r(null, o);
394
+ });
395
+ }
396
+ /**
397
+ * Compress data.
398
+ *
399
+ * @param {(Buffer|String)} data Data to compress
400
+ * @param {Boolean} fin Specifies whether or not this is the last fragment
401
+ * @param {Function} callback Callback
402
+ * @private
403
+ */
404
+ _compress(e, t, r) {
405
+ const i = this._isServer ? "server" : "client";
406
+ if (!this._deflate) {
407
+ const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
408
+ this._deflate = W.createDeflateRaw({
409
+ ...this._options.zlibDeflateOptions,
410
+ windowBits: o
411
+ }), this._deflate[w] = 0, this._deflate[C] = [], this._deflate.on("data", Ut);
412
+ }
413
+ this._deflate[V] = r, this._deflate.write(e), this._deflate.flush(W.Z_SYNC_FLUSH, () => {
414
+ if (!this._deflate)
415
+ return;
416
+ let n = Te.concat(
417
+ this._deflate[C],
418
+ this._deflate[w]
419
+ );
420
+ t && (n = new Nt(n.buffer, n.byteOffset, n.length - 4)), this._deflate[V] = null, this._deflate[w] = 0, this._deflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._deflate.reset(), r(null, n);
421
+ });
422
+ }
423
+ };
424
+ var oe = Rt;
425
+ function Ut(s) {
426
+ this[C].push(s), this[w] += s.length;
427
+ }
428
+ function st(s) {
429
+ if (this[w] += s.length, this[se]._maxPayload < 1 || this[w] <= this[se]._maxPayload) {
430
+ this[C].push(s);
431
+ return;
432
+ }
433
+ this[J] = new RangeError("Max payload size exceeded"), this[J].code = "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH", this[J][tt] = 1009, this.removeListener("data", st), this.reset();
434
+ }
435
+ function Bt(s) {
436
+ this[se]._inflate = null, s[tt] = 1007, this[V](s);
437
+ }
438
+ var re = { exports: {} };
439
+ const $t = {}, Mt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
440
+ __proto__: null,
441
+ default: $t
442
+ }, Symbol.toStringTag, { value: "Module" })), It = /* @__PURE__ */ gt(Mt);
443
+ var Le;
444
+ const { isUtf8: Ne } = S, Dt = [
445
+ 0,
446
+ 0,
447
+ 0,
448
+ 0,
449
+ 0,
450
+ 0,
451
+ 0,
452
+ 0,
453
+ 0,
454
+ 0,
455
+ 0,
456
+ 0,
457
+ 0,
458
+ 0,
459
+ 0,
460
+ 0,
461
+ // 0 - 15
462
+ 0,
463
+ 0,
464
+ 0,
465
+ 0,
466
+ 0,
467
+ 0,
468
+ 0,
469
+ 0,
470
+ 0,
471
+ 0,
472
+ 0,
473
+ 0,
474
+ 0,
475
+ 0,
476
+ 0,
477
+ 0,
478
+ // 16 - 31
479
+ 0,
480
+ 1,
481
+ 0,
482
+ 1,
483
+ 1,
484
+ 1,
485
+ 1,
486
+ 1,
487
+ 0,
488
+ 0,
489
+ 1,
490
+ 1,
491
+ 0,
492
+ 1,
493
+ 1,
494
+ 0,
495
+ // 32 - 47
496
+ 1,
497
+ 1,
498
+ 1,
499
+ 1,
500
+ 1,
501
+ 1,
502
+ 1,
503
+ 1,
504
+ 1,
505
+ 1,
506
+ 0,
507
+ 0,
508
+ 0,
509
+ 0,
510
+ 0,
511
+ 0,
512
+ // 48 - 63
513
+ 0,
514
+ 1,
515
+ 1,
516
+ 1,
517
+ 1,
518
+ 1,
519
+ 1,
520
+ 1,
521
+ 1,
522
+ 1,
523
+ 1,
524
+ 1,
525
+ 1,
526
+ 1,
527
+ 1,
528
+ 1,
529
+ // 64 - 79
530
+ 1,
531
+ 1,
532
+ 1,
533
+ 1,
534
+ 1,
535
+ 1,
536
+ 1,
537
+ 1,
538
+ 1,
539
+ 1,
540
+ 1,
541
+ 0,
542
+ 0,
543
+ 0,
544
+ 1,
545
+ 1,
546
+ // 80 - 95
547
+ 1,
548
+ 1,
549
+ 1,
550
+ 1,
551
+ 1,
552
+ 1,
553
+ 1,
554
+ 1,
555
+ 1,
556
+ 1,
557
+ 1,
558
+ 1,
559
+ 1,
560
+ 1,
561
+ 1,
562
+ 1,
563
+ // 96 - 111
564
+ 1,
565
+ 1,
566
+ 1,
567
+ 1,
568
+ 1,
569
+ 1,
570
+ 1,
571
+ 1,
572
+ 1,
573
+ 1,
574
+ 1,
575
+ 0,
576
+ 1,
577
+ 0,
578
+ 1,
579
+ 0
580
+ // 112 - 127
581
+ ];
582
+ function Wt(s) {
583
+ return s >= 1e3 && s <= 1014 && s !== 1004 && s !== 1005 && s !== 1006 || s >= 3e3 && s <= 4999;
584
+ }
585
+ function be(s) {
586
+ const e = s.length;
587
+ let t = 0;
588
+ for (; t < e; )
589
+ if (!(s[t] & 128))
590
+ t++;
591
+ else if ((s[t] & 224) === 192) {
592
+ if (t + 1 === e || (s[t + 1] & 192) !== 128 || (s[t] & 254) === 192)
593
+ return !1;
594
+ t += 2;
595
+ } else if ((s[t] & 240) === 224) {
596
+ if (t + 2 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || s[t] === 224 && (s[t + 1] & 224) === 128 || // Overlong
597
+ s[t] === 237 && (s[t + 1] & 224) === 160)
598
+ return !1;
599
+ t += 3;
600
+ } else if ((s[t] & 248) === 240) {
601
+ if (t + 3 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || (s[t + 3] & 192) !== 128 || s[t] === 240 && (s[t + 1] & 240) === 128 || // Overlong
602
+ s[t] === 244 && s[t + 1] > 143 || s[t] > 244)
603
+ return !1;
604
+ t += 4;
605
+ } else
606
+ return !1;
607
+ return !0;
608
+ }
609
+ re.exports = {
610
+ isValidStatusCode: Wt,
611
+ isValidUTF8: be,
612
+ tokenChars: Dt
613
+ };
614
+ if (Ne)
615
+ Le = re.exports.isValidUTF8 = function(s) {
616
+ return s.length < 24 ? be(s) : Ne(s);
617
+ };
618
+ else if (!process.env.WS_NO_UTF_8_VALIDATE)
619
+ try {
620
+ const s = It;
621
+ Le = re.exports.isValidUTF8 = function(e) {
622
+ return e.length < 32 ? be(e) : s(e);
623
+ };
624
+ } catch {
625
+ }
626
+ var ae = re.exports;
627
+ const { Writable: At } = S, Pe = oe, {
628
+ BINARY_TYPES: Ft,
629
+ EMPTY_BUFFER: Re,
630
+ kStatusCode: jt,
631
+ kWebSocket: Gt
632
+ } = U, { concat: de, toArrayBuffer: Vt, unmask: Ht } = ne, { isValidStatusCode: zt, isValidUTF8: Ue } = ae, X = Buffer[Symbol.species], A = 0, Be = 1, $e = 2, Me = 3, _e = 4, Yt = 5;
633
+ let qt = class extends At {
634
+ /**
635
+ * Creates a Receiver instance.
636
+ *
637
+ * @param {Object} [options] Options object
638
+ * @param {String} [options.binaryType=nodebuffer] The type for binary data
639
+ * @param {Object} [options.extensions] An object containing the negotiated
640
+ * extensions
641
+ * @param {Boolean} [options.isServer=false] Specifies whether to operate in
642
+ * client or server mode
643
+ * @param {Number} [options.maxPayload=0] The maximum allowed message length
644
+ * @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
645
+ * not to skip UTF-8 validation for text and close messages
646
+ */
647
+ constructor(e = {}) {
648
+ super(), this._binaryType = e.binaryType || Ft[0], this._extensions = e.extensions || {}, this._isServer = !!e.isServer, this._maxPayload = e.maxPayload | 0, this._skipUTF8Validation = !!e.skipUTF8Validation, this[Gt] = void 0, this._bufferedBytes = 0, this._buffers = [], this._compressed = !1, this._payloadLength = 0, this._mask = void 0, this._fragmented = 0, this._masked = !1, this._fin = !1, this._opcode = 0, this._totalPayloadLength = 0, this._messageLength = 0, this._fragments = [], this._state = A, this._loop = !1;
649
+ }
650
+ /**
651
+ * Implements `Writable.prototype._write()`.
652
+ *
653
+ * @param {Buffer} chunk The chunk of data to write
654
+ * @param {String} encoding The character encoding of `chunk`
655
+ * @param {Function} cb Callback
656
+ * @private
657
+ */
658
+ _write(e, t, r) {
659
+ if (this._opcode === 8 && this._state == A)
660
+ return r();
661
+ this._bufferedBytes += e.length, this._buffers.push(e), this.startLoop(r);
662
+ }
663
+ /**
664
+ * Consumes `n` bytes from the buffered data.
665
+ *
666
+ * @param {Number} n The number of bytes to consume
667
+ * @return {Buffer} The consumed bytes
668
+ * @private
669
+ */
670
+ consume(e) {
671
+ if (this._bufferedBytes -= e, e === this._buffers[0].length)
672
+ return this._buffers.shift();
673
+ if (e < this._buffers[0].length) {
674
+ const r = this._buffers[0];
675
+ return this._buffers[0] = new X(
676
+ r.buffer,
677
+ r.byteOffset + e,
678
+ r.length - e
679
+ ), new X(r.buffer, r.byteOffset, e);
680
+ }
681
+ const t = Buffer.allocUnsafe(e);
682
+ do {
683
+ const r = this._buffers[0], i = t.length - e;
684
+ e >= r.length ? t.set(this._buffers.shift(), i) : (t.set(new Uint8Array(r.buffer, r.byteOffset, e), i), this._buffers[0] = new X(
685
+ r.buffer,
686
+ r.byteOffset + e,
687
+ r.length - e
688
+ )), e -= r.length;
689
+ } while (e > 0);
690
+ return t;
691
+ }
692
+ /**
693
+ * Starts the parsing loop.
694
+ *
695
+ * @param {Function} cb Callback
696
+ * @private
697
+ */
698
+ startLoop(e) {
699
+ let t;
700
+ this._loop = !0;
701
+ do
702
+ switch (this._state) {
703
+ case A:
704
+ t = this.getInfo();
705
+ break;
706
+ case Be:
707
+ t = this.getPayloadLength16();
708
+ break;
709
+ case $e:
710
+ t = this.getPayloadLength64();
711
+ break;
712
+ case Me:
713
+ this.getMask();
714
+ break;
715
+ case _e:
716
+ t = this.getData(e);
717
+ break;
718
+ default:
719
+ this._loop = !1;
720
+ return;
721
+ }
722
+ while (this._loop);
723
+ e(t);
724
+ }
725
+ /**
726
+ * Reads the first two bytes of a frame.
727
+ *
728
+ * @return {(RangeError|undefined)} A possible error
729
+ * @private
730
+ */
731
+ getInfo() {
732
+ if (this._bufferedBytes < 2) {
733
+ this._loop = !1;
734
+ return;
735
+ }
736
+ const e = this.consume(2);
737
+ if (e[0] & 48)
738
+ return this._loop = !1, g(
739
+ RangeError,
740
+ "RSV2 and RSV3 must be clear",
741
+ !0,
742
+ 1002,
743
+ "WS_ERR_UNEXPECTED_RSV_2_3"
744
+ );
745
+ const t = (e[0] & 64) === 64;
746
+ if (t && !this._extensions[Pe.extensionName])
747
+ return this._loop = !1, g(
748
+ RangeError,
749
+ "RSV1 must be clear",
750
+ !0,
751
+ 1002,
752
+ "WS_ERR_UNEXPECTED_RSV_1"
753
+ );
754
+ if (this._fin = (e[0] & 128) === 128, this._opcode = e[0] & 15, this._payloadLength = e[1] & 127, this._opcode === 0) {
755
+ if (t)
756
+ return this._loop = !1, g(
757
+ RangeError,
758
+ "RSV1 must be clear",
759
+ !0,
760
+ 1002,
761
+ "WS_ERR_UNEXPECTED_RSV_1"
762
+ );
763
+ if (!this._fragmented)
764
+ return this._loop = !1, g(
765
+ RangeError,
766
+ "invalid opcode 0",
767
+ !0,
768
+ 1002,
769
+ "WS_ERR_INVALID_OPCODE"
770
+ );
771
+ this._opcode = this._fragmented;
772
+ } else if (this._opcode === 1 || this._opcode === 2) {
773
+ if (this._fragmented)
774
+ return this._loop = !1, g(
775
+ RangeError,
776
+ `invalid opcode ${this._opcode}`,
777
+ !0,
778
+ 1002,
779
+ "WS_ERR_INVALID_OPCODE"
780
+ );
781
+ this._compressed = t;
782
+ } else if (this._opcode > 7 && this._opcode < 11) {
783
+ if (!this._fin)
784
+ return this._loop = !1, g(
785
+ RangeError,
786
+ "FIN must be set",
787
+ !0,
788
+ 1002,
789
+ "WS_ERR_EXPECTED_FIN"
790
+ );
791
+ if (t)
792
+ return this._loop = !1, g(
793
+ RangeError,
794
+ "RSV1 must be clear",
795
+ !0,
796
+ 1002,
797
+ "WS_ERR_UNEXPECTED_RSV_1"
798
+ );
799
+ if (this._payloadLength > 125 || this._opcode === 8 && this._payloadLength === 1)
800
+ return this._loop = !1, g(
801
+ RangeError,
802
+ `invalid payload length ${this._payloadLength}`,
803
+ !0,
804
+ 1002,
805
+ "WS_ERR_INVALID_CONTROL_PAYLOAD_LENGTH"
806
+ );
807
+ } else
808
+ return this._loop = !1, g(
809
+ RangeError,
810
+ `invalid opcode ${this._opcode}`,
811
+ !0,
812
+ 1002,
813
+ "WS_ERR_INVALID_OPCODE"
814
+ );
815
+ if (!this._fin && !this._fragmented && (this._fragmented = this._opcode), this._masked = (e[1] & 128) === 128, this._isServer) {
816
+ if (!this._masked)
817
+ return this._loop = !1, g(
818
+ RangeError,
819
+ "MASK must be set",
820
+ !0,
821
+ 1002,
822
+ "WS_ERR_EXPECTED_MASK"
823
+ );
824
+ } else if (this._masked)
825
+ return this._loop = !1, g(
826
+ RangeError,
827
+ "MASK must be clear",
828
+ !0,
829
+ 1002,
830
+ "WS_ERR_UNEXPECTED_MASK"
831
+ );
832
+ if (this._payloadLength === 126)
833
+ this._state = Be;
834
+ else if (this._payloadLength === 127)
835
+ this._state = $e;
836
+ else
837
+ return this.haveLength();
838
+ }
839
+ /**
840
+ * Gets extended payload length (7+16).
841
+ *
842
+ * @return {(RangeError|undefined)} A possible error
843
+ * @private
844
+ */
845
+ getPayloadLength16() {
846
+ if (this._bufferedBytes < 2) {
847
+ this._loop = !1;
848
+ return;
849
+ }
850
+ return this._payloadLength = this.consume(2).readUInt16BE(0), this.haveLength();
851
+ }
852
+ /**
853
+ * Gets extended payload length (7+64).
854
+ *
855
+ * @return {(RangeError|undefined)} A possible error
856
+ * @private
857
+ */
858
+ getPayloadLength64() {
859
+ if (this._bufferedBytes < 8) {
860
+ this._loop = !1;
861
+ return;
862
+ }
863
+ const e = this.consume(8), t = e.readUInt32BE(0);
864
+ return t > Math.pow(2, 53 - 32) - 1 ? (this._loop = !1, g(
865
+ RangeError,
866
+ "Unsupported WebSocket frame: payload length > 2^53 - 1",
867
+ !1,
868
+ 1009,
869
+ "WS_ERR_UNSUPPORTED_DATA_PAYLOAD_LENGTH"
870
+ )) : (this._payloadLength = t * Math.pow(2, 32) + e.readUInt32BE(4), this.haveLength());
871
+ }
872
+ /**
873
+ * Payload length has been read.
874
+ *
875
+ * @return {(RangeError|undefined)} A possible error
876
+ * @private
877
+ */
878
+ haveLength() {
879
+ if (this._payloadLength && this._opcode < 8 && (this._totalPayloadLength += this._payloadLength, this._totalPayloadLength > this._maxPayload && this._maxPayload > 0))
880
+ return this._loop = !1, g(
881
+ RangeError,
882
+ "Max payload size exceeded",
883
+ !1,
884
+ 1009,
885
+ "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
886
+ );
887
+ this._masked ? this._state = Me : this._state = _e;
888
+ }
889
+ /**
890
+ * Reads mask bytes.
891
+ *
892
+ * @private
893
+ */
894
+ getMask() {
895
+ if (this._bufferedBytes < 4) {
896
+ this._loop = !1;
897
+ return;
898
+ }
899
+ this._mask = this.consume(4), this._state = _e;
900
+ }
901
+ /**
902
+ * Reads data bytes.
903
+ *
904
+ * @param {Function} cb Callback
905
+ * @return {(Error|RangeError|undefined)} A possible error
906
+ * @private
907
+ */
908
+ getData(e) {
909
+ let t = Re;
910
+ if (this._payloadLength) {
911
+ if (this._bufferedBytes < this._payloadLength) {
912
+ this._loop = !1;
913
+ return;
914
+ }
915
+ t = this.consume(this._payloadLength), this._masked && this._mask[0] | this._mask[1] | this._mask[2] | this._mask[3] && Ht(t, this._mask);
916
+ }
917
+ if (this._opcode > 7)
918
+ return this.controlMessage(t);
919
+ if (this._compressed) {
920
+ this._state = Yt, this.decompress(t, e);
921
+ return;
922
+ }
923
+ return t.length && (this._messageLength = this._totalPayloadLength, this._fragments.push(t)), this.dataMessage();
924
+ }
925
+ /**
926
+ * Decompresses data.
927
+ *
928
+ * @param {Buffer} data Compressed data
929
+ * @param {Function} cb Callback
930
+ * @private
931
+ */
932
+ decompress(e, t) {
933
+ this._extensions[Pe.extensionName].decompress(e, this._fin, (i, n) => {
934
+ if (i)
935
+ return t(i);
936
+ if (n.length) {
937
+ if (this._messageLength += n.length, this._messageLength > this._maxPayload && this._maxPayload > 0)
938
+ return t(
939
+ g(
940
+ RangeError,
941
+ "Max payload size exceeded",
942
+ !1,
943
+ 1009,
944
+ "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
945
+ )
946
+ );
947
+ this._fragments.push(n);
948
+ }
949
+ const o = this.dataMessage();
950
+ if (o)
951
+ return t(o);
952
+ this.startLoop(t);
953
+ });
954
+ }
955
+ /**
956
+ * Handles a data message.
957
+ *
958
+ * @return {(Error|undefined)} A possible error
959
+ * @private
960
+ */
961
+ dataMessage() {
962
+ if (this._fin) {
963
+ const e = this._messageLength, t = this._fragments;
964
+ if (this._totalPayloadLength = 0, this._messageLength = 0, this._fragmented = 0, this._fragments = [], this._opcode === 2) {
965
+ let r;
966
+ this._binaryType === "nodebuffer" ? r = de(t, e) : this._binaryType === "arraybuffer" ? r = Vt(de(t, e)) : r = t, this.emit("message", r, !0);
967
+ } else {
968
+ const r = de(t, e);
969
+ if (!this._skipUTF8Validation && !Ue(r))
970
+ return this._loop = !1, g(
971
+ Error,
972
+ "invalid UTF-8 sequence",
973
+ !0,
974
+ 1007,
975
+ "WS_ERR_INVALID_UTF8"
976
+ );
977
+ this.emit("message", r, !1);
978
+ }
979
+ }
980
+ this._state = A;
981
+ }
982
+ /**
983
+ * Handles a control message.
984
+ *
985
+ * @param {Buffer} data Data to handle
986
+ * @return {(Error|RangeError|undefined)} A possible error
987
+ * @private
988
+ */
989
+ controlMessage(e) {
990
+ if (this._opcode === 8)
991
+ if (this._loop = !1, e.length === 0)
992
+ this.emit("conclude", 1005, Re), this.end();
993
+ else {
994
+ const t = e.readUInt16BE(0);
995
+ if (!zt(t))
996
+ return g(
997
+ RangeError,
998
+ `invalid status code ${t}`,
999
+ !0,
1000
+ 1002,
1001
+ "WS_ERR_INVALID_CLOSE_CODE"
1002
+ );
1003
+ const r = new X(
1004
+ e.buffer,
1005
+ e.byteOffset + 2,
1006
+ e.length - 2
1007
+ );
1008
+ if (!this._skipUTF8Validation && !Ue(r))
1009
+ return g(
1010
+ Error,
1011
+ "invalid UTF-8 sequence",
1012
+ !0,
1013
+ 1007,
1014
+ "WS_ERR_INVALID_UTF8"
1015
+ );
1016
+ this.emit("conclude", t, r), this.end();
1017
+ }
1018
+ else
1019
+ this._opcode === 9 ? this.emit("ping", e) : this.emit("pong", e);
1020
+ this._state = A;
1021
+ }
1022
+ };
1023
+ var rt = qt;
1024
+ function g(s, e, t, r, i) {
1025
+ const n = new s(
1026
+ t ? `Invalid WebSocket frame: ${e}` : e
1027
+ );
1028
+ return Error.captureStackTrace(n, g), n.code = i, n[jt] = r, n;
1029
+ }
1030
+ const qs = /* @__PURE__ */ z(rt), { randomFillSync: Kt } = S, Ie = oe, { EMPTY_BUFFER: Xt } = U, { isValidStatusCode: Zt } = ae, { mask: De, toBuffer: M } = ne, x = Symbol("kByteLength"), Qt = Buffer.alloc(4);
1031
+ let Jt = class P {
1032
+ /**
1033
+ * Creates a Sender instance.
1034
+ *
1035
+ * @param {(net.Socket|tls.Socket)} socket The connection socket
1036
+ * @param {Object} [extensions] An object containing the negotiated extensions
1037
+ * @param {Function} [generateMask] The function used to generate the masking
1038
+ * key
1039
+ */
1040
+ constructor(e, t, r) {
1041
+ this._extensions = t || {}, r && (this._generateMask = r, this._maskBuffer = Buffer.alloc(4)), this._socket = e, this._firstFragment = !0, this._compress = !1, this._bufferedBytes = 0, this._deflating = !1, this._queue = [];
1042
+ }
1043
+ /**
1044
+ * Frames a piece of data according to the HyBi WebSocket protocol.
1045
+ *
1046
+ * @param {(Buffer|String)} data The data to frame
1047
+ * @param {Object} options Options object
1048
+ * @param {Boolean} [options.fin=false] Specifies whether or not to set the
1049
+ * FIN bit
1050
+ * @param {Function} [options.generateMask] The function used to generate the
1051
+ * masking key
1052
+ * @param {Boolean} [options.mask=false] Specifies whether or not to mask
1053
+ * `data`
1054
+ * @param {Buffer} [options.maskBuffer] The buffer used to store the masking
1055
+ * key
1056
+ * @param {Number} options.opcode The opcode
1057
+ * @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
1058
+ * modified
1059
+ * @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
1060
+ * RSV1 bit
1061
+ * @return {(Buffer|String)[]} The framed data
1062
+ * @public
1063
+ */
1064
+ static frame(e, t) {
1065
+ let r, i = !1, n = 2, o = !1;
1066
+ t.mask && (r = t.maskBuffer || Qt, t.generateMask ? t.generateMask(r) : Kt(r, 0, 4), o = (r[0] | r[1] | r[2] | r[3]) === 0, n = 6);
1067
+ let l;
1068
+ typeof e == "string" ? (!t.mask || o) && t[x] !== void 0 ? l = t[x] : (e = Buffer.from(e), l = e.length) : (l = e.length, i = t.mask && t.readOnly && !o);
1069
+ let f = l;
1070
+ l >= 65536 ? (n += 8, f = 127) : l > 125 && (n += 2, f = 126);
1071
+ const a = Buffer.allocUnsafe(i ? l + n : n);
1072
+ return a[0] = t.fin ? t.opcode | 128 : t.opcode, t.rsv1 && (a[0] |= 64), a[1] = f, f === 126 ? a.writeUInt16BE(l, 2) : f === 127 && (a[2] = a[3] = 0, a.writeUIntBE(l, 4, 6)), t.mask ? (a[1] |= 128, a[n - 4] = r[0], a[n - 3] = r[1], a[n - 2] = r[2], a[n - 1] = r[3], o ? [a, e] : i ? (De(e, r, a, n, l), [a]) : (De(e, r, e, 0, l), [a, e])) : [a, e];
1073
+ }
1074
+ /**
1075
+ * Sends a close message to the other peer.
1076
+ *
1077
+ * @param {Number} [code] The status code component of the body
1078
+ * @param {(String|Buffer)} [data] The message component of the body
1079
+ * @param {Boolean} [mask=false] Specifies whether or not to mask the message
1080
+ * @param {Function} [cb] Callback
1081
+ * @public
1082
+ */
1083
+ close(e, t, r, i) {
1084
+ let n;
1085
+ if (e === void 0)
1086
+ n = Xt;
1087
+ else {
1088
+ if (typeof e != "number" || !Zt(e))
1089
+ throw new TypeError("First argument must be a valid error code number");
1090
+ if (t === void 0 || !t.length)
1091
+ n = Buffer.allocUnsafe(2), n.writeUInt16BE(e, 0);
1092
+ else {
1093
+ const l = Buffer.byteLength(t);
1094
+ if (l > 123)
1095
+ throw new RangeError("The message must not be greater than 123 bytes");
1096
+ n = Buffer.allocUnsafe(2 + l), n.writeUInt16BE(e, 0), typeof t == "string" ? n.write(t, 2) : n.set(t, 2);
1097
+ }
1098
+ }
1099
+ const o = {
1100
+ [x]: n.length,
1101
+ fin: !0,
1102
+ generateMask: this._generateMask,
1103
+ mask: r,
1104
+ maskBuffer: this._maskBuffer,
1105
+ opcode: 8,
1106
+ readOnly: !1,
1107
+ rsv1: !1
1108
+ };
1109
+ this._deflating ? this.enqueue([this.dispatch, n, !1, o, i]) : this.sendFrame(P.frame(n, o), i);
1110
+ }
1111
+ /**
1112
+ * Sends a ping message to the other peer.
1113
+ *
1114
+ * @param {*} data The message to send
1115
+ * @param {Boolean} [mask=false] Specifies whether or not to mask `data`
1116
+ * @param {Function} [cb] Callback
1117
+ * @public
1118
+ */
1119
+ ping(e, t, r) {
1120
+ let i, n;
1121
+ if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
1122
+ throw new RangeError("The data size must not be greater than 125 bytes");
1123
+ const o = {
1124
+ [x]: i,
1125
+ fin: !0,
1126
+ generateMask: this._generateMask,
1127
+ mask: t,
1128
+ maskBuffer: this._maskBuffer,
1129
+ opcode: 9,
1130
+ readOnly: n,
1131
+ rsv1: !1
1132
+ };
1133
+ this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
1134
+ }
1135
+ /**
1136
+ * Sends a pong message to the other peer.
1137
+ *
1138
+ * @param {*} data The message to send
1139
+ * @param {Boolean} [mask=false] Specifies whether or not to mask `data`
1140
+ * @param {Function} [cb] Callback
1141
+ * @public
1142
+ */
1143
+ pong(e, t, r) {
1144
+ let i, n;
1145
+ if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
1146
+ throw new RangeError("The data size must not be greater than 125 bytes");
1147
+ const o = {
1148
+ [x]: i,
1149
+ fin: !0,
1150
+ generateMask: this._generateMask,
1151
+ mask: t,
1152
+ maskBuffer: this._maskBuffer,
1153
+ opcode: 10,
1154
+ readOnly: n,
1155
+ rsv1: !1
1156
+ };
1157
+ this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
1158
+ }
1159
+ /**
1160
+ * Sends a data message to the other peer.
1161
+ *
1162
+ * @param {*} data The message to send
1163
+ * @param {Object} options Options object
1164
+ * @param {Boolean} [options.binary=false] Specifies whether `data` is binary
1165
+ * or text
1166
+ * @param {Boolean} [options.compress=false] Specifies whether or not to
1167
+ * compress `data`
1168
+ * @param {Boolean} [options.fin=false] Specifies whether the fragment is the
1169
+ * last one
1170
+ * @param {Boolean} [options.mask=false] Specifies whether or not to mask
1171
+ * `data`
1172
+ * @param {Function} [cb] Callback
1173
+ * @public
1174
+ */
1175
+ send(e, t, r) {
1176
+ const i = this._extensions[Ie.extensionName];
1177
+ let n = t.binary ? 2 : 1, o = t.compress, l, f;
1178
+ if (typeof e == "string" ? (l = Buffer.byteLength(e), f = !1) : (e = M(e), l = e.length, f = M.readOnly), this._firstFragment ? (this._firstFragment = !1, o && i && i.params[i._isServer ? "server_no_context_takeover" : "client_no_context_takeover"] && (o = l >= i._threshold), this._compress = o) : (o = !1, n = 0), t.fin && (this._firstFragment = !0), i) {
1179
+ const a = {
1180
+ [x]: l,
1181
+ fin: t.fin,
1182
+ generateMask: this._generateMask,
1183
+ mask: t.mask,
1184
+ maskBuffer: this._maskBuffer,
1185
+ opcode: n,
1186
+ readOnly: f,
1187
+ rsv1: o
1188
+ };
1189
+ this._deflating ? this.enqueue([this.dispatch, e, this._compress, a, r]) : this.dispatch(e, this._compress, a, r);
1190
+ } else
1191
+ this.sendFrame(
1192
+ P.frame(e, {
1193
+ [x]: l,
1194
+ fin: t.fin,
1195
+ generateMask: this._generateMask,
1196
+ mask: t.mask,
1197
+ maskBuffer: this._maskBuffer,
1198
+ opcode: n,
1199
+ readOnly: f,
1200
+ rsv1: !1
1201
+ }),
1202
+ r
1203
+ );
1204
+ }
1205
+ /**
1206
+ * Dispatches a message.
1207
+ *
1208
+ * @param {(Buffer|String)} data The message to send
1209
+ * @param {Boolean} [compress=false] Specifies whether or not to compress
1210
+ * `data`
1211
+ * @param {Object} options Options object
1212
+ * @param {Boolean} [options.fin=false] Specifies whether or not to set the
1213
+ * FIN bit
1214
+ * @param {Function} [options.generateMask] The function used to generate the
1215
+ * masking key
1216
+ * @param {Boolean} [options.mask=false] Specifies whether or not to mask
1217
+ * `data`
1218
+ * @param {Buffer} [options.maskBuffer] The buffer used to store the masking
1219
+ * key
1220
+ * @param {Number} options.opcode The opcode
1221
+ * @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
1222
+ * modified
1223
+ * @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
1224
+ * RSV1 bit
1225
+ * @param {Function} [cb] Callback
1226
+ * @private
1227
+ */
1228
+ dispatch(e, t, r, i) {
1229
+ if (!t) {
1230
+ this.sendFrame(P.frame(e, r), i);
1231
+ return;
1232
+ }
1233
+ const n = this._extensions[Ie.extensionName];
1234
+ this._bufferedBytes += r[x], this._deflating = !0, n.compress(e, r.fin, (o, l) => {
1235
+ if (this._socket.destroyed) {
1236
+ const f = new Error(
1237
+ "The socket was closed while data was being compressed"
1238
+ );
1239
+ typeof i == "function" && i(f);
1240
+ for (let a = 0; a < this._queue.length; a++) {
1241
+ const c = this._queue[a], h = c[c.length - 1];
1242
+ typeof h == "function" && h(f);
1243
+ }
1244
+ return;
1245
+ }
1246
+ this._bufferedBytes -= r[x], this._deflating = !1, r.readOnly = !1, this.sendFrame(P.frame(l, r), i), this.dequeue();
1247
+ });
1248
+ }
1249
+ /**
1250
+ * Executes queued send operations.
1251
+ *
1252
+ * @private
1253
+ */
1254
+ dequeue() {
1255
+ for (; !this._deflating && this._queue.length; ) {
1256
+ const e = this._queue.shift();
1257
+ this._bufferedBytes -= e[3][x], Reflect.apply(e[0], this, e.slice(1));
1258
+ }
1259
+ }
1260
+ /**
1261
+ * Enqueues a send operation.
1262
+ *
1263
+ * @param {Array} params Send operation parameters.
1264
+ * @private
1265
+ */
1266
+ enqueue(e) {
1267
+ this._bufferedBytes += e[3][x], this._queue.push(e);
1268
+ }
1269
+ /**
1270
+ * Sends a frame.
1271
+ *
1272
+ * @param {Buffer[]} list The frame to send
1273
+ * @param {Function} [cb] Callback
1274
+ * @private
1275
+ */
1276
+ sendFrame(e, t) {
1277
+ e.length === 2 ? (this._socket.cork(), this._socket.write(e[0]), this._socket.write(e[1], t), this._socket.uncork()) : this._socket.write(e[0], t);
1278
+ }
1279
+ };
1280
+ var it = Jt;
1281
+ const Ks = /* @__PURE__ */ z(it), { kForOnEventAttribute: F, kListener: pe } = U, We = Symbol("kCode"), Ae = Symbol("kData"), Fe = Symbol("kError"), je = Symbol("kMessage"), Ge = Symbol("kReason"), I = Symbol("kTarget"), Ve = Symbol("kType"), He = Symbol("kWasClean");
1282
+ class B {
1283
+ /**
1284
+ * Create a new `Event`.
1285
+ *
1286
+ * @param {String} type The name of the event
1287
+ * @throws {TypeError} If the `type` argument is not specified
1288
+ */
1289
+ constructor(e) {
1290
+ this[I] = null, this[Ve] = e;
1291
+ }
1292
+ /**
1293
+ * @type {*}
1294
+ */
1295
+ get target() {
1296
+ return this[I];
1297
+ }
1298
+ /**
1299
+ * @type {String}
1300
+ */
1301
+ get type() {
1302
+ return this[Ve];
1303
+ }
1304
+ }
1305
+ Object.defineProperty(B.prototype, "target", { enumerable: !0 });
1306
+ Object.defineProperty(B.prototype, "type", { enumerable: !0 });
1307
+ class Y extends B {
1308
+ /**
1309
+ * Create a new `CloseEvent`.
1310
+ *
1311
+ * @param {String} type The name of the event
1312
+ * @param {Object} [options] A dictionary object that allows for setting
1313
+ * attributes via object members of the same name
1314
+ * @param {Number} [options.code=0] The status code explaining why the
1315
+ * connection was closed
1316
+ * @param {String} [options.reason=''] A human-readable string explaining why
1317
+ * the connection was closed
1318
+ * @param {Boolean} [options.wasClean=false] Indicates whether or not the
1319
+ * connection was cleanly closed
1320
+ */
1321
+ constructor(e, t = {}) {
1322
+ super(e), this[We] = t.code === void 0 ? 0 : t.code, this[Ge] = t.reason === void 0 ? "" : t.reason, this[He] = t.wasClean === void 0 ? !1 : t.wasClean;
1323
+ }
1324
+ /**
1325
+ * @type {Number}
1326
+ */
1327
+ get code() {
1328
+ return this[We];
1329
+ }
1330
+ /**
1331
+ * @type {String}
1332
+ */
1333
+ get reason() {
1334
+ return this[Ge];
1335
+ }
1336
+ /**
1337
+ * @type {Boolean}
1338
+ */
1339
+ get wasClean() {
1340
+ return this[He];
1341
+ }
1342
+ }
1343
+ Object.defineProperty(Y.prototype, "code", { enumerable: !0 });
1344
+ Object.defineProperty(Y.prototype, "reason", { enumerable: !0 });
1345
+ Object.defineProperty(Y.prototype, "wasClean", { enumerable: !0 });
1346
+ class le extends B {
1347
+ /**
1348
+ * Create a new `ErrorEvent`.
1349
+ *
1350
+ * @param {String} type The name of the event
1351
+ * @param {Object} [options] A dictionary object that allows for setting
1352
+ * attributes via object members of the same name
1353
+ * @param {*} [options.error=null] The error that generated this event
1354
+ * @param {String} [options.message=''] The error message
1355
+ */
1356
+ constructor(e, t = {}) {
1357
+ super(e), this[Fe] = t.error === void 0 ? null : t.error, this[je] = t.message === void 0 ? "" : t.message;
1358
+ }
1359
+ /**
1360
+ * @type {*}
1361
+ */
1362
+ get error() {
1363
+ return this[Fe];
1364
+ }
1365
+ /**
1366
+ * @type {String}
1367
+ */
1368
+ get message() {
1369
+ return this[je];
1370
+ }
1371
+ }
1372
+ Object.defineProperty(le.prototype, "error", { enumerable: !0 });
1373
+ Object.defineProperty(le.prototype, "message", { enumerable: !0 });
1374
+ class xe extends B {
1375
+ /**
1376
+ * Create a new `MessageEvent`.
1377
+ *
1378
+ * @param {String} type The name of the event
1379
+ * @param {Object} [options] A dictionary object that allows for setting
1380
+ * attributes via object members of the same name
1381
+ * @param {*} [options.data=null] The message content
1382
+ */
1383
+ constructor(e, t = {}) {
1384
+ super(e), this[Ae] = t.data === void 0 ? null : t.data;
1385
+ }
1386
+ /**
1387
+ * @type {*}
1388
+ */
1389
+ get data() {
1390
+ return this[Ae];
1391
+ }
1392
+ }
1393
+ Object.defineProperty(xe.prototype, "data", { enumerable: !0 });
1394
+ const es = {
1395
+ /**
1396
+ * Register an event listener.
1397
+ *
1398
+ * @param {String} type A string representing the event type to listen for
1399
+ * @param {(Function|Object)} handler The listener to add
1400
+ * @param {Object} [options] An options object specifies characteristics about
1401
+ * the event listener
1402
+ * @param {Boolean} [options.once=false] A `Boolean` indicating that the
1403
+ * listener should be invoked at most once after being added. If `true`,
1404
+ * the listener would be automatically removed when invoked.
1405
+ * @public
1406
+ */
1407
+ addEventListener(s, e, t = {}) {
1408
+ for (const i of this.listeners(s))
1409
+ if (!t[F] && i[pe] === e && !i[F])
1410
+ return;
1411
+ let r;
1412
+ if (s === "message")
1413
+ r = function(n, o) {
1414
+ const l = new xe("message", {
1415
+ data: o ? n : n.toString()
1416
+ });
1417
+ l[I] = this, Z(e, this, l);
1418
+ };
1419
+ else if (s === "close")
1420
+ r = function(n, o) {
1421
+ const l = new Y("close", {
1422
+ code: n,
1423
+ reason: o.toString(),
1424
+ wasClean: this._closeFrameReceived && this._closeFrameSent
1425
+ });
1426
+ l[I] = this, Z(e, this, l);
1427
+ };
1428
+ else if (s === "error")
1429
+ r = function(n) {
1430
+ const o = new le("error", {
1431
+ error: n,
1432
+ message: n.message
1433
+ });
1434
+ o[I] = this, Z(e, this, o);
1435
+ };
1436
+ else if (s === "open")
1437
+ r = function() {
1438
+ const n = new B("open");
1439
+ n[I] = this, Z(e, this, n);
1440
+ };
1441
+ else
1442
+ return;
1443
+ r[F] = !!t[F], r[pe] = e, t.once ? this.once(s, r) : this.on(s, r);
1444
+ },
1445
+ /**
1446
+ * Remove an event listener.
1447
+ *
1448
+ * @param {String} type A string representing the event type to remove
1449
+ * @param {(Function|Object)} handler The listener to remove
1450
+ * @public
1451
+ */
1452
+ removeEventListener(s, e) {
1453
+ for (const t of this.listeners(s))
1454
+ if (t[pe] === e && !t[F]) {
1455
+ this.removeListener(s, t);
1456
+ break;
1457
+ }
1458
+ }
1459
+ };
1460
+ var ts = {
1461
+ CloseEvent: Y,
1462
+ ErrorEvent: le,
1463
+ Event: B,
1464
+ EventTarget: es,
1465
+ MessageEvent: xe
1466
+ };
1467
+ function Z(s, e, t) {
1468
+ typeof s == "object" && s.handleEvent ? s.handleEvent.call(s, t) : s.call(e, t);
1469
+ }
1470
+ const { tokenChars: j } = ae;
1471
+ function k(s, e, t) {
1472
+ s[e] === void 0 ? s[e] = [t] : s[e].push(t);
1473
+ }
1474
+ function ss(s) {
1475
+ const e = /* @__PURE__ */ Object.create(null);
1476
+ let t = /* @__PURE__ */ Object.create(null), r = !1, i = !1, n = !1, o, l, f = -1, a = -1, c = -1, h = 0;
1477
+ for (; h < s.length; h++)
1478
+ if (a = s.charCodeAt(h), o === void 0)
1479
+ if (c === -1 && j[a] === 1)
1480
+ f === -1 && (f = h);
1481
+ else if (h !== 0 && (a === 32 || a === 9))
1482
+ c === -1 && f !== -1 && (c = h);
1483
+ else if (a === 59 || a === 44) {
1484
+ if (f === -1)
1485
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1486
+ c === -1 && (c = h);
1487
+ const v = s.slice(f, c);
1488
+ a === 44 ? (k(e, v, t), t = /* @__PURE__ */ Object.create(null)) : o = v, f = c = -1;
1489
+ } else
1490
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1491
+ else if (l === void 0)
1492
+ if (c === -1 && j[a] === 1)
1493
+ f === -1 && (f = h);
1494
+ else if (a === 32 || a === 9)
1495
+ c === -1 && f !== -1 && (c = h);
1496
+ else if (a === 59 || a === 44) {
1497
+ if (f === -1)
1498
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1499
+ c === -1 && (c = h), k(t, s.slice(f, c), !0), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), f = c = -1;
1500
+ } else if (a === 61 && f !== -1 && c === -1)
1501
+ l = s.slice(f, h), f = c = -1;
1502
+ else
1503
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1504
+ else if (i) {
1505
+ if (j[a] !== 1)
1506
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1507
+ f === -1 ? f = h : r || (r = !0), i = !1;
1508
+ } else if (n)
1509
+ if (j[a] === 1)
1510
+ f === -1 && (f = h);
1511
+ else if (a === 34 && f !== -1)
1512
+ n = !1, c = h;
1513
+ else if (a === 92)
1514
+ i = !0;
1515
+ else
1516
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1517
+ else if (a === 34 && s.charCodeAt(h - 1) === 61)
1518
+ n = !0;
1519
+ else if (c === -1 && j[a] === 1)
1520
+ f === -1 && (f = h);
1521
+ else if (f !== -1 && (a === 32 || a === 9))
1522
+ c === -1 && (c = h);
1523
+ else if (a === 59 || a === 44) {
1524
+ if (f === -1)
1525
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1526
+ c === -1 && (c = h);
1527
+ let v = s.slice(f, c);
1528
+ r && (v = v.replace(/\\/g, ""), r = !1), k(t, l, v), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), l = void 0, f = c = -1;
1529
+ } else
1530
+ throw new SyntaxError(`Unexpected character at index ${h}`);
1531
+ if (f === -1 || n || a === 32 || a === 9)
1532
+ throw new SyntaxError("Unexpected end of input");
1533
+ c === -1 && (c = h);
1534
+ const p = s.slice(f, c);
1535
+ return o === void 0 ? k(e, p, t) : (l === void 0 ? k(t, p, !0) : r ? k(t, l, p.replace(/\\/g, "")) : k(t, l, p), k(e, o, t)), e;
1536
+ }
1537
+ function rs(s) {
1538
+ return Object.keys(s).map((e) => {
1539
+ let t = s[e];
1540
+ return Array.isArray(t) || (t = [t]), t.map((r) => [e].concat(
1541
+ Object.keys(r).map((i) => {
1542
+ let n = r[i];
1543
+ return Array.isArray(n) || (n = [n]), n.map((o) => o === !0 ? i : `${i}=${o}`).join("; ");
1544
+ })
1545
+ ).join("; ")).join(", ");
1546
+ }).join(", ");
1547
+ }
1548
+ var nt = { format: rs, parse: ss };
1549
+ const is = S, ns = S, os = S, ot = S, as = S, { randomBytes: ls, createHash: fs } = S, { URL: me } = S, T = oe, hs = rt, cs = it, {
1550
+ BINARY_TYPES: ze,
1551
+ EMPTY_BUFFER: Q,
1552
+ GUID: us,
1553
+ kForOnEventAttribute: ge,
1554
+ kListener: ds,
1555
+ kStatusCode: _s,
1556
+ kWebSocket: y,
1557
+ NOOP: at
1558
+ } = U, {
1559
+ EventTarget: { addEventListener: ps, removeEventListener: ms }
1560
+ } = ts, { format: gs, parse: ys } = nt, { toBuffer: vs } = ne, Ss = 30 * 1e3, lt = Symbol("kAborted"), ye = [8, 13], O = ["CONNECTING", "OPEN", "CLOSING", "CLOSED"], Es = /^[!#$%&'*+\-.0-9A-Z^_`|a-z~]+$/;
1561
+ let m = class d extends is {
1562
+ /**
1563
+ * Create a new `WebSocket`.
1564
+ *
1565
+ * @param {(String|URL)} address The URL to which to connect
1566
+ * @param {(String|String[])} [protocols] The subprotocols
1567
+ * @param {Object} [options] Connection options
1568
+ */
1569
+ constructor(e, t, r) {
1570
+ super(), this._binaryType = ze[0], this._closeCode = 1006, this._closeFrameReceived = !1, this._closeFrameSent = !1, this._closeMessage = Q, this._closeTimer = null, this._extensions = {}, this._paused = !1, this._protocol = "", this._readyState = d.CONNECTING, this._receiver = null, this._sender = null, this._socket = null, e !== null ? (this._bufferedAmount = 0, this._isServer = !1, this._redirects = 0, t === void 0 ? t = [] : Array.isArray(t) || (typeof t == "object" && t !== null ? (r = t, t = []) : t = [t]), ht(this, e, t, r)) : this._isServer = !0;
1571
+ }
1572
+ /**
1573
+ * This deviates from the WHATWG interface since ws doesn't support the
1574
+ * required default "blob" type (instead we define a custom "nodebuffer"
1575
+ * type).
1576
+ *
1577
+ * @type {String}
1578
+ */
1579
+ get binaryType() {
1580
+ return this._binaryType;
1581
+ }
1582
+ set binaryType(e) {
1583
+ ze.includes(e) && (this._binaryType = e, this._receiver && (this._receiver._binaryType = e));
1584
+ }
1585
+ /**
1586
+ * @type {Number}
1587
+ */
1588
+ get bufferedAmount() {
1589
+ return this._socket ? this._socket._writableState.length + this._sender._bufferedBytes : this._bufferedAmount;
1590
+ }
1591
+ /**
1592
+ * @type {String}
1593
+ */
1594
+ get extensions() {
1595
+ return Object.keys(this._extensions).join();
1596
+ }
1597
+ /**
1598
+ * @type {Boolean}
1599
+ */
1600
+ get isPaused() {
1601
+ return this._paused;
1602
+ }
1603
+ /**
1604
+ * @type {Function}
1605
+ */
1606
+ /* istanbul ignore next */
1607
+ get onclose() {
1608
+ return null;
1609
+ }
1610
+ /**
1611
+ * @type {Function}
1612
+ */
1613
+ /* istanbul ignore next */
1614
+ get onerror() {
1615
+ return null;
1616
+ }
1617
+ /**
1618
+ * @type {Function}
1619
+ */
1620
+ /* istanbul ignore next */
1621
+ get onopen() {
1622
+ return null;
1623
+ }
1624
+ /**
1625
+ * @type {Function}
1626
+ */
1627
+ /* istanbul ignore next */
1628
+ get onmessage() {
1629
+ return null;
1630
+ }
1631
+ /**
1632
+ * @type {String}
1633
+ */
1634
+ get protocol() {
1635
+ return this._protocol;
1636
+ }
1637
+ /**
1638
+ * @type {Number}
1639
+ */
1640
+ get readyState() {
1641
+ return this._readyState;
1642
+ }
1643
+ /**
1644
+ * @type {String}
1645
+ */
1646
+ get url() {
1647
+ return this._url;
1648
+ }
1649
+ /**
1650
+ * Set up the socket and the internal resources.
1651
+ *
1652
+ * @param {(net.Socket|tls.Socket)} socket The network socket between the
1653
+ * server and client
1654
+ * @param {Buffer} head The first packet of the upgraded stream
1655
+ * @param {Object} options Options object
1656
+ * @param {Function} [options.generateMask] The function used to generate the
1657
+ * masking key
1658
+ * @param {Number} [options.maxPayload=0] The maximum allowed message size
1659
+ * @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
1660
+ * not to skip UTF-8 validation for text and close messages
1661
+ * @private
1662
+ */
1663
+ setSocket(e, t, r) {
1664
+ const i = new hs({
1665
+ binaryType: this.binaryType,
1666
+ extensions: this._extensions,
1667
+ isServer: this._isServer,
1668
+ maxPayload: r.maxPayload,
1669
+ skipUTF8Validation: r.skipUTF8Validation
1670
+ });
1671
+ this._sender = new cs(e, this._extensions, r.generateMask), this._receiver = i, this._socket = e, i[y] = this, e[y] = this, i.on("conclude", ks), i.on("drain", ws), i.on("error", Os), i.on("message", Cs), i.on("ping", Ts), i.on("pong", Ls), e.setTimeout(0), e.setNoDelay(), t.length > 0 && e.unshift(t), e.on("close", ut), e.on("data", fe), e.on("end", dt), e.on("error", _t), this._readyState = d.OPEN, this.emit("open");
1672
+ }
1673
+ /**
1674
+ * Emit the `'close'` event.
1675
+ *
1676
+ * @private
1677
+ */
1678
+ emitClose() {
1679
+ if (!this._socket) {
1680
+ this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
1681
+ return;
1682
+ }
1683
+ this._extensions[T.extensionName] && this._extensions[T.extensionName].cleanup(), this._receiver.removeAllListeners(), this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
1684
+ }
1685
+ /**
1686
+ * Start a closing handshake.
1687
+ *
1688
+ * +----------+ +-----------+ +----------+
1689
+ * - - -|ws.close()|-->|close frame|-->|ws.close()|- - -
1690
+ * | +----------+ +-----------+ +----------+ |
1691
+ * +----------+ +-----------+ |
1692
+ * CLOSING |ws.close()|<--|close frame|<--+-----+ CLOSING
1693
+ * +----------+ +-----------+ |
1694
+ * | | | +---+ |
1695
+ * +------------------------+-->|fin| - - - -
1696
+ * | +---+ | +---+
1697
+ * - - - - -|fin|<---------------------+
1698
+ * +---+
1699
+ *
1700
+ * @param {Number} [code] Status code explaining why the connection is closing
1701
+ * @param {(String|Buffer)} [data] The reason why the connection is
1702
+ * closing
1703
+ * @public
1704
+ */
1705
+ close(e, t) {
1706
+ if (this.readyState !== d.CLOSED) {
1707
+ if (this.readyState === d.CONNECTING) {
1708
+ const r = "WebSocket was closed before the connection was established";
1709
+ b(this, this._req, r);
1710
+ return;
1711
+ }
1712
+ if (this.readyState === d.CLOSING) {
1713
+ this._closeFrameSent && (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end();
1714
+ return;
1715
+ }
1716
+ this._readyState = d.CLOSING, this._sender.close(e, t, !this._isServer, (r) => {
1717
+ r || (this._closeFrameSent = !0, (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end());
1718
+ }), this._closeTimer = setTimeout(
1719
+ this._socket.destroy.bind(this._socket),
1720
+ Ss
1721
+ );
1722
+ }
1723
+ }
1724
+ /**
1725
+ * Pause the socket.
1726
+ *
1727
+ * @public
1728
+ */
1729
+ pause() {
1730
+ this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !0, this._socket.pause());
1731
+ }
1732
+ /**
1733
+ * Send a ping.
1734
+ *
1735
+ * @param {*} [data] The data to send
1736
+ * @param {Boolean} [mask] Indicates whether or not to mask `data`
1737
+ * @param {Function} [cb] Callback which is executed when the ping is sent
1738
+ * @public
1739
+ */
1740
+ ping(e, t, r) {
1741
+ if (this.readyState === d.CONNECTING)
1742
+ throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
1743
+ if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
1744
+ ve(this, e, r);
1745
+ return;
1746
+ }
1747
+ t === void 0 && (t = !this._isServer), this._sender.ping(e || Q, t, r);
1748
+ }
1749
+ /**
1750
+ * Send a pong.
1751
+ *
1752
+ * @param {*} [data] The data to send
1753
+ * @param {Boolean} [mask] Indicates whether or not to mask `data`
1754
+ * @param {Function} [cb] Callback which is executed when the pong is sent
1755
+ * @public
1756
+ */
1757
+ pong(e, t, r) {
1758
+ if (this.readyState === d.CONNECTING)
1759
+ throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
1760
+ if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
1761
+ ve(this, e, r);
1762
+ return;
1763
+ }
1764
+ t === void 0 && (t = !this._isServer), this._sender.pong(e || Q, t, r);
1765
+ }
1766
+ /**
1767
+ * Resume the socket.
1768
+ *
1769
+ * @public
1770
+ */
1771
+ resume() {
1772
+ this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !1, this._receiver._writableState.needDrain || this._socket.resume());
1773
+ }
1774
+ /**
1775
+ * Send a data message.
1776
+ *
1777
+ * @param {*} data The message to send
1778
+ * @param {Object} [options] Options object
1779
+ * @param {Boolean} [options.binary] Specifies whether `data` is binary or
1780
+ * text
1781
+ * @param {Boolean} [options.compress] Specifies whether or not to compress
1782
+ * `data`
1783
+ * @param {Boolean} [options.fin=true] Specifies whether the fragment is the
1784
+ * last one
1785
+ * @param {Boolean} [options.mask] Specifies whether or not to mask `data`
1786
+ * @param {Function} [cb] Callback which is executed when data is written out
1787
+ * @public
1788
+ */
1789
+ send(e, t, r) {
1790
+ if (this.readyState === d.CONNECTING)
1791
+ throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
1792
+ if (typeof t == "function" && (r = t, t = {}), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
1793
+ ve(this, e, r);
1794
+ return;
1795
+ }
1796
+ const i = {
1797
+ binary: typeof e != "string",
1798
+ mask: !this._isServer,
1799
+ compress: !0,
1800
+ fin: !0,
1801
+ ...t
1802
+ };
1803
+ this._extensions[T.extensionName] || (i.compress = !1), this._sender.send(e || Q, i, r);
1804
+ }
1805
+ /**
1806
+ * Forcibly close the connection.
1807
+ *
1808
+ * @public
1809
+ */
1810
+ terminate() {
1811
+ if (this.readyState !== d.CLOSED) {
1812
+ if (this.readyState === d.CONNECTING) {
1813
+ const e = "WebSocket was closed before the connection was established";
1814
+ b(this, this._req, e);
1815
+ return;
1816
+ }
1817
+ this._socket && (this._readyState = d.CLOSING, this._socket.destroy());
1818
+ }
1819
+ }
1820
+ };
1821
+ Object.defineProperty(m, "CONNECTING", {
1822
+ enumerable: !0,
1823
+ value: O.indexOf("CONNECTING")
1824
+ });
1825
+ Object.defineProperty(m.prototype, "CONNECTING", {
1826
+ enumerable: !0,
1827
+ value: O.indexOf("CONNECTING")
1828
+ });
1829
+ Object.defineProperty(m, "OPEN", {
1830
+ enumerable: !0,
1831
+ value: O.indexOf("OPEN")
1832
+ });
1833
+ Object.defineProperty(m.prototype, "OPEN", {
1834
+ enumerable: !0,
1835
+ value: O.indexOf("OPEN")
1836
+ });
1837
+ Object.defineProperty(m, "CLOSING", {
1838
+ enumerable: !0,
1839
+ value: O.indexOf("CLOSING")
1840
+ });
1841
+ Object.defineProperty(m.prototype, "CLOSING", {
1842
+ enumerable: !0,
1843
+ value: O.indexOf("CLOSING")
1844
+ });
1845
+ Object.defineProperty(m, "CLOSED", {
1846
+ enumerable: !0,
1847
+ value: O.indexOf("CLOSED")
1848
+ });
1849
+ Object.defineProperty(m.prototype, "CLOSED", {
1850
+ enumerable: !0,
1851
+ value: O.indexOf("CLOSED")
1852
+ });
1853
+ [
1854
+ "binaryType",
1855
+ "bufferedAmount",
1856
+ "extensions",
1857
+ "isPaused",
1858
+ "protocol",
1859
+ "readyState",
1860
+ "url"
1861
+ ].forEach((s) => {
1862
+ Object.defineProperty(m.prototype, s, { enumerable: !0 });
1863
+ });
1864
+ ["open", "error", "close", "message"].forEach((s) => {
1865
+ Object.defineProperty(m.prototype, `on${s}`, {
1866
+ enumerable: !0,
1867
+ get() {
1868
+ for (const e of this.listeners(s))
1869
+ if (e[ge])
1870
+ return e[ds];
1871
+ return null;
1872
+ },
1873
+ set(e) {
1874
+ for (const t of this.listeners(s))
1875
+ if (t[ge]) {
1876
+ this.removeListener(s, t);
1877
+ break;
1878
+ }
1879
+ typeof e == "function" && this.addEventListener(s, e, {
1880
+ [ge]: !0
1881
+ });
1882
+ }
1883
+ });
1884
+ });
1885
+ m.prototype.addEventListener = ps;
1886
+ m.prototype.removeEventListener = ms;
1887
+ var ft = m;
1888
+ function ht(s, e, t, r) {
1889
+ const i = {
1890
+ protocolVersion: ye[1],
1891
+ maxPayload: 104857600,
1892
+ skipUTF8Validation: !1,
1893
+ perMessageDeflate: !0,
1894
+ followRedirects: !1,
1895
+ maxRedirects: 10,
1896
+ ...r,
1897
+ createConnection: void 0,
1898
+ socketPath: void 0,
1899
+ hostname: void 0,
1900
+ protocol: void 0,
1901
+ timeout: void 0,
1902
+ method: "GET",
1903
+ host: void 0,
1904
+ path: void 0,
1905
+ port: void 0
1906
+ };
1907
+ if (!ye.includes(i.protocolVersion))
1908
+ throw new RangeError(
1909
+ `Unsupported protocol version: ${i.protocolVersion} (supported versions: ${ye.join(", ")})`
1910
+ );
1911
+ let n;
1912
+ if (e instanceof me)
1913
+ n = e, s._url = e.href;
1914
+ else {
1915
+ try {
1916
+ n = new me(e);
1917
+ } catch {
1918
+ throw new SyntaxError(`Invalid URL: ${e}`);
1919
+ }
1920
+ s._url = e;
1921
+ }
1922
+ const o = n.protocol === "wss:", l = n.protocol === "ws+unix:";
1923
+ let f;
1924
+ if (n.protocol !== "ws:" && !o && !l ? f = `The URL's protocol must be one of "ws:", "wss:", or "ws+unix:"` : l && !n.pathname ? f = "The URL's pathname is empty" : n.hash && (f = "The URL contains a fragment identifier"), f) {
1925
+ const u = new SyntaxError(f);
1926
+ if (s._redirects === 0)
1927
+ throw u;
1928
+ ee(s, u);
1929
+ return;
1930
+ }
1931
+ const a = o ? 443 : 80, c = ls(16).toString("base64"), h = o ? ns.request : os.request, p = /* @__PURE__ */ new Set();
1932
+ let v;
1933
+ if (i.createConnection = o ? xs : bs, i.defaultPort = i.defaultPort || a, i.port = n.port || a, i.host = n.hostname.startsWith("[") ? n.hostname.slice(1, -1) : n.hostname, i.headers = {
1934
+ ...i.headers,
1935
+ "Sec-WebSocket-Version": i.protocolVersion,
1936
+ "Sec-WebSocket-Key": c,
1937
+ Connection: "Upgrade",
1938
+ Upgrade: "websocket"
1939
+ }, i.path = n.pathname + n.search, i.timeout = i.handshakeTimeout, i.perMessageDeflate && (v = new T(
1940
+ i.perMessageDeflate !== !0 ? i.perMessageDeflate : {},
1941
+ !1,
1942
+ i.maxPayload
1943
+ ), i.headers["Sec-WebSocket-Extensions"] = gs({
1944
+ [T.extensionName]: v.offer()
1945
+ })), t.length) {
1946
+ for (const u of t) {
1947
+ if (typeof u != "string" || !Es.test(u) || p.has(u))
1948
+ throw new SyntaxError(
1949
+ "An invalid or duplicated subprotocol was specified"
1950
+ );
1951
+ p.add(u);
1952
+ }
1953
+ i.headers["Sec-WebSocket-Protocol"] = t.join(",");
1954
+ }
1955
+ if (i.origin && (i.protocolVersion < 13 ? i.headers["Sec-WebSocket-Origin"] = i.origin : i.headers.Origin = i.origin), (n.username || n.password) && (i.auth = `${n.username}:${n.password}`), l) {
1956
+ const u = i.path.split(":");
1957
+ i.socketPath = u[0], i.path = u[1];
1958
+ }
1959
+ let _;
1960
+ if (i.followRedirects) {
1961
+ if (s._redirects === 0) {
1962
+ s._originalIpc = l, s._originalSecure = o, s._originalHostOrSocketPath = l ? i.socketPath : n.host;
1963
+ const u = r && r.headers;
1964
+ if (r = { ...r, headers: {} }, u)
1965
+ for (const [E, $] of Object.entries(u))
1966
+ r.headers[E.toLowerCase()] = $;
1967
+ } else if (s.listenerCount("redirect") === 0) {
1968
+ const u = l ? s._originalIpc ? i.socketPath === s._originalHostOrSocketPath : !1 : s._originalIpc ? !1 : n.host === s._originalHostOrSocketPath;
1969
+ (!u || s._originalSecure && !o) && (delete i.headers.authorization, delete i.headers.cookie, u || delete i.headers.host, i.auth = void 0);
1970
+ }
1971
+ i.auth && !r.headers.authorization && (r.headers.authorization = "Basic " + Buffer.from(i.auth).toString("base64")), _ = s._req = h(i), s._redirects && s.emit("redirect", s.url, _);
1972
+ } else
1973
+ _ = s._req = h(i);
1974
+ i.timeout && _.on("timeout", () => {
1975
+ b(s, _, "Opening handshake has timed out");
1976
+ }), _.on("error", (u) => {
1977
+ _ === null || _[lt] || (_ = s._req = null, ee(s, u));
1978
+ }), _.on("response", (u) => {
1979
+ const E = u.headers.location, $ = u.statusCode;
1980
+ if (E && i.followRedirects && $ >= 300 && $ < 400) {
1981
+ if (++s._redirects > i.maxRedirects) {
1982
+ b(s, _, "Maximum redirects exceeded");
1983
+ return;
1984
+ }
1985
+ _.abort();
1986
+ let q;
1987
+ try {
1988
+ q = new me(E, e);
1989
+ } catch {
1990
+ const L = new SyntaxError(`Invalid URL: ${E}`);
1991
+ ee(s, L);
1992
+ return;
1993
+ }
1994
+ ht(s, q, t, r);
1995
+ } else
1996
+ s.emit("unexpected-response", _, u) || b(
1997
+ s,
1998
+ _,
1999
+ `Unexpected server response: ${u.statusCode}`
2000
+ );
2001
+ }), _.on("upgrade", (u, E, $) => {
2002
+ if (s.emit("upgrade", u), s.readyState !== m.CONNECTING)
2003
+ return;
2004
+ if (_ = s._req = null, u.headers.upgrade.toLowerCase() !== "websocket") {
2005
+ b(s, E, "Invalid Upgrade header");
2006
+ return;
2007
+ }
2008
+ const q = fs("sha1").update(c + us).digest("base64");
2009
+ if (u.headers["sec-websocket-accept"] !== q) {
2010
+ b(s, E, "Invalid Sec-WebSocket-Accept header");
2011
+ return;
2012
+ }
2013
+ const D = u.headers["sec-websocket-protocol"];
2014
+ let L;
2015
+ if (D !== void 0 ? p.size ? p.has(D) || (L = "Server sent an invalid subprotocol") : L = "Server sent a subprotocol but none was requested" : p.size && (L = "Server sent no subprotocol"), L) {
2016
+ b(s, E, L);
2017
+ return;
2018
+ }
2019
+ D && (s._protocol = D);
2020
+ const ke = u.headers["sec-websocket-extensions"];
2021
+ if (ke !== void 0) {
2022
+ if (!v) {
2023
+ b(s, E, "Server sent a Sec-WebSocket-Extensions header but no extension was requested");
2024
+ return;
2025
+ }
2026
+ let he;
2027
+ try {
2028
+ he = ys(ke);
2029
+ } catch {
2030
+ b(s, E, "Invalid Sec-WebSocket-Extensions header");
2031
+ return;
2032
+ }
2033
+ const we = Object.keys(he);
2034
+ if (we.length !== 1 || we[0] !== T.extensionName) {
2035
+ b(s, E, "Server indicated an extension that was not requested");
2036
+ return;
2037
+ }
2038
+ try {
2039
+ v.accept(he[T.extensionName]);
2040
+ } catch {
2041
+ b(s, E, "Invalid Sec-WebSocket-Extensions header");
2042
+ return;
2043
+ }
2044
+ s._extensions[T.extensionName] = v;
2045
+ }
2046
+ s.setSocket(E, $, {
2047
+ generateMask: i.generateMask,
2048
+ maxPayload: i.maxPayload,
2049
+ skipUTF8Validation: i.skipUTF8Validation
2050
+ });
2051
+ }), i.finishRequest ? i.finishRequest(_, s) : _.end();
2052
+ }
2053
+ function ee(s, e) {
2054
+ s._readyState = m.CLOSING, s.emit("error", e), s.emitClose();
2055
+ }
2056
+ function bs(s) {
2057
+ return s.path = s.socketPath, ot.connect(s);
2058
+ }
2059
+ function xs(s) {
2060
+ return s.path = void 0, !s.servername && s.servername !== "" && (s.servername = ot.isIP(s.host) ? "" : s.host), as.connect(s);
2061
+ }
2062
+ function b(s, e, t) {
2063
+ s._readyState = m.CLOSING;
2064
+ const r = new Error(t);
2065
+ Error.captureStackTrace(r, b), e.setHeader ? (e[lt] = !0, e.abort(), e.socket && !e.socket.destroyed && e.socket.destroy(), process.nextTick(ee, s, r)) : (e.destroy(r), e.once("error", s.emit.bind(s, "error")), e.once("close", s.emitClose.bind(s)));
2066
+ }
2067
+ function ve(s, e, t) {
2068
+ if (e) {
2069
+ const r = vs(e).length;
2070
+ s._socket ? s._sender._bufferedBytes += r : s._bufferedAmount += r;
2071
+ }
2072
+ if (t) {
2073
+ const r = new Error(
2074
+ `WebSocket is not open: readyState ${s.readyState} (${O[s.readyState]})`
2075
+ );
2076
+ process.nextTick(t, r);
2077
+ }
2078
+ }
2079
+ function ks(s, e) {
2080
+ const t = this[y];
2081
+ t._closeFrameReceived = !0, t._closeMessage = e, t._closeCode = s, t._socket[y] !== void 0 && (t._socket.removeListener("data", fe), process.nextTick(ct, t._socket), s === 1005 ? t.close() : t.close(s, e));
2082
+ }
2083
+ function ws() {
2084
+ const s = this[y];
2085
+ s.isPaused || s._socket.resume();
2086
+ }
2087
+ function Os(s) {
2088
+ const e = this[y];
2089
+ e._socket[y] !== void 0 && (e._socket.removeListener("data", fe), process.nextTick(ct, e._socket), e.close(s[_s])), e.emit("error", s);
2090
+ }
2091
+ function Ye() {
2092
+ this[y].emitClose();
2093
+ }
2094
+ function Cs(s, e) {
2095
+ this[y].emit("message", s, e);
2096
+ }
2097
+ function Ts(s) {
2098
+ const e = this[y];
2099
+ e.pong(s, !e._isServer, at), e.emit("ping", s);
2100
+ }
2101
+ function Ls(s) {
2102
+ this[y].emit("pong", s);
2103
+ }
2104
+ function ct(s) {
2105
+ s.resume();
2106
+ }
2107
+ function ut() {
2108
+ const s = this[y];
2109
+ this.removeListener("close", ut), this.removeListener("data", fe), this.removeListener("end", dt), s._readyState = m.CLOSING;
2110
+ let e;
2111
+ !this._readableState.endEmitted && !s._closeFrameReceived && !s._receiver._writableState.errorEmitted && (e = s._socket.read()) !== null && s._receiver.write(e), s._receiver.end(), this[y] = void 0, clearTimeout(s._closeTimer), s._receiver._writableState.finished || s._receiver._writableState.errorEmitted ? s.emitClose() : (s._receiver.on("error", Ye), s._receiver.on("finish", Ye));
2112
+ }
2113
+ function fe(s) {
2114
+ this[y]._receiver.write(s) || this.pause();
2115
+ }
2116
+ function dt() {
2117
+ const s = this[y];
2118
+ s._readyState = m.CLOSING, s._receiver.end(), this.end();
2119
+ }
2120
+ function _t() {
2121
+ const s = this[y];
2122
+ this.removeListener("error", _t), this.on("error", at), s && (s._readyState = m.CLOSING, this.destroy());
2123
+ }
2124
+ const Xs = /* @__PURE__ */ z(ft), { tokenChars: Ns } = ae;
2125
+ function Ps(s) {
2126
+ const e = /* @__PURE__ */ new Set();
2127
+ let t = -1, r = -1, i = 0;
2128
+ for (i; i < s.length; i++) {
2129
+ const o = s.charCodeAt(i);
2130
+ if (r === -1 && Ns[o] === 1)
2131
+ t === -1 && (t = i);
2132
+ else if (i !== 0 && (o === 32 || o === 9))
2133
+ r === -1 && t !== -1 && (r = i);
2134
+ else if (o === 44) {
2135
+ if (t === -1)
2136
+ throw new SyntaxError(`Unexpected character at index ${i}`);
2137
+ r === -1 && (r = i);
2138
+ const l = s.slice(t, r);
2139
+ if (e.has(l))
2140
+ throw new SyntaxError(`The "${l}" subprotocol is duplicated`);
2141
+ e.add(l), t = r = -1;
2142
+ } else
2143
+ throw new SyntaxError(`Unexpected character at index ${i}`);
2144
+ }
2145
+ if (t === -1 || r !== -1)
2146
+ throw new SyntaxError("Unexpected end of input");
2147
+ const n = s.slice(t, i);
2148
+ if (e.has(n))
2149
+ throw new SyntaxError(`The "${n}" subprotocol is duplicated`);
2150
+ return e.add(n), e;
2151
+ }
2152
+ var Rs = { parse: Ps };
2153
+ const Us = S, ie = S, { createHash: Bs } = S, qe = nt, N = oe, $s = Rs, Ms = ft, { GUID: Is, kWebSocket: Ds } = U, Ws = /^[+/0-9A-Za-z]{22}==$/, Ke = 0, Xe = 1, pt = 2;
2154
+ class As extends Us {
2155
+ /**
2156
+ * Create a `WebSocketServer` instance.
2157
+ *
2158
+ * @param {Object} options Configuration options
2159
+ * @param {Number} [options.backlog=511] The maximum length of the queue of
2160
+ * pending connections
2161
+ * @param {Boolean} [options.clientTracking=true] Specifies whether or not to
2162
+ * track clients
2163
+ * @param {Function} [options.handleProtocols] A hook to handle protocols
2164
+ * @param {String} [options.host] The hostname where to bind the server
2165
+ * @param {Number} [options.maxPayload=104857600] The maximum allowed message
2166
+ * size
2167
+ * @param {Boolean} [options.noServer=false] Enable no server mode
2168
+ * @param {String} [options.path] Accept only connections matching this path
2169
+ * @param {(Boolean|Object)} [options.perMessageDeflate=false] Enable/disable
2170
+ * permessage-deflate
2171
+ * @param {Number} [options.port] The port where to bind the server
2172
+ * @param {(http.Server|https.Server)} [options.server] A pre-created HTTP/S
2173
+ * server to use
2174
+ * @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
2175
+ * not to skip UTF-8 validation for text and close messages
2176
+ * @param {Function} [options.verifyClient] A hook to reject connections
2177
+ * @param {Function} [options.WebSocket=WebSocket] Specifies the `WebSocket`
2178
+ * class to use. It must be the `WebSocket` class or class that extends it
2179
+ * @param {Function} [callback] A listener for the `listening` event
2180
+ */
2181
+ constructor(e, t) {
2182
+ if (super(), e = {
2183
+ maxPayload: 100 * 1024 * 1024,
2184
+ skipUTF8Validation: !1,
2185
+ perMessageDeflate: !1,
2186
+ handleProtocols: null,
2187
+ clientTracking: !0,
2188
+ verifyClient: null,
2189
+ noServer: !1,
2190
+ backlog: null,
2191
+ // use default (511 as implemented in net.js)
2192
+ server: null,
2193
+ host: null,
2194
+ path: null,
2195
+ port: null,
2196
+ WebSocket: Ms,
2197
+ ...e
2198
+ }, e.port == null && !e.server && !e.noServer || e.port != null && (e.server || e.noServer) || e.server && e.noServer)
2199
+ throw new TypeError(
2200
+ 'One and only one of the "port", "server", or "noServer" options must be specified'
2201
+ );
2202
+ if (e.port != null ? (this._server = ie.createServer((r, i) => {
2203
+ const n = ie.STATUS_CODES[426];
2204
+ i.writeHead(426, {
2205
+ "Content-Length": n.length,
2206
+ "Content-Type": "text/plain"
2207
+ }), i.end(n);
2208
+ }), this._server.listen(
2209
+ e.port,
2210
+ e.host,
2211
+ e.backlog,
2212
+ t
2213
+ )) : e.server && (this._server = e.server), this._server) {
2214
+ const r = this.emit.bind(this, "connection");
2215
+ this._removeListeners = js(this._server, {
2216
+ listening: this.emit.bind(this, "listening"),
2217
+ error: this.emit.bind(this, "error"),
2218
+ upgrade: (i, n, o) => {
2219
+ this.handleUpgrade(i, n, o, r);
2220
+ }
2221
+ });
2222
+ }
2223
+ e.perMessageDeflate === !0 && (e.perMessageDeflate = {}), e.clientTracking && (this.clients = /* @__PURE__ */ new Set(), this._shouldEmitClose = !1), this.options = e, this._state = Ke;
2224
+ }
2225
+ /**
2226
+ * Returns the bound address, the address family name, and port of the server
2227
+ * as reported by the operating system if listening on an IP socket.
2228
+ * If the server is listening on a pipe or UNIX domain socket, the name is
2229
+ * returned as a string.
2230
+ *
2231
+ * @return {(Object|String|null)} The address of the server
2232
+ * @public
2233
+ */
2234
+ address() {
2235
+ if (this.options.noServer)
2236
+ throw new Error('The server is operating in "noServer" mode');
2237
+ return this._server ? this._server.address() : null;
2238
+ }
2239
+ /**
2240
+ * Stop the server from accepting new connections and emit the `'close'` event
2241
+ * when all existing connections are closed.
2242
+ *
2243
+ * @param {Function} [cb] A one-time listener for the `'close'` event
2244
+ * @public
2245
+ */
2246
+ close(e) {
2247
+ if (this._state === pt) {
2248
+ e && this.once("close", () => {
2249
+ e(new Error("The server is not running"));
2250
+ }), process.nextTick(G, this);
2251
+ return;
2252
+ }
2253
+ if (e && this.once("close", e), this._state !== Xe)
2254
+ if (this._state = Xe, this.options.noServer || this.options.server)
2255
+ this._server && (this._removeListeners(), this._removeListeners = this._server = null), this.clients ? this.clients.size ? this._shouldEmitClose = !0 : process.nextTick(G, this) : process.nextTick(G, this);
2256
+ else {
2257
+ const t = this._server;
2258
+ this._removeListeners(), this._removeListeners = this._server = null, t.close(() => {
2259
+ G(this);
2260
+ });
2261
+ }
2262
+ }
2263
+ /**
2264
+ * See if a given request should be handled by this server instance.
2265
+ *
2266
+ * @param {http.IncomingMessage} req Request object to inspect
2267
+ * @return {Boolean} `true` if the request is valid, else `false`
2268
+ * @public
2269
+ */
2270
+ shouldHandle(e) {
2271
+ if (this.options.path) {
2272
+ const t = e.url.indexOf("?");
2273
+ if ((t !== -1 ? e.url.slice(0, t) : e.url) !== this.options.path)
2274
+ return !1;
2275
+ }
2276
+ return !0;
2277
+ }
2278
+ /**
2279
+ * Handle a HTTP Upgrade request.
2280
+ *
2281
+ * @param {http.IncomingMessage} req The request object
2282
+ * @param {(net.Socket|tls.Socket)} socket The network socket between the
2283
+ * server and client
2284
+ * @param {Buffer} head The first packet of the upgraded stream
2285
+ * @param {Function} cb Callback
2286
+ * @public
2287
+ */
2288
+ handleUpgrade(e, t, r, i) {
2289
+ t.on("error", Ze);
2290
+ const n = e.headers["sec-websocket-key"], o = +e.headers["sec-websocket-version"];
2291
+ if (e.method !== "GET") {
2292
+ R(this, e, t, 405, "Invalid HTTP method");
2293
+ return;
2294
+ }
2295
+ if (e.headers.upgrade.toLowerCase() !== "websocket") {
2296
+ R(this, e, t, 400, "Invalid Upgrade header");
2297
+ return;
2298
+ }
2299
+ if (!n || !Ws.test(n)) {
2300
+ R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Key header");
2301
+ return;
2302
+ }
2303
+ if (o !== 8 && o !== 13) {
2304
+ R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Version header");
2305
+ return;
2306
+ }
2307
+ if (!this.shouldHandle(e)) {
2308
+ H(t, 400);
2309
+ return;
2310
+ }
2311
+ const l = e.headers["sec-websocket-protocol"];
2312
+ let f = /* @__PURE__ */ new Set();
2313
+ if (l !== void 0)
2314
+ try {
2315
+ f = $s.parse(l);
2316
+ } catch {
2317
+ R(this, e, t, 400, "Invalid Sec-WebSocket-Protocol header");
2318
+ return;
2319
+ }
2320
+ const a = e.headers["sec-websocket-extensions"], c = {};
2321
+ if (this.options.perMessageDeflate && a !== void 0) {
2322
+ const h = new N(
2323
+ this.options.perMessageDeflate,
2324
+ !0,
2325
+ this.options.maxPayload
2326
+ );
2327
+ try {
2328
+ const p = qe.parse(a);
2329
+ p[N.extensionName] && (h.accept(p[N.extensionName]), c[N.extensionName] = h);
2330
+ } catch {
2331
+ R(this, e, t, 400, "Invalid or unacceptable Sec-WebSocket-Extensions header");
2332
+ return;
2333
+ }
2334
+ }
2335
+ if (this.options.verifyClient) {
2336
+ const h = {
2337
+ origin: e.headers[`${o === 8 ? "sec-websocket-origin" : "origin"}`],
2338
+ secure: !!(e.socket.authorized || e.socket.encrypted),
2339
+ req: e
2340
+ };
2341
+ if (this.options.verifyClient.length === 2) {
2342
+ this.options.verifyClient(h, (p, v, _, u) => {
2343
+ if (!p)
2344
+ return H(t, v || 401, _, u);
2345
+ this.completeUpgrade(
2346
+ c,
2347
+ n,
2348
+ f,
2349
+ e,
2350
+ t,
2351
+ r,
2352
+ i
2353
+ );
2354
+ });
2355
+ return;
2356
+ }
2357
+ if (!this.options.verifyClient(h))
2358
+ return H(t, 401);
2359
+ }
2360
+ this.completeUpgrade(c, n, f, e, t, r, i);
2361
+ }
2362
+ /**
2363
+ * Upgrade the connection to WebSocket.
2364
+ *
2365
+ * @param {Object} extensions The accepted extensions
2366
+ * @param {String} key The value of the `Sec-WebSocket-Key` header
2367
+ * @param {Set} protocols The subprotocols
2368
+ * @param {http.IncomingMessage} req The request object
2369
+ * @param {(net.Socket|tls.Socket)} socket The network socket between the
2370
+ * server and client
2371
+ * @param {Buffer} head The first packet of the upgraded stream
2372
+ * @param {Function} cb Callback
2373
+ * @throws {Error} If called more than once with the same socket
2374
+ * @private
2375
+ */
2376
+ completeUpgrade(e, t, r, i, n, o, l) {
2377
+ if (!n.readable || !n.writable)
2378
+ return n.destroy();
2379
+ if (n[Ds])
2380
+ throw new Error(
2381
+ "server.handleUpgrade() was called more than once with the same socket, possibly due to a misconfiguration"
2382
+ );
2383
+ if (this._state > Ke)
2384
+ return H(n, 503);
2385
+ const a = [
2386
+ "HTTP/1.1 101 Switching Protocols",
2387
+ "Upgrade: websocket",
2388
+ "Connection: Upgrade",
2389
+ `Sec-WebSocket-Accept: ${Bs("sha1").update(t + Is).digest("base64")}`
2390
+ ], c = new this.options.WebSocket(null);
2391
+ if (r.size) {
2392
+ const h = this.options.handleProtocols ? this.options.handleProtocols(r, i) : r.values().next().value;
2393
+ h && (a.push(`Sec-WebSocket-Protocol: ${h}`), c._protocol = h);
2394
+ }
2395
+ if (e[N.extensionName]) {
2396
+ const h = e[N.extensionName].params, p = qe.format({
2397
+ [N.extensionName]: [h]
2398
+ });
2399
+ a.push(`Sec-WebSocket-Extensions: ${p}`), c._extensions = e;
2400
+ }
2401
+ this.emit("headers", a, i), n.write(a.concat(`\r
2402
+ `).join(`\r
2403
+ `)), n.removeListener("error", Ze), c.setSocket(n, o, {
2404
+ maxPayload: this.options.maxPayload,
2405
+ skipUTF8Validation: this.options.skipUTF8Validation
2406
+ }), this.clients && (this.clients.add(c), c.on("close", () => {
2407
+ this.clients.delete(c), this._shouldEmitClose && !this.clients.size && process.nextTick(G, this);
2408
+ })), l(c, i);
2409
+ }
2410
+ }
2411
+ var Fs = As;
2412
+ function js(s, e) {
2413
+ for (const t of Object.keys(e))
2414
+ s.on(t, e[t]);
2415
+ return function() {
2416
+ for (const r of Object.keys(e))
2417
+ s.removeListener(r, e[r]);
2418
+ };
2419
+ }
2420
+ function G(s) {
2421
+ s._state = pt, s.emit("close");
2422
+ }
2423
+ function Ze() {
2424
+ this.destroy();
2425
+ }
2426
+ function H(s, e, t, r) {
2427
+ t = t || ie.STATUS_CODES[e], r = {
2428
+ Connection: "close",
2429
+ "Content-Type": "text/html",
2430
+ "Content-Length": Buffer.byteLength(t),
2431
+ ...r
2432
+ }, s.once("finish", s.destroy), s.end(
2433
+ `HTTP/1.1 ${e} ${ie.STATUS_CODES[e]}\r
2434
+ ` + Object.keys(r).map((i) => `${i}: ${r[i]}`).join(`\r
2435
+ `) + `\r
2436
+ \r
2437
+ ` + t
2438
+ );
2439
+ }
2440
+ function R(s, e, t, r, i) {
2441
+ if (s.listenerCount("wsClientError")) {
2442
+ const n = new Error(i);
2443
+ Error.captureStackTrace(n, R), s.emit("wsClientError", n, t, e);
2444
+ } else
2445
+ H(t, r, i);
2446
+ }
2447
+ const Zs = /* @__PURE__ */ z(Fs);
2448
+ export {
2449
+ qs as Receiver,
2450
+ Ks as Sender,
2451
+ Xs as WebSocket,
2452
+ Zs as WebSocketServer,
2453
+ Vs as createWebSocketStream,
2454
+ Xs as default
2455
+ };
src/backend/gradio_image_prompter/templates/example/index.js ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const { setContext: ee, getContext: p } = window.__gradio__svelte__internal, v = "WORKER_PROXY_CONTEXT_KEY";
2
+ function y() {
3
+ return p(v);
4
+ }
5
+ function k(l) {
6
+ return l.host === window.location.host || l.host === "localhost:7860" || l.host === "127.0.0.1:7860" || // Ref: https://github.com/gradio-app/gradio/blob/v3.32.0/js/app/src/Index.svelte#L194
7
+ l.host === "lite.local";
8
+ }
9
+ async function f(l) {
10
+ if (l == null)
11
+ return l;
12
+ const e = new URL(l);
13
+ if (!k(e) || e.protocol !== "http:" && e.protocol !== "https:")
14
+ return l;
15
+ const r = y();
16
+ if (r == null)
17
+ return l;
18
+ const n = e.pathname;
19
+ return r.httpRequest({
20
+ method: "GET",
21
+ path: n,
22
+ headers: {},
23
+ query_string: ""
24
+ }).then((t) => {
25
+ if (t.status !== 200)
26
+ throw new Error(`Failed to get file ${n} from the Wasm worker.`);
27
+ const o = new Blob([t.body], {
28
+ type: t.headers["Content-Type"]
29
+ });
30
+ return URL.createObjectURL(o);
31
+ });
32
+ }
33
+ const {
34
+ SvelteComponent: w,
35
+ append: C,
36
+ assign: _,
37
+ compute_rest_props: d,
38
+ detach: u,
39
+ element: b,
40
+ empty: E,
41
+ exclude_internal_props: R,
42
+ get_spread_update: O,
43
+ handle_promise: h,
44
+ init: q,
45
+ insert: m,
46
+ noop: c,
47
+ safe_not_equal: T,
48
+ set_attributes: g,
49
+ set_data: P,
50
+ set_style: U,
51
+ src_url_equal: W,
52
+ text: K,
53
+ update_await_block_branch: X
54
+ } = window.__gradio__svelte__internal;
55
+ function Y(l) {
56
+ let e, r = (
57
+ /*error*/
58
+ l[3].message + ""
59
+ ), n;
60
+ return {
61
+ c() {
62
+ e = b("p"), n = K(r), U(e, "color", "red");
63
+ },
64
+ m(t, o) {
65
+ m(t, e, o), C(e, n);
66
+ },
67
+ p(t, o) {
68
+ o & /*src*/
69
+ 1 && r !== (r = /*error*/
70
+ t[3].message + "") && P(n, r);
71
+ },
72
+ d(t) {
73
+ t && u(e);
74
+ }
75
+ };
76
+ }
77
+ function L(l) {
78
+ let e, r, n = [
79
+ {
80
+ src: r = /*resolved_src*/
81
+ l[2]
82
+ },
83
+ /*$$restProps*/
84
+ l[1]
85
+ ], t = {};
86
+ for (let o = 0; o < n.length; o += 1)
87
+ t = _(t, n[o]);
88
+ return {
89
+ c() {
90
+ e = b("img"), g(e, t);
91
+ },
92
+ m(o, s) {
93
+ m(o, e, s);
94
+ },
95
+ p(o, s) {
96
+ g(e, t = O(n, [
97
+ s & /*src*/
98
+ 1 && !W(e.src, r = /*resolved_src*/
99
+ o[2]) && { src: r },
100
+ s & /*$$restProps*/
101
+ 2 && /*$$restProps*/
102
+ o[1]
103
+ ]));
104
+ },
105
+ d(o) {
106
+ o && u(e);
107
+ }
108
+ };
109
+ }
110
+ function N(l) {
111
+ return { c, m: c, p: c, d: c };
112
+ }
113
+ function S(l) {
114
+ let e, r, n = {
115
+ ctx: l,
116
+ current: null,
117
+ token: null,
118
+ hasCatch: !0,
119
+ pending: N,
120
+ then: L,
121
+ catch: Y,
122
+ value: 2,
123
+ error: 3
124
+ };
125
+ return h(r = f(
126
+ /*src*/
127
+ l[0]
128
+ ), n), {
129
+ c() {
130
+ e = E(), n.block.c();
131
+ },
132
+ m(t, o) {
133
+ m(t, e, o), n.block.m(t, n.anchor = o), n.mount = () => e.parentNode, n.anchor = e;
134
+ },
135
+ p(t, [o]) {
136
+ l = t, n.ctx = l, o & /*src*/
137
+ 1 && r !== (r = f(
138
+ /*src*/
139
+ l[0]
140
+ )) && h(r, n) || X(n, l, o);
141
+ },
142
+ i: c,
143
+ o: c,
144
+ d(t) {
145
+ t && u(e), n.block.d(t), n.token = null, n = null;
146
+ }
147
+ };
148
+ }
149
+ function j(l, e, r) {
150
+ const n = ["src"];
151
+ let t = d(e, n), { src: o = void 0 } = e;
152
+ return l.$$set = (s) => {
153
+ e = _(_({}, e), R(s)), r(1, t = d(e, n)), "src" in s && r(0, o = s.src);
154
+ }, [o, t];
155
+ }
156
+ class B extends w {
157
+ constructor(e) {
158
+ super(), q(this, e, j, S, T, { src: 0 });
159
+ }
160
+ }
161
+ const {
162
+ SvelteComponent: F,
163
+ attr: G,
164
+ create_component: I,
165
+ destroy_component: z,
166
+ detach: A,
167
+ element: D,
168
+ init: H,
169
+ insert: J,
170
+ mount_component: M,
171
+ safe_not_equal: Q,
172
+ toggle_class: i,
173
+ transition_in: V,
174
+ transition_out: Z
175
+ } = window.__gradio__svelte__internal;
176
+ function x(l) {
177
+ let e, r, n;
178
+ return r = new B({
179
+ props: {
180
+ src: (
181
+ /*samples_dir*/
182
+ l[1] + /*value*/
183
+ l[0]
184
+ ),
185
+ alt: ""
186
+ }
187
+ }), {
188
+ c() {
189
+ e = D("div"), I(r.$$.fragment), G(e, "class", "container svelte-h11ksk"), i(
190
+ e,
191
+ "table",
192
+ /*type*/
193
+ l[2] === "table"
194
+ ), i(
195
+ e,
196
+ "gallery",
197
+ /*type*/
198
+ l[2] === "gallery"
199
+ ), i(
200
+ e,
201
+ "selected",
202
+ /*selected*/
203
+ l[3]
204
+ );
205
+ },
206
+ m(t, o) {
207
+ J(t, e, o), M(r, e, null), n = !0;
208
+ },
209
+ p(t, [o]) {
210
+ const s = {};
211
+ o & /*samples_dir, value*/
212
+ 3 && (s.src = /*samples_dir*/
213
+ t[1] + /*value*/
214
+ t[0]), r.$set(s), (!n || o & /*type*/
215
+ 4) && i(
216
+ e,
217
+ "table",
218
+ /*type*/
219
+ t[2] === "table"
220
+ ), (!n || o & /*type*/
221
+ 4) && i(
222
+ e,
223
+ "gallery",
224
+ /*type*/
225
+ t[2] === "gallery"
226
+ ), (!n || o & /*selected*/
227
+ 8) && i(
228
+ e,
229
+ "selected",
230
+ /*selected*/
231
+ t[3]
232
+ );
233
+ },
234
+ i(t) {
235
+ n || (V(r.$$.fragment, t), n = !0);
236
+ },
237
+ o(t) {
238
+ Z(r.$$.fragment, t), n = !1;
239
+ },
240
+ d(t) {
241
+ t && A(e), z(r);
242
+ }
243
+ };
244
+ }
245
+ function $(l, e, r) {
246
+ let { value: n } = e, { samples_dir: t } = e, { type: o } = e, { selected: s = !1 } = e;
247
+ return l.$$set = (a) => {
248
+ "value" in a && r(0, n = a.value), "samples_dir" in a && r(1, t = a.samples_dir), "type" in a && r(2, o = a.type), "selected" in a && r(3, s = a.selected);
249
+ }, [n, t, o, s];
250
+ }
251
+ class te extends F {
252
+ constructor(e) {
253
+ super(), H(this, e, $, x, Q, {
254
+ value: 0,
255
+ samples_dir: 1,
256
+ type: 2,
257
+ selected: 3
258
+ });
259
+ }
260
+ }
261
+ export {
262
+ te as default
263
+ };
src/backend/gradio_image_prompter/templates/example/style.css ADDED
@@ -0,0 +1 @@
 
 
1
+ .container.svelte-h11ksk img{width:100%;height:100%}.container.selected.svelte-h11ksk{border-color:var(--border-color-accent)}.container.table.svelte-h11ksk{margin:0 auto;border:2px solid var(--border-color-primary);border-radius:var(--radius-lg);overflow:hidden;width:var(--size-20);height:var(--size-20);object-fit:cover}.container.gallery.svelte-h11ksk{height:var(--size-20);max-height:var(--size-20);object-fit:cover}
src/demo/__init__.py ADDED
File without changes
src/demo/app.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_image_prompter import ImagePrompter
3
+
4
+ demo = gr.Interface(
5
+ lambda prompts: (prompts["image"], prompts["points"]),
6
+ ImagePrompter(show_label=False),
7
+ [gr.Image(show_label=False), gr.Dataframe(label="Points")],
8
+ )
9
+ demo.launch()
src/frontend/Example.svelte ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import Image from "./shared/Image.svelte";
3
+
4
+ export let value: string;
5
+ export let samples_dir: string;
6
+ export let type: "gallery" | "table";
7
+ export let selected = false;
8
+ </script>
9
+
10
+ <div
11
+ class="container"
12
+ class:table={type === "table"}
13
+ class:gallery={type === "gallery"}
14
+ class:selected
15
+ >
16
+ <Image src={samples_dir + value} alt="" />
17
+ </div>
18
+
19
+ <style>
20
+ .container :global(img) {
21
+ width: 100%;
22
+ height: 100%;
23
+ }
24
+
25
+ .container.selected {
26
+ border-color: var(--border-color-accent);
27
+ }
28
+
29
+ .container.table {
30
+ margin: 0 auto;
31
+ border: 2px solid var(--border-color-primary);
32
+ border-radius: var(--radius-lg);
33
+ overflow: hidden;
34
+ width: var(--size-20);
35
+ height: var(--size-20);
36
+ object-fit: cover;
37
+ }
38
+
39
+ .container.gallery {
40
+ height: var(--size-20);
41
+ max-height: var(--size-20);
42
+ object-fit: cover;
43
+ }
44
+ </style>
src/frontend/Index.svelte ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <svelte:options accessors={true} />
2
+
3
+ <script context="module" lang="ts">
4
+ export { default as BaseImageUploader } from "./shared/ImageUploader.svelte";
5
+ export { default as BaseStaticImage } from "./shared/ImagePreview.svelte";
6
+ export { default as BaseExample } from "./Example.svelte";
7
+ export { default as BaseImage } from "./shared/Image.svelte";
8
+ export { default as BoxDrawer } from "./shared/BoxDrawer.svelte";
9
+ </script>
10
+
11
+ <script lang="ts">
12
+ import type { Gradio, SelectData } from "@gradio/utils";
13
+ import StaticImage from "./shared/ImagePreview.svelte";
14
+ import ImageUploader from "./shared/ImageUploader.svelte";
15
+
16
+ import { Block, Empty, UploadText } from "@gradio/atoms";
17
+ import { Image } from "@gradio/icons";
18
+ import { StatusTracker } from "@gradio/statustracker";
19
+ import type { FileData } from "@gradio/client";
20
+ import type { LoadingStatus } from "@gradio/statustracker";
21
+ import { normalise_file } from "@gradio/client";
22
+
23
+ export let elem_id = "";
24
+ export let elem_classes: string[] = [];
25
+ export let visible = true;
26
+
27
+ export let value: { image: FileData; points: number[][6] } | null = null;
28
+ $: _image = value && normalise_file(value.image, root, proxy_url);
29
+ $: _points = value && value.points;
30
+
31
+ export let label: string;
32
+ export let show_label: boolean;
33
+ export let show_download_button: boolean;
34
+ export let root: string;
35
+ export let proxy_url: null | string;
36
+
37
+ export let height: number | undefined;
38
+ export let width: number | undefined;
39
+
40
+ export let _selectable = false;
41
+ export let container = true;
42
+ export let scale: number | null = null;
43
+ export let min_width: number | undefined = undefined;
44
+ export let loading_status: LoadingStatus;
45
+ export let show_share_button = false;
46
+ export let sources: "upload"[] = ["upload"];
47
+ export let interactive: boolean;
48
+ export let streaming: boolean;
49
+
50
+ export let gradio: Gradio<{
51
+ change: never;
52
+ error: string;
53
+ edit: never;
54
+ stream: never;
55
+ drag: never;
56
+ upload: never;
57
+ clear: never;
58
+ select: SelectData;
59
+ share: ShareData;
60
+ }>;
61
+
62
+ $: url = _image?.url;
63
+ $: url && gradio.dispatch("change");
64
+
65
+ let dragging: boolean;
66
+ let active_tool: null | "webcam" = null;
67
+ </script>
68
+
69
+ {#if !interactive}
70
+ <Block
71
+ {visible}
72
+ variant={"solid"}
73
+ border_mode={dragging ? "focus" : "base"}
74
+ padding={false}
75
+ {elem_id}
76
+ {elem_classes}
77
+ height={height || undefined}
78
+ {width}
79
+ allow_overflow={false}
80
+ {container}
81
+ {scale}
82
+ {min_width}
83
+ >
84
+ <StatusTracker
85
+ autoscroll={gradio.autoscroll}
86
+ i18n={gradio.i18n}
87
+ {...loading_status}
88
+ />
89
+ <StaticImage
90
+ on:select={({ detail }) => gradio.dispatch("select", detail)}
91
+ on:share={({ detail }) => gradio.dispatch("share", detail)}
92
+ on:error={({ detail }) => gradio.dispatch("error", detail)}
93
+ value={_image}
94
+ {label}
95
+ {show_label}
96
+ {show_download_button}
97
+ selectable={_selectable}
98
+ {show_share_button}
99
+ i18n={gradio.i18n}
100
+ />
101
+ </Block>
102
+ {:else}
103
+ <Block
104
+ {visible}
105
+ variant={_image === null ? "dashed" : "solid"}
106
+ border_mode={dragging ? "focus" : "base"}
107
+ padding={false}
108
+ {elem_id}
109
+ {elem_classes}
110
+ height={height || undefined}
111
+ {width}
112
+ allow_overflow={false}
113
+ {container}
114
+ {scale}
115
+ {min_width}
116
+ >
117
+ <StatusTracker
118
+ autoscroll={gradio.autoscroll}
119
+ i18n={gradio.i18n}
120
+ {...loading_status}
121
+ />
122
+
123
+ <ImageUploader
124
+ bind:active_tool
125
+ bind:value={_image}
126
+ bind:points={_points}
127
+ {root}
128
+ {sources}
129
+ on:points_change={({ detail }) => (value.points = detail)}
130
+ on:edit={() => gradio.dispatch("edit")}
131
+ on:clear={() => {
132
+ value = null;
133
+ gradio.dispatch("clear");
134
+ gradio.dispatch("change");
135
+ }}
136
+ on:stream={() => gradio.dispatch("stream")}
137
+ on:drag={({ detail }) => (dragging = detail)}
138
+ on:upload={({ detail }) => {
139
+ if (value == null) {
140
+ value = { image: detail, points: null };
141
+ } else {
142
+ value.image = detail;
143
+ }
144
+ gradio.dispatch("upload");
145
+ }}
146
+ on:select={({ detail }) => gradio.dispatch("select", detail)}
147
+ on:share={({ detail }) => gradio.dispatch("share", detail)}
148
+ on:error={({ detail }) => {
149
+ loading_status = loading_status;
150
+ loading_status.status = "error";
151
+ gradio.dispatch("error", detail);
152
+ }}
153
+ on:click={() => gradio.dispatch("error", "bad thing happened")}
154
+ on:error
155
+ {label}
156
+ {show_label}
157
+ {streaming}
158
+ i18n={gradio.i18n}
159
+ >
160
+ {#if sources.includes("upload")}
161
+ <UploadText i18n={gradio.i18n} type="image" mode="short" />
162
+ {:else}
163
+ <Empty unpadded_box={true} size="large"><Image /></Empty>
164
+ {/if}
165
+ </ImageUploader>
166
+ </Block>
167
+ {/if}
src/frontend/package-lock.json ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "gradio_image_prompter",
3
+ "version": "0.4.2",
4
+ "lockfileVersion": 3,
5
+ "requires": true,
6
+ "packages": {
7
+ "": {
8
+ "name": "gradio_image_prompter",
9
+ "version": "0.4.2",
10
+ "license": "ISC",
11
+ "dependencies": {
12
+ "@gradio/atoms": "0.3.1",
13
+ "@gradio/client": "0.8.2",
14
+ "@gradio/icons": "0.3.1",
15
+ "@gradio/statustracker": "0.4.1",
16
+ "@gradio/upload": "0.5.2",
17
+ "@gradio/utils": "0.2.0",
18
+ "@gradio/wasm": "0.3.0",
19
+ "cropperjs": "^1.5.12",
20
+ "lazy-brush": "^1.0.1",
21
+ "resize-observer-polyfill": "^1.5.1"
22
+ }
23
+ },
24
+ "node_modules/@ampproject/remapping": {
25
+ "version": "2.2.1",
26
+ "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.2.1.tgz",
27
+ "integrity": "sha512-lFMjJTrFL3j7L9yBxwYfCq2k6qqwHyzuUl/XBnif78PWTJYyL/dfowQHWE3sp6U6ZzqWiiIZnpTMO96zhkjwtg==",
28
+ "peer": true,
29
+ "dependencies": {
30
+ "@jridgewell/gen-mapping": "^0.3.0",
31
+ "@jridgewell/trace-mapping": "^0.3.9"
32
+ },
33
+ "engines": {
34
+ "node": ">=6.0.0"
35
+ }
36
+ },
37
+ "node_modules/@esbuild/darwin-arm64": {
38
+ "version": "0.19.8",
39
+ "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.19.8.tgz",
40
+ "integrity": "sha512-RQw9DemMbIq35Bprbboyf8SmOr4UXsRVxJ97LgB55VKKeJOOdvsIPy0nFyF2l8U+h4PtBx/1kRf0BelOYCiQcw==",
41
+ "cpu": [
42
+ "arm64"
43
+ ],
44
+ "optional": true,
45
+ "os": [
46
+ "darwin"
47
+ ],
48
+ "engines": {
49
+ "node": ">=12"
50
+ }
51
+ },
52
+ "node_modules/@formatjs/ecma402-abstract": {
53
+ "version": "1.11.4",
54
+ "resolved": "https://registry.npmjs.org/@formatjs/ecma402-abstract/-/ecma402-abstract-1.11.4.tgz",
55
+ "integrity": "sha512-EBikYFp2JCdIfGEb5G9dyCkTGDmC57KSHhRQOC3aYxoPWVZvfWCDjZwkGYHN7Lis/fmuWl906bnNTJifDQ3sXw==",
56
+ "dependencies": {
57
+ "@formatjs/intl-localematcher": "0.2.25",
58
+ "tslib": "^2.1.0"
59
+ }
60
+ },
61
+ "node_modules/@formatjs/fast-memoize": {
62
+ "version": "1.2.1",
63
+ "resolved": "https://registry.npmjs.org/@formatjs/fast-memoize/-/fast-memoize-1.2.1.tgz",
64
+ "integrity": "sha512-Rg0e76nomkz3vF9IPlKeV+Qynok0r7YZjL6syLz4/urSg0IbjPZCB/iYUMNsYA643gh4mgrX3T7KEIFIxJBQeg==",
65
+ "dependencies": {
66
+ "tslib": "^2.1.0"
67
+ }
68
+ },
69
+ "node_modules/@formatjs/icu-messageformat-parser": {
70
+ "version": "2.1.0",
71
+ "resolved": "https://registry.npmjs.org/@formatjs/icu-messageformat-parser/-/icu-messageformat-parser-2.1.0.tgz",
72
+ "integrity": "sha512-Qxv/lmCN6hKpBSss2uQ8IROVnta2r9jd3ymUEIjm2UyIkUCHVcbUVRGL/KS/wv7876edvsPe+hjHVJ4z8YuVaw==",
73
+ "dependencies": {
74
+ "@formatjs/ecma402-abstract": "1.11.4",
75
+ "@formatjs/icu-skeleton-parser": "1.3.6",
76
+ "tslib": "^2.1.0"
77
+ }
78
+ },
79
+ "node_modules/@formatjs/icu-skeleton-parser": {
80
+ "version": "1.3.6",
81
+ "resolved": "https://registry.npmjs.org/@formatjs/icu-skeleton-parser/-/icu-skeleton-parser-1.3.6.tgz",
82
+ "integrity": "sha512-I96mOxvml/YLrwU2Txnd4klA7V8fRhb6JG/4hm3VMNmeJo1F03IpV2L3wWt7EweqNLES59SZ4d6hVOPCSf80Bg==",
83
+ "dependencies": {
84
+ "@formatjs/ecma402-abstract": "1.11.4",
85
+ "tslib": "^2.1.0"
86
+ }
87
+ },
88
+ "node_modules/@formatjs/intl-localematcher": {
89
+ "version": "0.2.25",
90
+ "resolved": "https://registry.npmjs.org/@formatjs/intl-localematcher/-/intl-localematcher-0.2.25.tgz",
91
+ "integrity": "sha512-YmLcX70BxoSopLFdLr1Ds99NdlTI2oWoLbaUW2M406lxOIPzE1KQhRz2fPUkq34xVZQaihCoU29h0KK7An3bhA==",
92
+ "dependencies": {
93
+ "tslib": "^2.1.0"
94
+ }
95
+ },
96
+ "node_modules/@gradio/atoms": {
97
+ "version": "0.3.1",
98
+ "resolved": "https://registry.npmjs.org/@gradio/atoms/-/atoms-0.3.1.tgz",
99
+ "integrity": "sha512-P2u1Qud/EmwfGMD9HZdSkw4L3RznGUE3owBx4lRY7JP/1J3sDqy/wN8pZFex+kPKripX29+IiH6+4TRqSs2zFw==",
100
+ "dependencies": {
101
+ "@gradio/icons": "^0.3.1",
102
+ "@gradio/utils": "^0.2.0"
103
+ }
104
+ },
105
+ "node_modules/@gradio/client": {
106
+ "version": "0.8.2",
107
+ "resolved": "https://registry.npmjs.org/@gradio/client/-/client-0.8.2.tgz",
108
+ "integrity": "sha512-ZWrkJBsVg7ioIHhGV1pqIo4MBL0GPn0SHLeA04cqrsxkWiZHZz9CB5wFtm1kaFtd68ERAgEzR8OYVzzlBd2pyQ==",
109
+ "dependencies": {
110
+ "bufferutil": "^4.0.7",
111
+ "semiver": "^1.1.0",
112
+ "ws": "^8.13.0"
113
+ },
114
+ "engines": {
115
+ "node": ">=18.0.0"
116
+ }
117
+ },
118
+ "node_modules/@gradio/column": {
119
+ "version": "0.1.0",
120
+ "resolved": "https://registry.npmjs.org/@gradio/column/-/column-0.1.0.tgz",
121
+ "integrity": "sha512-P24nqqVnMXBaDA1f/zSN5HZRho4PxP8Dq+7VltPHlmxIEiZYik2AJ4J0LeuIha34FDO0guu/16evdrpvGIUAfw=="
122
+ },
123
+ "node_modules/@gradio/icons": {
124
+ "version": "0.3.1",
125
+ "resolved": "https://registry.npmjs.org/@gradio/icons/-/icons-0.3.1.tgz",
126
+ "integrity": "sha512-ZwgXODKa7irD+spE0RCae8fyixgwKOtds6wHL300n9pIRYzL9QkvS1cQJbz0C6NupFCYRSGTQrV5hoLo7yQCew=="
127
+ },
128
+ "node_modules/@gradio/statustracker": {
129
+ "version": "0.4.1",
130
+ "resolved": "https://registry.npmjs.org/@gradio/statustracker/-/statustracker-0.4.1.tgz",
131
+ "integrity": "sha512-6YV5UDzau/nNid5D25YLZyPGm/tFd9b0a+x0OCHY+aE3cez7PD4v6hWGuQXPNwa/69viRm8YyoQ2Vex7/3updA==",
132
+ "dependencies": {
133
+ "@gradio/atoms": "^0.3.1",
134
+ "@gradio/column": "^0.1.0",
135
+ "@gradio/icons": "^0.3.1",
136
+ "@gradio/utils": "^0.2.0"
137
+ }
138
+ },
139
+ "node_modules/@gradio/theme": {
140
+ "version": "0.2.0",
141
+ "resolved": "https://registry.npmjs.org/@gradio/theme/-/theme-0.2.0.tgz",
142
+ "integrity": "sha512-33c68Nk7oRXLn08OxPfjcPm7S4tXGOUV1I1bVgzdM2YV5o1QBOS1GEnXPZPu/CEYPePLMB6bsDwffrLEyLGWVQ=="
143
+ },
144
+ "node_modules/@gradio/upload": {
145
+ "version": "0.5.2",
146
+ "resolved": "https://registry.npmjs.org/@gradio/upload/-/upload-0.5.2.tgz",
147
+ "integrity": "sha512-IXQZ/+0TG/FSOSjJKE28lUG+vGGboD+YQswyvSK6lOpRHvixiqK+eJo0g3jHvmWO9wZLBrEx3XRv8LSgnVHHzw==",
148
+ "dependencies": {
149
+ "@gradio/atoms": "^0.3.1",
150
+ "@gradio/client": "^0.8.2",
151
+ "@gradio/icons": "^0.3.1",
152
+ "@gradio/upload": "^0.5.2",
153
+ "@gradio/utils": "^0.2.0"
154
+ }
155
+ },
156
+ "node_modules/@gradio/utils": {
157
+ "version": "0.2.0",
158
+ "resolved": "https://registry.npmjs.org/@gradio/utils/-/utils-0.2.0.tgz",
159
+ "integrity": "sha512-YkwzXufi6IxQrlMW+1sFo8Yn6F9NLL69ZoBsbo7QEhms0v5L7pmOTw+dfd7M3dwbRP2lgjrb52i1kAIN3n6aqQ==",
160
+ "dependencies": {
161
+ "@gradio/theme": "^0.2.0",
162
+ "svelte-i18n": "^3.6.0"
163
+ }
164
+ },
165
+ "node_modules/@gradio/wasm": {
166
+ "version": "0.3.0",
167
+ "resolved": "https://registry.npmjs.org/@gradio/wasm/-/wasm-0.3.0.tgz",
168
+ "integrity": "sha512-avgMFBrHUUDzQraBMW9mNgiQMMkObsPzDap0PZV6FgzfDpW8K+R4BBcl+gClq82jRi3ulDjtISTXriUrNNfkrg==",
169
+ "dependencies": {
170
+ "@types/path-browserify": "^1.0.0",
171
+ "path-browserify": "^1.0.1"
172
+ }
173
+ },
174
+ "node_modules/@jridgewell/gen-mapping": {
175
+ "version": "0.3.3",
176
+ "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz",
177
+ "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==",
178
+ "peer": true,
179
+ "dependencies": {
180
+ "@jridgewell/set-array": "^1.0.1",
181
+ "@jridgewell/sourcemap-codec": "^1.4.10",
182
+ "@jridgewell/trace-mapping": "^0.3.9"
183
+ },
184
+ "engines": {
185
+ "node": ">=6.0.0"
186
+ }
187
+ },
188
+ "node_modules/@jridgewell/resolve-uri": {
189
+ "version": "3.1.1",
190
+ "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz",
191
+ "integrity": "sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==",
192
+ "peer": true,
193
+ "engines": {
194
+ "node": ">=6.0.0"
195
+ }
196
+ },
197
+ "node_modules/@jridgewell/set-array": {
198
+ "version": "1.1.2",
199
+ "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz",
200
+ "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==",
201
+ "peer": true,
202
+ "engines": {
203
+ "node": ">=6.0.0"
204
+ }
205
+ },
206
+ "node_modules/@jridgewell/sourcemap-codec": {
207
+ "version": "1.4.15",
208
+ "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz",
209
+ "integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==",
210
+ "peer": true
211
+ },
212
+ "node_modules/@jridgewell/trace-mapping": {
213
+ "version": "0.3.20",
214
+ "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.20.tgz",
215
+ "integrity": "sha512-R8LcPeWZol2zR8mmH3JeKQ6QRCFb7XgUhV9ZlGhHLGyg4wpPiPZNQOOWhFZhxKw8u//yTbNGI42Bx/3paXEQ+Q==",
216
+ "peer": true,
217
+ "dependencies": {
218
+ "@jridgewell/resolve-uri": "^3.1.0",
219
+ "@jridgewell/sourcemap-codec": "^1.4.14"
220
+ }
221
+ },
222
+ "node_modules/@types/estree": {
223
+ "version": "1.0.5",
224
+ "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz",
225
+ "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==",
226
+ "peer": true
227
+ },
228
+ "node_modules/@types/path-browserify": {
229
+ "version": "1.0.2",
230
+ "resolved": "https://registry.npmjs.org/@types/path-browserify/-/path-browserify-1.0.2.tgz",
231
+ "integrity": "sha512-ZkC5IUqqIFPXx3ASTTybTzmQdwHwe2C0u3eL75ldQ6T9E9IWFJodn6hIfbZGab73DfyiHN4Xw15gNxUq2FbvBA=="
232
+ },
233
+ "node_modules/acorn": {
234
+ "version": "8.11.2",
235
+ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.11.2.tgz",
236
+ "integrity": "sha512-nc0Axzp/0FILLEVsm4fNwLCwMttvhEI263QtVPQcbpfZZ3ts0hLsZGOpE6czNlid7CJ9MlyH8reXkpsf3YUY4w==",
237
+ "peer": true,
238
+ "bin": {
239
+ "acorn": "bin/acorn"
240
+ },
241
+ "engines": {
242
+ "node": ">=0.4.0"
243
+ }
244
+ },
245
+ "node_modules/aria-query": {
246
+ "version": "5.3.0",
247
+ "resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.0.tgz",
248
+ "integrity": "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==",
249
+ "peer": true,
250
+ "dependencies": {
251
+ "dequal": "^2.0.3"
252
+ }
253
+ },
254
+ "node_modules/axobject-query": {
255
+ "version": "3.2.1",
256
+ "resolved": "https://registry.npmjs.org/axobject-query/-/axobject-query-3.2.1.tgz",
257
+ "integrity": "sha512-jsyHu61e6N4Vbz/v18DHwWYKK0bSWLqn47eeDSKPB7m8tqMHF9YJ+mhIk2lVteyZrY8tnSj/jHOv4YiTCuCJgg==",
258
+ "peer": true,
259
+ "dependencies": {
260
+ "dequal": "^2.0.3"
261
+ }
262
+ },
263
+ "node_modules/bufferutil": {
264
+ "version": "4.0.8",
265
+ "resolved": "https://registry.npmjs.org/bufferutil/-/bufferutil-4.0.8.tgz",
266
+ "integrity": "sha512-4T53u4PdgsXqKaIctwF8ifXlRTTmEPJ8iEPWFdGZvcf7sbwYo6FKFEX9eNNAnzFZ7EzJAQ3CJeOtCRA4rDp7Pw==",
267
+ "hasInstallScript": true,
268
+ "dependencies": {
269
+ "node-gyp-build": "^4.3.0"
270
+ },
271
+ "engines": {
272
+ "node": ">=6.14.2"
273
+ }
274
+ },
275
+ "node_modules/cli-color": {
276
+ "version": "2.0.3",
277
+ "resolved": "https://registry.npmjs.org/cli-color/-/cli-color-2.0.3.tgz",
278
+ "integrity": "sha512-OkoZnxyC4ERN3zLzZaY9Emb7f/MhBOIpePv0Ycok0fJYT+Ouo00UBEIwsVsr0yoow++n5YWlSUgST9GKhNHiRQ==",
279
+ "dependencies": {
280
+ "d": "^1.0.1",
281
+ "es5-ext": "^0.10.61",
282
+ "es6-iterator": "^2.0.3",
283
+ "memoizee": "^0.4.15",
284
+ "timers-ext": "^0.1.7"
285
+ },
286
+ "engines": {
287
+ "node": ">=0.10"
288
+ }
289
+ },
290
+ "node_modules/code-red": {
291
+ "version": "1.0.4",
292
+ "resolved": "https://registry.npmjs.org/code-red/-/code-red-1.0.4.tgz",
293
+ "integrity": "sha512-7qJWqItLA8/VPVlKJlFXU+NBlo/qyfs39aJcuMT/2ere32ZqvF5OSxgdM5xOfJJ7O429gg2HM47y8v9P+9wrNw==",
294
+ "peer": true,
295
+ "dependencies": {
296
+ "@jridgewell/sourcemap-codec": "^1.4.15",
297
+ "@types/estree": "^1.0.1",
298
+ "acorn": "^8.10.0",
299
+ "estree-walker": "^3.0.3",
300
+ "periscopic": "^3.1.0"
301
+ }
302
+ },
303
+ "node_modules/cropperjs": {
304
+ "version": "1.6.1",
305
+ "resolved": "https://registry.npmjs.org/cropperjs/-/cropperjs-1.6.1.tgz",
306
+ "integrity": "sha512-F4wsi+XkDHCOMrHMYjrTEE4QBOrsHHN5/2VsVAaRq8P7E5z7xQpT75S+f/9WikmBEailas3+yo+6zPIomW+NOA=="
307
+ },
308
+ "node_modules/css-tree": {
309
+ "version": "2.3.1",
310
+ "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-2.3.1.tgz",
311
+ "integrity": "sha512-6Fv1DV/TYw//QF5IzQdqsNDjx/wc8TrMBZsqjL9eW01tWb7R7k/mq+/VXfJCl7SoD5emsJop9cOByJZfs8hYIw==",
312
+ "peer": true,
313
+ "dependencies": {
314
+ "mdn-data": "2.0.30",
315
+ "source-map-js": "^1.0.1"
316
+ },
317
+ "engines": {
318
+ "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0"
319
+ }
320
+ },
321
+ "node_modules/d": {
322
+ "version": "1.0.1",
323
+ "resolved": "https://registry.npmjs.org/d/-/d-1.0.1.tgz",
324
+ "integrity": "sha512-m62ShEObQ39CfralilEQRjH6oAMtNCV1xJyEx5LpRYUVN+EviphDgUc/F3hnYbADmkiNs67Y+3ylmlG7Lnu+FA==",
325
+ "dependencies": {
326
+ "es5-ext": "^0.10.50",
327
+ "type": "^1.0.1"
328
+ }
329
+ },
330
+ "node_modules/deepmerge": {
331
+ "version": "4.3.1",
332
+ "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz",
333
+ "integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==",
334
+ "engines": {
335
+ "node": ">=0.10.0"
336
+ }
337
+ },
338
+ "node_modules/dequal": {
339
+ "version": "2.0.3",
340
+ "resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz",
341
+ "integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==",
342
+ "peer": true,
343
+ "engines": {
344
+ "node": ">=6"
345
+ }
346
+ },
347
+ "node_modules/es5-ext": {
348
+ "version": "0.10.62",
349
+ "resolved": "https://registry.npmjs.org/es5-ext/-/es5-ext-0.10.62.tgz",
350
+ "integrity": "sha512-BHLqn0klhEpnOKSrzn/Xsz2UIW8j+cGmo9JLzr8BiUapV8hPL9+FliFqjwr9ngW7jWdnxv6eO+/LqyhJVqgrjA==",
351
+ "hasInstallScript": true,
352
+ "dependencies": {
353
+ "es6-iterator": "^2.0.3",
354
+ "es6-symbol": "^3.1.3",
355
+ "next-tick": "^1.1.0"
356
+ },
357
+ "engines": {
358
+ "node": ">=0.10"
359
+ }
360
+ },
361
+ "node_modules/es6-iterator": {
362
+ "version": "2.0.3",
363
+ "resolved": "https://registry.npmjs.org/es6-iterator/-/es6-iterator-2.0.3.tgz",
364
+ "integrity": "sha512-zw4SRzoUkd+cl+ZoE15A9o1oQd920Bb0iOJMQkQhl3jNc03YqVjAhG7scf9C5KWRU/R13Orf588uCC6525o02g==",
365
+ "dependencies": {
366
+ "d": "1",
367
+ "es5-ext": "^0.10.35",
368
+ "es6-symbol": "^3.1.1"
369
+ }
370
+ },
371
+ "node_modules/es6-symbol": {
372
+ "version": "3.1.3",
373
+ "resolved": "https://registry.npmjs.org/es6-symbol/-/es6-symbol-3.1.3.tgz",
374
+ "integrity": "sha512-NJ6Yn3FuDinBaBRWl/q5X/s4koRHBrgKAu+yGI6JCBeiu3qrcbJhwT2GeR/EXVfylRk8dpQVJoLEFhK+Mu31NA==",
375
+ "dependencies": {
376
+ "d": "^1.0.1",
377
+ "ext": "^1.1.2"
378
+ }
379
+ },
380
+ "node_modules/es6-weak-map": {
381
+ "version": "2.0.3",
382
+ "resolved": "https://registry.npmjs.org/es6-weak-map/-/es6-weak-map-2.0.3.tgz",
383
+ "integrity": "sha512-p5um32HOTO1kP+w7PRnB+5lQ43Z6muuMuIMffvDN8ZB4GcnjLBV6zGStpbASIMk4DCAvEaamhe2zhyCb/QXXsA==",
384
+ "dependencies": {
385
+ "d": "1",
386
+ "es5-ext": "^0.10.46",
387
+ "es6-iterator": "^2.0.3",
388
+ "es6-symbol": "^3.1.1"
389
+ }
390
+ },
391
+ "node_modules/esbuild": {
392
+ "version": "0.19.8",
393
+ "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.8.tgz",
394
+ "integrity": "sha512-l7iffQpT2OrZfH2rXIp7/FkmaeZM0vxbxN9KfiCwGYuZqzMg/JdvX26R31Zxn/Pxvsrg3Y9N6XTcnknqDyyv4w==",
395
+ "hasInstallScript": true,
396
+ "bin": {
397
+ "esbuild": "bin/esbuild"
398
+ },
399
+ "engines": {
400
+ "node": ">=12"
401
+ },
402
+ "optionalDependencies": {
403
+ "@esbuild/android-arm": "0.19.8",
404
+ "@esbuild/android-arm64": "0.19.8",
405
+ "@esbuild/android-x64": "0.19.8",
406
+ "@esbuild/darwin-arm64": "0.19.8",
407
+ "@esbuild/darwin-x64": "0.19.8",
408
+ "@esbuild/freebsd-arm64": "0.19.8",
409
+ "@esbuild/freebsd-x64": "0.19.8",
410
+ "@esbuild/linux-arm": "0.19.8",
411
+ "@esbuild/linux-arm64": "0.19.8",
412
+ "@esbuild/linux-ia32": "0.19.8",
413
+ "@esbuild/linux-loong64": "0.19.8",
414
+ "@esbuild/linux-mips64el": "0.19.8",
415
+ "@esbuild/linux-ppc64": "0.19.8",
416
+ "@esbuild/linux-riscv64": "0.19.8",
417
+ "@esbuild/linux-s390x": "0.19.8",
418
+ "@esbuild/linux-x64": "0.19.8",
419
+ "@esbuild/netbsd-x64": "0.19.8",
420
+ "@esbuild/openbsd-x64": "0.19.8",
421
+ "@esbuild/sunos-x64": "0.19.8",
422
+ "@esbuild/win32-arm64": "0.19.8",
423
+ "@esbuild/win32-ia32": "0.19.8",
424
+ "@esbuild/win32-x64": "0.19.8"
425
+ }
426
+ },
427
+ "node_modules/estree-walker": {
428
+ "version": "3.0.3",
429
+ "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz",
430
+ "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==",
431
+ "peer": true,
432
+ "dependencies": {
433
+ "@types/estree": "^1.0.0"
434
+ }
435
+ },
436
+ "node_modules/event-emitter": {
437
+ "version": "0.3.5",
438
+ "resolved": "https://registry.npmjs.org/event-emitter/-/event-emitter-0.3.5.tgz",
439
+ "integrity": "sha512-D9rRn9y7kLPnJ+hMq7S/nhvoKwwvVJahBi2BPmx3bvbsEdK3W9ii8cBSGjP+72/LnM4n6fo3+dkCX5FeTQruXA==",
440
+ "dependencies": {
441
+ "d": "1",
442
+ "es5-ext": "~0.10.14"
443
+ }
444
+ },
445
+ "node_modules/ext": {
446
+ "version": "1.7.0",
447
+ "resolved": "https://registry.npmjs.org/ext/-/ext-1.7.0.tgz",
448
+ "integrity": "sha512-6hxeJYaL110a9b5TEJSj0gojyHQAmA2ch5Os+ySCiA1QGdS697XWY1pzsrSjqA9LDEEgdB/KypIlR59RcLuHYw==",
449
+ "dependencies": {
450
+ "type": "^2.7.2"
451
+ }
452
+ },
453
+ "node_modules/ext/node_modules/type": {
454
+ "version": "2.7.2",
455
+ "resolved": "https://registry.npmjs.org/type/-/type-2.7.2.tgz",
456
+ "integrity": "sha512-dzlvlNlt6AXU7EBSfpAscydQ7gXB+pPGsPnfJnZpiNJBDj7IaJzQlBZYGdEi4R9HmPdBv2XmWJ6YUtoTa7lmCw=="
457
+ },
458
+ "node_modules/globalyzer": {
459
+ "version": "0.1.0",
460
+ "resolved": "https://registry.npmjs.org/globalyzer/-/globalyzer-0.1.0.tgz",
461
+ "integrity": "sha512-40oNTM9UfG6aBmuKxk/giHn5nQ8RVz/SS4Ir6zgzOv9/qC3kKZ9v4etGTcJbEl/NyVQH7FGU7d+X1egr57Md2Q=="
462
+ },
463
+ "node_modules/globrex": {
464
+ "version": "0.1.2",
465
+ "resolved": "https://registry.npmjs.org/globrex/-/globrex-0.1.2.tgz",
466
+ "integrity": "sha512-uHJgbwAMwNFf5mLst7IWLNg14x1CkeqglJb/K3doi4dw6q2IvAAmM/Y81kevy83wP+Sst+nutFTYOGg3d1lsxg=="
467
+ },
468
+ "node_modules/intl-messageformat": {
469
+ "version": "9.13.0",
470
+ "resolved": "https://registry.npmjs.org/intl-messageformat/-/intl-messageformat-9.13.0.tgz",
471
+ "integrity": "sha512-7sGC7QnSQGa5LZP7bXLDhVDtQOeKGeBFGHF2Y8LVBwYZoQZCgWeKoPGTa5GMG8g/TzDgeXuYJQis7Ggiw2xTOw==",
472
+ "dependencies": {
473
+ "@formatjs/ecma402-abstract": "1.11.4",
474
+ "@formatjs/fast-memoize": "1.2.1",
475
+ "@formatjs/icu-messageformat-parser": "2.1.0",
476
+ "tslib": "^2.1.0"
477
+ }
478
+ },
479
+ "node_modules/is-promise": {
480
+ "version": "2.2.2",
481
+ "resolved": "https://registry.npmjs.org/is-promise/-/is-promise-2.2.2.tgz",
482
+ "integrity": "sha512-+lP4/6lKUBfQjZ2pdxThZvLUAafmZb8OAxFb8XXtiQmS35INgr85hdOGoEs124ez1FCnZJt6jau/T+alh58QFQ=="
483
+ },
484
+ "node_modules/is-reference": {
485
+ "version": "3.0.2",
486
+ "resolved": "https://registry.npmjs.org/is-reference/-/is-reference-3.0.2.tgz",
487
+ "integrity": "sha512-v3rht/LgVcsdZa3O2Nqs+NMowLOxeOm7Ay9+/ARQ2F+qEoANRcqrjAZKGN0v8ymUetZGgkp26LTnGT7H0Qo9Pg==",
488
+ "peer": true,
489
+ "dependencies": {
490
+ "@types/estree": "*"
491
+ }
492
+ },
493
+ "node_modules/lazy-brush": {
494
+ "version": "1.0.1",
495
+ "resolved": "https://registry.npmjs.org/lazy-brush/-/lazy-brush-1.0.1.tgz",
496
+ "integrity": "sha512-xT/iSClTVi7vLoF8dCWTBhCuOWqsLXCMPa6ucVmVAk6hyNCM5JeS1NLhXqIrJktUg+caEYKlqSOUU4u3cpXzKg=="
497
+ },
498
+ "node_modules/locate-character": {
499
+ "version": "3.0.0",
500
+ "resolved": "https://registry.npmjs.org/locate-character/-/locate-character-3.0.0.tgz",
501
+ "integrity": "sha512-SW13ws7BjaeJ6p7Q6CO2nchbYEc3X3J6WrmTTDto7yMPqVSZTUyY5Tjbid+Ab8gLnATtygYtiDIJGQRRn2ZOiA==",
502
+ "peer": true
503
+ },
504
+ "node_modules/lru-queue": {
505
+ "version": "0.1.0",
506
+ "resolved": "https://registry.npmjs.org/lru-queue/-/lru-queue-0.1.0.tgz",
507
+ "integrity": "sha512-BpdYkt9EvGl8OfWHDQPISVpcl5xZthb+XPsbELj5AQXxIC8IriDZIQYjBJPEm5rS420sjZ0TLEzRcq5KdBhYrQ==",
508
+ "dependencies": {
509
+ "es5-ext": "~0.10.2"
510
+ }
511
+ },
512
+ "node_modules/magic-string": {
513
+ "version": "0.30.5",
514
+ "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.5.tgz",
515
+ "integrity": "sha512-7xlpfBaQaP/T6Vh8MO/EqXSW5En6INHEvEXQiuff7Gku0PWjU3uf6w/j9o7O+SpB5fOAkrI5HeoNgwjEO0pFsA==",
516
+ "peer": true,
517
+ "dependencies": {
518
+ "@jridgewell/sourcemap-codec": "^1.4.15"
519
+ },
520
+ "engines": {
521
+ "node": ">=12"
522
+ }
523
+ },
524
+ "node_modules/mdn-data": {
525
+ "version": "2.0.30",
526
+ "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.0.30.tgz",
527
+ "integrity": "sha512-GaqWWShW4kv/G9IEucWScBx9G1/vsFZZJUO+tD26M8J8z3Kw5RDQjaoZe03YAClgeS/SWPOcb4nkFBTEi5DUEA==",
528
+ "peer": true
529
+ },
530
+ "node_modules/memoizee": {
531
+ "version": "0.4.15",
532
+ "resolved": "https://registry.npmjs.org/memoizee/-/memoizee-0.4.15.tgz",
533
+ "integrity": "sha512-UBWmJpLZd5STPm7PMUlOw/TSy972M+z8gcyQ5veOnSDRREz/0bmpyTfKt3/51DhEBqCZQn1udM/5flcSPYhkdQ==",
534
+ "dependencies": {
535
+ "d": "^1.0.1",
536
+ "es5-ext": "^0.10.53",
537
+ "es6-weak-map": "^2.0.3",
538
+ "event-emitter": "^0.3.5",
539
+ "is-promise": "^2.2.2",
540
+ "lru-queue": "^0.1.0",
541
+ "next-tick": "^1.1.0",
542
+ "timers-ext": "^0.1.7"
543
+ }
544
+ },
545
+ "node_modules/mri": {
546
+ "version": "1.2.0",
547
+ "resolved": "https://registry.npmjs.org/mri/-/mri-1.2.0.tgz",
548
+ "integrity": "sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==",
549
+ "engines": {
550
+ "node": ">=4"
551
+ }
552
+ },
553
+ "node_modules/next-tick": {
554
+ "version": "1.1.0",
555
+ "resolved": "https://registry.npmjs.org/next-tick/-/next-tick-1.1.0.tgz",
556
+ "integrity": "sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ=="
557
+ },
558
+ "node_modules/node-gyp-build": {
559
+ "version": "4.7.1",
560
+ "resolved": "https://registry.npmjs.org/node-gyp-build/-/node-gyp-build-4.7.1.tgz",
561
+ "integrity": "sha512-wTSrZ+8lsRRa3I3H8Xr65dLWSgCvY2l4AOnaeKdPA9TB/WYMPaTcrzf3rXvFoVvjKNVnu0CcWSx54qq9GKRUYg==",
562
+ "bin": {
563
+ "node-gyp-build": "bin.js",
564
+ "node-gyp-build-optional": "optional.js",
565
+ "node-gyp-build-test": "build-test.js"
566
+ }
567
+ },
568
+ "node_modules/path-browserify": {
569
+ "version": "1.0.1",
570
+ "resolved": "https://registry.npmjs.org/path-browserify/-/path-browserify-1.0.1.tgz",
571
+ "integrity": "sha512-b7uo2UCUOYZcnF/3ID0lulOJi/bafxa1xPe7ZPsammBSpjSWQkjNxlt635YGS2MiR9GjvuXCtz2emr3jbsz98g=="
572
+ },
573
+ "node_modules/periscopic": {
574
+ "version": "3.1.0",
575
+ "resolved": "https://registry.npmjs.org/periscopic/-/periscopic-3.1.0.tgz",
576
+ "integrity": "sha512-vKiQ8RRtkl9P+r/+oefh25C3fhybptkHKCZSPlcXiJux2tJF55GnEj3BVn4A5gKfq9NWWXXrxkHBwVPUfH0opw==",
577
+ "peer": true,
578
+ "dependencies": {
579
+ "@types/estree": "^1.0.0",
580
+ "estree-walker": "^3.0.0",
581
+ "is-reference": "^3.0.0"
582
+ }
583
+ },
584
+ "node_modules/resize-observer-polyfill": {
585
+ "version": "1.5.1",
586
+ "resolved": "https://registry.npmjs.org/resize-observer-polyfill/-/resize-observer-polyfill-1.5.1.tgz",
587
+ "integrity": "sha512-LwZrotdHOo12nQuZlHEmtuXdqGoOD0OhaxopaNFxWzInpEgaLWoVuAMbTzixuosCx2nEG58ngzW3vxdWoxIgdg=="
588
+ },
589
+ "node_modules/sade": {
590
+ "version": "1.8.1",
591
+ "resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz",
592
+ "integrity": "sha512-xal3CZX1Xlo/k4ApwCFrHVACi9fBqJ7V+mwhBsuf/1IOKbBy098Fex+Wa/5QMubw09pSZ/u8EY8PWgevJsXp1A==",
593
+ "dependencies": {
594
+ "mri": "^1.1.0"
595
+ },
596
+ "engines": {
597
+ "node": ">=6"
598
+ }
599
+ },
600
+ "node_modules/semiver": {
601
+ "version": "1.1.0",
602
+ "resolved": "https://registry.npmjs.org/semiver/-/semiver-1.1.0.tgz",
603
+ "integrity": "sha512-QNI2ChmuioGC1/xjyYwyZYADILWyW6AmS1UH6gDj/SFUUUS4MBAWs/7mxnkRPc/F4iHezDP+O8t0dO8WHiEOdg==",
604
+ "engines": {
605
+ "node": ">=6"
606
+ }
607
+ },
608
+ "node_modules/source-map-js": {
609
+ "version": "1.0.2",
610
+ "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz",
611
+ "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==",
612
+ "peer": true,
613
+ "engines": {
614
+ "node": ">=0.10.0"
615
+ }
616
+ },
617
+ "node_modules/svelte": {
618
+ "version": "4.2.8",
619
+ "resolved": "https://registry.npmjs.org/svelte/-/svelte-4.2.8.tgz",
620
+ "integrity": "sha512-hU6dh1MPl8gh6klQZwK/n73GiAHiR95IkFsesLPbMeEZi36ydaXL/ZAb4g9sayT0MXzpxyZjR28yderJHxcmYA==",
621
+ "peer": true,
622
+ "dependencies": {
623
+ "@ampproject/remapping": "^2.2.1",
624
+ "@jridgewell/sourcemap-codec": "^1.4.15",
625
+ "@jridgewell/trace-mapping": "^0.3.18",
626
+ "acorn": "^8.9.0",
627
+ "aria-query": "^5.3.0",
628
+ "axobject-query": "^3.2.1",
629
+ "code-red": "^1.0.3",
630
+ "css-tree": "^2.3.1",
631
+ "estree-walker": "^3.0.3",
632
+ "is-reference": "^3.0.1",
633
+ "locate-character": "^3.0.0",
634
+ "magic-string": "^0.30.4",
635
+ "periscopic": "^3.1.0"
636
+ },
637
+ "engines": {
638
+ "node": ">=16"
639
+ }
640
+ },
641
+ "node_modules/svelte-i18n": {
642
+ "version": "3.7.4",
643
+ "resolved": "https://registry.npmjs.org/svelte-i18n/-/svelte-i18n-3.7.4.tgz",
644
+ "integrity": "sha512-yGRCNo+eBT4cPuU7IVsYTYjxB7I2V8qgUZPlHnNctJj5IgbJgV78flsRzpjZ/8iUYZrS49oCt7uxlU3AZv/N5Q==",
645
+ "dependencies": {
646
+ "cli-color": "^2.0.3",
647
+ "deepmerge": "^4.2.2",
648
+ "esbuild": "^0.19.2",
649
+ "estree-walker": "^2",
650
+ "intl-messageformat": "^9.13.0",
651
+ "sade": "^1.8.1",
652
+ "tiny-glob": "^0.2.9"
653
+ },
654
+ "bin": {
655
+ "svelte-i18n": "dist/cli.js"
656
+ },
657
+ "engines": {
658
+ "node": ">= 16"
659
+ },
660
+ "peerDependencies": {
661
+ "svelte": "^3 || ^4"
662
+ }
663
+ },
664
+ "node_modules/svelte-i18n/node_modules/estree-walker": {
665
+ "version": "2.0.2",
666
+ "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz",
667
+ "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w=="
668
+ },
669
+ "node_modules/timers-ext": {
670
+ "version": "0.1.7",
671
+ "resolved": "https://registry.npmjs.org/timers-ext/-/timers-ext-0.1.7.tgz",
672
+ "integrity": "sha512-b85NUNzTSdodShTIbky6ZF02e8STtVVfD+fu4aXXShEELpozH+bCpJLYMPZbsABN2wDH7fJpqIoXxJpzbf0NqQ==",
673
+ "dependencies": {
674
+ "es5-ext": "~0.10.46",
675
+ "next-tick": "1"
676
+ }
677
+ },
678
+ "node_modules/tiny-glob": {
679
+ "version": "0.2.9",
680
+ "resolved": "https://registry.npmjs.org/tiny-glob/-/tiny-glob-0.2.9.tgz",
681
+ "integrity": "sha512-g/55ssRPUjShh+xkfx9UPDXqhckHEsHr4Vd9zX55oSdGZc/MD0m3sferOkwWtp98bv+kcVfEHtRJgBVJzelrzg==",
682
+ "dependencies": {
683
+ "globalyzer": "0.1.0",
684
+ "globrex": "^0.1.2"
685
+ }
686
+ },
687
+ "node_modules/tslib": {
688
+ "version": "2.6.2",
689
+ "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
690
+ "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
691
+ },
692
+ "node_modules/type": {
693
+ "version": "1.2.0",
694
+ "resolved": "https://registry.npmjs.org/type/-/type-1.2.0.tgz",
695
+ "integrity": "sha512-+5nt5AAniqsCnu2cEQQdpzCAh33kVx8n0VoFidKpB1dVVLAN/F+bgVOqOJqOnEnrhp222clB5p3vUlD+1QAnfg=="
696
+ },
697
+ "node_modules/ws": {
698
+ "version": "8.14.2",
699
+ "resolved": "https://registry.npmjs.org/ws/-/ws-8.14.2.tgz",
700
+ "integrity": "sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==",
701
+ "engines": {
702
+ "node": ">=10.0.0"
703
+ },
704
+ "peerDependencies": {
705
+ "bufferutil": "^4.0.1",
706
+ "utf-8-validate": ">=5.0.2"
707
+ },
708
+ "peerDependenciesMeta": {
709
+ "bufferutil": {
710
+ "optional": true
711
+ },
712
+ "utf-8-validate": {
713
+ "optional": true
714
+ }
715
+ }
716
+ }
717
+ }
718
+ }
src/frontend/package.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "gradio_image_prompter",
3
+ "version": "0.4.2",
4
+ "description": "Gradio UI packages",
5
+ "type": "module",
6
+ "author": "",
7
+ "license": "ISC",
8
+ "private": false,
9
+ "dependencies": {
10
+ "@gradio/atoms": "0.3.1",
11
+ "@gradio/client": "0.8.2",
12
+ "@gradio/icons": "0.3.1",
13
+ "@gradio/statustracker": "0.4.1",
14
+ "@gradio/upload": "0.5.2",
15
+ "@gradio/utils": "0.2.0",
16
+ "@gradio/wasm": "0.3.0",
17
+ "cropperjs": "^1.5.12",
18
+ "lazy-brush": "^1.0.1",
19
+ "resize-observer-polyfill": "^1.5.1"
20
+ },
21
+ "main_changeset": true,
22
+ "main": "./Index.svelte",
23
+ "exports": {
24
+ ".": "./Index.svelte",
25
+ "./example": "./Example.svelte",
26
+ "./package.json": "./package.json"
27
+ }
28
+ }
src/frontend/shared/BoxDrawer.svelte ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <svelte:options accessors={true} />
2
+
3
+ <script lang="ts">
4
+ import { createEventDispatcher, onDestroy, onMount, tick } from "svelte";
5
+
6
+ const dispatch = createEventDispatcher();
7
+
8
+ export let width = 0;
9
+ export let height = 0;
10
+ export let natural_width = 0;
11
+ export let natural_height = 0;
12
+
13
+ let boxes: Array<Array<number>> = [];
14
+ let points: Array<Array<number>> = [];
15
+
16
+ let canvas_container: HTMLElement;
17
+ let canvas: HTMLCanvasElement;
18
+ let ctx: CanvasRenderingContext2D | null;
19
+
20
+ let mouse_pressing: boolean = false;
21
+ let mouse_button: number;
22
+ let prev_x: number, prev_y: number;
23
+ let cur_x: number, cur_y: number;
24
+
25
+ let old_width = 0;
26
+ let old_height = 0;
27
+ let canvasObserver: ResizeObserver;
28
+
29
+ async function set_canvas_size(dimensions: {
30
+ width: number;
31
+ height: number;
32
+ }) {
33
+ await tick();
34
+ canvas.width = dimensions.width;
35
+ canvas.height = dimensions.height;
36
+ canvas.style.width = `${dimensions.width}px`;
37
+ canvas.style.height = `${dimensions.height}px`;
38
+ canvas.style.marginTop = `-${dimensions.height}px`;
39
+ }
40
+
41
+ export async function resize_canvas() {
42
+ if (width === old_width && height === old_height) return;
43
+ await set_canvas_size({ width: width, height: height });
44
+ draw_canvas();
45
+ setTimeout(() => {
46
+ old_height = height;
47
+ old_width = width;
48
+ }, 100);
49
+ clear();
50
+ }
51
+
52
+ export function clear() {
53
+ boxes = [];
54
+ points = [];
55
+ draw_canvas();
56
+ dispatch("change", points);
57
+ return true;
58
+ }
59
+
60
+ export function undo() {
61
+ boxes.pop();
62
+ points.pop();
63
+ draw_canvas();
64
+ dispatch("change", points);
65
+ return true;
66
+ }
67
+
68
+ onMount(async () => {
69
+ ctx = canvas.getContext("2d");
70
+ if (ctx) {
71
+ (ctx.lineJoin = "round"), (ctx.lineCap = "round");
72
+ ctx.strokeStyle = "#000";
73
+ }
74
+ canvasObserver = new ResizeObserver(() => {
75
+ resize_canvas();
76
+ });
77
+ canvasObserver.observe(canvas_container);
78
+ draw_loop();
79
+ clear();
80
+ });
81
+
82
+ onDestroy(() => {
83
+ canvasObserver.unobserve(canvas_container);
84
+ });
85
+
86
+ function get_mouse_pos(e: MouseEvent | TouchEvent | FocusEvent) {
87
+ const rect = canvas.getBoundingClientRect();
88
+ let screenX, screenY: number;
89
+ if (e instanceof MouseEvent) {
90
+ screenX = e.clientX;
91
+ screenY = e.clientY;
92
+ } else if (e instanceof TouchEvent) {
93
+ screenX = e.changedTouches[0].clientX;
94
+ screenY = e.changedTouches[0].clientY;
95
+ } else {
96
+ return { x: prev_x, y: prev_y };
97
+ }
98
+ return { x: screenX - rect.left, y: screenY - rect.top };
99
+ }
100
+
101
+ function handle_draw_start(e: MouseEvent | TouchEvent) {
102
+ e.preventDefault();
103
+ (mouse_pressing = true), (mouse_button = 0);
104
+ if (e instanceof MouseEvent) mouse_button = e.button;
105
+ const { x, y } = get_mouse_pos(e);
106
+ (prev_x = x), (prev_y = y);
107
+ }
108
+
109
+ function handle_draw_move(e: MouseEvent | TouchEvent) {
110
+ e.preventDefault();
111
+ const { x, y } = get_mouse_pos(e);
112
+ (cur_x = x), (cur_y = y);
113
+ }
114
+
115
+ function handle_draw_end(e: MouseEvent | TouchEvent | FocusEvent) {
116
+ e.preventDefault();
117
+ if (mouse_pressing) {
118
+ const { x, y } = get_mouse_pos(e);
119
+ let x1 = Math.min(prev_x, x);
120
+ let y1 = Math.min(prev_y, y);
121
+ let x2 = Math.max(prev_x, x);
122
+ let y2 = Math.max(prev_y, y);
123
+ boxes.push([x1, y1, x2, y2]);
124
+ let scale_x = natural_width / width;
125
+ let scale_y = natural_height / height;
126
+ let is_point = x1 == x2 && y1 == y2;
127
+ points.push([
128
+ Math.round(x1 * scale_x),
129
+ Math.round(y1 * scale_y),
130
+ is_point ? (mouse_button == 0 ? 1 : 0) : 2, // label1
131
+ is_point ? 0 : Math.round(x2 * scale_x),
132
+ is_point ? 0 : Math.round(y2 * scale_y),
133
+ is_point ? 4 : 3, // label2
134
+ ]);
135
+ dispatch("change", points);
136
+ }
137
+ mouse_pressing = false;
138
+ }
139
+
140
+ function draw_loop() {
141
+ draw_canvas();
142
+ window.requestAnimationFrame(() => {
143
+ draw_loop();
144
+ });
145
+ }
146
+
147
+ function draw_canvas() {
148
+ if (!ctx) return;
149
+ ctx.clearRect(0, 0, width, height);
150
+ if (mouse_pressing && cur_x != prev_x && prev_y != cur_y) {
151
+ let boxes_temp = boxes.slice();
152
+ boxes_temp.push([prev_x, prev_y, cur_x, cur_y]);
153
+ draw_boxes(boxes_temp);
154
+ draw_points(boxes);
155
+ } else {
156
+ draw_boxes(boxes);
157
+ draw_points(boxes);
158
+ }
159
+ }
160
+
161
+ function draw_boxes(boxes: Array<Array<number>>) {
162
+ if (!ctx) return;
163
+ ctx.fillStyle = "rgba(0, 0, 0, 0.1)";
164
+ ctx.beginPath();
165
+ boxes.forEach((box: Array<number>) => {
166
+ if (box[0] != box[2] && box[1] != box[3]) {
167
+ ctx.rect(box[0], box[1], box[2] - box[0], box[3] - box[1]);
168
+ }
169
+ });
170
+ ctx.fill();
171
+ ctx.stroke();
172
+ }
173
+
174
+ function draw_points(boxes: Array<Array<number>>) {
175
+ if (!ctx) return;
176
+ // Draw foreground points.
177
+ ctx.beginPath();
178
+ ctx.fillStyle = "rgba(0, 255, 255, 1.0)"; // Cyan.
179
+ boxes.forEach((box: Array<number>, index: number) => {
180
+ if (points[index][2] == 1) {
181
+ let radius = Math.sqrt(width * height) * 0.01;
182
+ ctx.moveTo(box[0] + radius, box[1]);
183
+ ctx.arc(box[0], box[1], radius, 0, 2 * Math.PI, false);
184
+ }
185
+ });
186
+ ctx.fill();
187
+ ctx.stroke();
188
+ // Draw background points.
189
+ ctx.beginPath();
190
+ ctx.fillStyle = "rgba(255, 192, 203, 1.0)"; // Pink.
191
+ boxes.forEach((box: Array<number>, index: number) => {
192
+ if (points[index][2] == 0) {
193
+ let radius = Math.sqrt(width * height) * 0.01;
194
+ ctx.moveTo(box[0] + radius, box[1]);
195
+ ctx.arc(box[0], box[1], radius, 0, 2 * Math.PI, false);
196
+ }
197
+ });
198
+ ctx.fill();
199
+ ctx.stroke();
200
+ }
201
+ </script>
202
+
203
+ <div class="wrap" bind:this={canvas_container}>
204
+ <canvas
205
+ bind:this={canvas}
206
+ on:mousedown={handle_draw_start}
207
+ on:mousemove={handle_draw_move}
208
+ on:mouseout={handle_draw_move}
209
+ on:mouseup={handle_draw_end}
210
+ on:touchstart={handle_draw_start}
211
+ on:touchmove={handle_draw_move}
212
+ on:touchend={handle_draw_end}
213
+ on:touchcancel={handle_draw_end}
214
+ on:blur={handle_draw_end}
215
+ on:click|stopPropagation
216
+ style=" z-index: 15"
217
+ />
218
+ </div>
219
+
220
+ <style>
221
+ canvas {
222
+ display: block;
223
+ position: absolute;
224
+ top: 0;
225
+ right: 0;
226
+ bottom: 0;
227
+ left: 0;
228
+ margin: auto;
229
+ }
230
+
231
+ .wrap {
232
+ position: relative;
233
+ width: var(--size-full);
234
+ height: var(--size-full);
235
+ touch-action: none;
236
+ }
237
+ </style>
src/frontend/shared/ClearImage.svelte ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import { createEventDispatcher } from "svelte";
3
+ import { IconButton } from "@gradio/atoms";
4
+ import { Undo, Erase, Clear } from "@gradio/icons";
5
+
6
+ const dispatch = createEventDispatcher();
7
+ </script>
8
+
9
+ <div>
10
+ <IconButton
11
+ Icon={Undo}
12
+ label="Remove Last Box"
13
+ on:click={(event) => {
14
+ dispatch("remove_box");
15
+ event.stopPropagation();
16
+ }}
17
+ />
18
+
19
+ <IconButton
20
+ Icon={Erase}
21
+ label="Remove All boxes"
22
+ on:click={(event) => {
23
+ dispatch("remove_boxes");
24
+ event.stopPropagation();
25
+ }}
26
+ />
27
+
28
+ <IconButton
29
+ Icon={Clear}
30
+ label="Remove Image"
31
+ on:click={(event) => {
32
+ dispatch("remove_image");
33
+ event.stopPropagation();
34
+ }}
35
+ />
36
+ </div>
37
+
38
+ <style>
39
+ div {
40
+ display: flex;
41
+ position: absolute;
42
+ top: var(--size-2);
43
+ right: var(--size-2);
44
+ justify-content: flex-end;
45
+ gap: var(--spacing-sm);
46
+ z-index: var(--layer-5);
47
+ }
48
+ </style>
src/frontend/shared/Image.svelte ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import type { HTMLImgAttributes } from "svelte/elements";
3
+ type $$Props = HTMLImgAttributes;
4
+
5
+ import { resolve_wasm_src } from "@gradio/wasm/svelte";
6
+
7
+ export let src: HTMLImgAttributes["src"] = undefined;
8
+ </script>
9
+
10
+ {#await resolve_wasm_src(src) then resolved_src}
11
+ <!-- svelte-ignore a11y-missing-attribute -->
12
+ <img src={resolved_src} {...$$restProps} />
13
+ {:catch error}
14
+ <p style="color: red;">{error.message}</p>
15
+ {/await}
src/frontend/shared/ImagePreview.svelte ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import { createEventDispatcher } from "svelte";
3
+ import type { SelectData } from "@gradio/utils";
4
+ import { uploadToHuggingFace } from "@gradio/utils";
5
+ import { BlockLabel, Empty, IconButton, ShareButton } from "@gradio/atoms";
6
+ import { Download } from "@gradio/icons";
7
+ import { get_coordinates_of_clicked_image } from "./utils";
8
+
9
+ import { Image } from "@gradio/icons";
10
+ import { type FileData } from "@gradio/client";
11
+ import type { I18nFormatter } from "@gradio/utils";
12
+
13
+ export let value: null | FileData;
14
+ export let label: string | undefined = undefined;
15
+ export let show_label: boolean;
16
+ export let show_download_button = true;
17
+ export let selectable = false;
18
+ export let show_share_button = false;
19
+ export let i18n: I18nFormatter;
20
+
21
+ const dispatch = createEventDispatcher<{
22
+ change: string;
23
+ select: SelectData;
24
+ }>();
25
+
26
+ const handle_click = (evt: MouseEvent): void => {
27
+ let coordinates = get_coordinates_of_clicked_image(evt);
28
+ if (coordinates) {
29
+ dispatch("select", { index: coordinates, value: null });
30
+ }
31
+ };
32
+ </script>
33
+
34
+ <BlockLabel {show_label} Icon={Image} label={label || i18n("image.image")} />
35
+ {#if value === null || !value.url}
36
+ <Empty unpadded_box={true} size="large"><Image /></Empty>
37
+ {:else}
38
+ <div class="icon-buttons">
39
+ {#if show_download_button}
40
+ <a
41
+ href={value.url}
42
+ target={window.__is_colab__ ? "_blank" : null}
43
+ download={value.orig_name || "image"}
44
+ >
45
+ <IconButton Icon={Download} label={i18n("common.download")} />
46
+ </a>
47
+ {/if}
48
+ {#if show_share_button}
49
+ <ShareButton
50
+ {i18n}
51
+ on:share
52
+ on:error
53
+ formatter={async (value) => {
54
+ if (!value) return "";
55
+ let url = await uploadToHuggingFace(value, "base64");
56
+ return `<img src="${url}" />`;
57
+ }}
58
+ {value}
59
+ />
60
+ {/if}
61
+ </div>
62
+ <button on:click={handle_click}>
63
+ <img src={value.url} alt="" class:selectable loading="lazy" />
64
+ </button>
65
+ {/if}
66
+
67
+ <style>
68
+ img,
69
+ button {
70
+ width: var(--size-full);
71
+ height: var(--size-full);
72
+ object-fit: contain;
73
+ display: block;
74
+ border-radius: var(--radius-lg);
75
+ }
76
+
77
+ .selectable {
78
+ cursor: crosshair;
79
+ }
80
+
81
+ .icon-buttons {
82
+ display: flex;
83
+ position: absolute;
84
+ top: 6px;
85
+ right: 6px;
86
+ gap: var(--size-1);
87
+ }
88
+ </style>
src/frontend/shared/ImageUploader.svelte ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import { createEventDispatcher } from "svelte";
3
+ import { BlockLabel } from "@gradio/atoms";
4
+ import { Image } from "@gradio/icons";
5
+ import type { I18nFormatter } from "@gradio/utils";
6
+ import { get_coordinates_of_clicked_image } from "./utils";
7
+ import { ImagePaste, Upload as UploadIcon } from "@gradio/icons";
8
+ import { Toolbar, IconButton } from "@gradio/atoms";
9
+
10
+ import { Upload } from "@gradio/upload";
11
+ import { type FileData, normalise_file } from "@gradio/client";
12
+ import ClearImage from "./ClearImage.svelte";
13
+ import BoxDrawer from "./BoxDrawer.svelte";
14
+
15
+ const dispatch = createEventDispatcher();
16
+ let box_drawer: BoxDrawer;
17
+
18
+ export let value: null | FileData;
19
+ export let points: null | number[][6];
20
+ export let label: string | undefined = undefined;
21
+ export let show_label: boolean;
22
+
23
+ function handle_image_load(event: Event) {
24
+ const element = event.currentTarget as HTMLImageElement;
25
+ box_drawer.width = element.width;
26
+ box_drawer.height = element.height;
27
+ box_drawer.natural_width = element.naturalWidth;
28
+ box_drawer.natural_height = element.naturalHeight;
29
+ box_drawer.resize_canvas();
30
+ }
31
+
32
+ function handle_points_change({ detail }: { detail: number[][6] }) {
33
+ points = detail;
34
+ dispatch("points_change", detail);
35
+ }
36
+
37
+ export let sources: ("clipboard" | "upload")[] = ["upload", "clipboard"];
38
+ export let streaming = false;
39
+ export let root: string;
40
+ export let i18n: I18nFormatter;
41
+
42
+ let upload: Upload;
43
+ let uploading = false;
44
+ export let active_tool: "webcam" | null = null;
45
+
46
+ function handle_upload({ detail }: CustomEvent<FileData>): void {
47
+ value = normalise_file(detail, root, null);
48
+ dispatch("upload", detail);
49
+ }
50
+
51
+ $: if (uploading) value = null;
52
+ $: value && !value.url && (value = normalise_file(value, root, null));
53
+
54
+ let dragging = false;
55
+ $: dispatch("drag", dragging);
56
+
57
+ function handle_click(evt: MouseEvent): void {
58
+ let coordinates = get_coordinates_of_clicked_image(evt);
59
+ if (coordinates) {
60
+ dispatch("select", { index: coordinates, value: null });
61
+ }
62
+ }
63
+
64
+ const sources_meta = {
65
+ upload: {
66
+ icon: UploadIcon,
67
+ label: i18n("Upload"),
68
+ order: 0,
69
+ },
70
+ clipboard: {
71
+ icon: ImagePaste,
72
+ label: i18n("Paste"),
73
+ order: 2,
74
+ },
75
+ };
76
+
77
+ $: sources_list = sources.sort(
78
+ (a, b) => sources_meta[a].order - sources_meta[b].order,
79
+ );
80
+
81
+ async function handle_toolbar(
82
+ source: (typeof sources)[number],
83
+ ): Promise<void> {
84
+ switch (source) {
85
+ case "clipboard":
86
+ navigator.clipboard.read().then(async (items) => {
87
+ for (let i = 0; i < items.length; i++) {
88
+ const type = items[i].types.find((t) => t.startsWith("image/"));
89
+ if (type) {
90
+ value = null;
91
+ items[i].getType(type).then(async (blob) => {
92
+ const f = await upload.load_files([
93
+ new File([blob], `clipboard.${type.replace("image/", "")}`),
94
+ ]);
95
+ f;
96
+ value = f?.[0] || null;
97
+ });
98
+ break;
99
+ }
100
+ }
101
+ });
102
+ break;
103
+ case "upload":
104
+ upload.open_file_upload();
105
+ break;
106
+ default:
107
+ break;
108
+ }
109
+ }
110
+ </script>
111
+
112
+ <BlockLabel {show_label} Icon={Image} label={label || "Image"} />
113
+
114
+ <div data-testid="image" class="image-container">
115
+ {#if value?.url}
116
+ <ClearImage
117
+ on:remove_box={() => {
118
+ box_drawer.undo();
119
+ }}
120
+ on:remove_boxes={() => {
121
+ box_drawer.clear();
122
+ }}
123
+ on:remove_image={() => {
124
+ value = null;
125
+ dispatch("clear");
126
+ }}
127
+ />
128
+ {/if}
129
+ <div class="upload-container">
130
+ <Upload
131
+ hidden={value !== null || active_tool === "webcam"}
132
+ bind:this={upload}
133
+ bind:uploading
134
+ bind:dragging
135
+ filetype="image/*"
136
+ on:load={handle_upload}
137
+ on:error
138
+ {root}
139
+ disable_click={!sources.includes("upload")}
140
+ >
141
+ {#if value === null && !active_tool}
142
+ <slot />
143
+ {/if}
144
+ </Upload>
145
+ {#if value !== null && !streaming}
146
+ <!-- svelte-ignore a11y-click-events-have-key-events-->
147
+ <!-- svelte-ignore a11y-no-noninteractive-element-interactions-->
148
+ <img
149
+ src={value.url}
150
+ alt={value.alt_text}
151
+ on:click={handle_click}
152
+ on:load={handle_image_load}
153
+ />
154
+ <BoxDrawer bind:this={box_drawer} on:change={handle_points_change} />
155
+ {/if}
156
+ </div>
157
+ {#if sources.length > 1 || sources.includes("clipboard")}
158
+ <Toolbar show_border={!value?.url}>
159
+ {#each sources_list as source}
160
+ <IconButton
161
+ on:click={() => handle_toolbar(source)}
162
+ Icon={sources_meta[source].icon}
163
+ size="large"
164
+ label="{source}-image-toolbar-btn"
165
+ padded={false}
166
+ />
167
+ {/each}
168
+ </Toolbar>
169
+ {/if}
170
+ </div>
171
+
172
+ <style>
173
+ img {
174
+ width: var(--size-full);
175
+ height: var(--size-full);
176
+ }
177
+
178
+ .upload-container {
179
+ height: 100%;
180
+ flex-shrink: 1;
181
+ max-height: 100%;
182
+ }
183
+
184
+ .image-container {
185
+ display: flex;
186
+ height: 100%;
187
+ flex-direction: column;
188
+ justify-content: center;
189
+ align-items: center;
190
+ max-height: 100%;
191
+ }
192
+ </style>
src/frontend/shared/utils.ts ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const get_coordinates_of_clicked_image = (
2
+ evt: MouseEvent
3
+ ): [number, number] | null => {
4
+ let image = evt.currentTarget as HTMLImageElement;
5
+
6
+ const imageRect = image.getBoundingClientRect();
7
+ const xScale = image.naturalWidth / imageRect.width;
8
+ const yScale = image.naturalHeight / imageRect.height;
9
+ if (xScale > yScale) {
10
+ const displayed_height = image.naturalHeight / xScale;
11
+ const y_offset = (imageRect.height - displayed_height) / 2;
12
+ var x = Math.round((evt.clientX - imageRect.left) * xScale);
13
+ var y = Math.round((evt.clientY - imageRect.top - y_offset) * xScale);
14
+ } else {
15
+ const displayed_width = image.naturalWidth / yScale;
16
+ const x_offset = (imageRect.width - displayed_width) / 2;
17
+ var x = Math.round((evt.clientX - imageRect.left - x_offset) * yScale);
18
+ var y = Math.round((evt.clientY - imageRect.top) * yScale);
19
+ }
20
+ if (x < 0 || x >= image.naturalWidth || y < 0 || y >= image.naturalHeight) {
21
+ return null;
22
+ }
23
+ return [x, y];
24
+ };
src/pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "hatchling",
4
+ "hatch-requirements-txt",
5
+ "hatch-fancy-pypi-readme>=22.5.0",
6
+ ]
7
+ build-backend = "hatchling.build"
8
+
9
+ [project]
10
+ name = "gradio_image_prompter"
11
+ version = "0.1.0"
12
+ description = "A gradio component to upload images and process point/box prompts."
13
+ readme = "README.md"
14
+ license = "apache-2.0"
15
+ requires-python = ">=3.8"
16
+ url = "https://github.com/PhyscalX/gradio-image-prompter"
17
+ authors = [{ name = "PhyscalX", email = "neopenx@gmail.com" }]
18
+ keywords = ["gradio-custom-component", "gradio-template-Image"]
19
+ # Add dependencies here
20
+ dependencies = ["gradio>=4.0,<5.0"]
21
+ classifiers = [
22
+ 'Development Status :: 3 - Alpha',
23
+ 'License :: OSI Approved :: Apache Software License',
24
+ 'Operating System :: OS Independent',
25
+ 'Programming Language :: Python :: 3',
26
+ 'Programming Language :: Python :: 3 :: Only',
27
+ 'Programming Language :: Python :: 3.8',
28
+ 'Programming Language :: Python :: 3.9',
29
+ 'Programming Language :: Python :: 3.10',
30
+ 'Programming Language :: Python :: 3.11',
31
+ 'Topic :: Scientific/Engineering',
32
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
33
+ 'Topic :: Scientific/Engineering :: Visualization',
34
+ ]
35
+
36
+ [project.optional-dependencies]
37
+ dev = ["build", "twine"]
38
+
39
+ [tool.hatch.build]
40
+ artifacts = ["/backend/gradio_image_prompter/templates", "*.pyi", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates"]
41
+
42
+ [tool.hatch.build.targets.wheel]
43
+ packages = ["/backend/gradio_image_prompter"]
structures/__init__.py ADDED
File without changes
structures/bounding_box.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # transpose
4
+ FLIP_LEFT_RIGHT = 0
5
+ FLIP_TOP_BOTTOM = 1
6
+
7
+
8
+ class BoxList(object):
9
+ """
10
+ This class represents a set of bounding boxes.
11
+ The bounding boxes are represented as a Nx4 Tensor.
12
+ In order to uniquely determine the bounding boxes with respect
13
+ to an image, we also store the corresponding image dimensions.
14
+ They can contain extra information that is specific to each bounding box, such as
15
+ labels.
16
+ """
17
+
18
+ def __init__(self, bbox, image_size, mode="xyxy"):
19
+ device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu")
20
+ # only do as_tensor if isn't a "no-op", because it hurts JIT tracing
21
+ if (not isinstance(bbox, torch.Tensor)
22
+ or bbox.dtype != torch.float32 or bbox.device != device):
23
+ bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
24
+ if bbox.ndimension() == 1 and bbox.size(-1) ==4:
25
+ bbox = bbox.unsqueeze(0)
26
+ if bbox.ndimension() != 2:
27
+ raise ValueError(
28
+ "bbox should have 2 dimensions, got {}".format(bbox.ndimension())
29
+ )
30
+ if bbox.size(-1) != 4:
31
+ raise ValueError(
32
+ "last dimenion of bbox should have a "
33
+ "size of 4, got {}".format(bbox.size(-1))
34
+ )
35
+ if mode not in ("xyxy", "xywh"):
36
+ raise ValueError("mode should be 'xyxy' or 'xywh'")
37
+
38
+ self.bbox = bbox
39
+ self.size = image_size # (image_width, image_height)
40
+ self.mode = mode
41
+ self.extra_fields = {}
42
+
43
+ # note: _jit_wrap/_jit_unwrap only work if the keys and the sizes don't change in between
44
+ def _jit_unwrap(self):
45
+ return (self.bbox,) + tuple(f for f in (self.get_field(field)
46
+ for field in sorted(self.fields()))
47
+ if isinstance(f, torch.Tensor))
48
+
49
+ def _jit_wrap(self, input_stream):
50
+ self.bbox = input_stream[0]
51
+ num_consumed = 1
52
+ for f in sorted(self.fields()):
53
+ if isinstance(self.extra_fields[f], torch.Tensor):
54
+ self.extra_fields[f] = input_stream[num_consumed]
55
+ num_consumed += 1
56
+ return self, input_stream[num_consumed:]
57
+
58
+ def add_field(self, field, field_data):
59
+ self.extra_fields[field] = field_data
60
+
61
+ def get_field(self, field):
62
+ return self.extra_fields[field]
63
+
64
+ def has_field(self, field):
65
+ return field in self.extra_fields
66
+
67
+ def fields(self):
68
+ return list(self.extra_fields.keys())
69
+
70
+ def _copy_extra_fields(self, bbox):
71
+ for k, v in bbox.extra_fields.items():
72
+ self.extra_fields[k] = v
73
+
74
+ def convert(self, mode):
75
+ if mode not in ("xyxy", "xywh"):
76
+ raise ValueError("mode should be 'xyxy' or 'xywh'")
77
+ if mode == self.mode:
78
+ return self
79
+ # we only have two modes, so don't need to check
80
+ # self.mode
81
+ xmin, ymin, xmax, ymax = self._split_into_xyxy()
82
+ if mode == "xyxy":
83
+ bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
84
+ bbox = BoxList(bbox, self.size, mode=mode)
85
+ else:
86
+ TO_REMOVE = 1
87
+ # NOTE: explicitly specify dim to avoid tracing error in GPU
88
+ bbox = torch.cat(
89
+ (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1
90
+ )
91
+ bbox = BoxList(bbox, self.size, mode=mode)
92
+ bbox._copy_extra_fields(self)
93
+ return bbox
94
+
95
+ def _split_into_xyxy(self):
96
+ if self.mode == "xyxy":
97
+ xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
98
+ return xmin, ymin, xmax, ymax
99
+ elif self.mode == "xywh":
100
+ TO_REMOVE = 1
101
+ xmin, ymin, w, h = self.bbox.split(1, dim=-1)
102
+ return (
103
+ xmin,
104
+ ymin,
105
+ xmin + (w - TO_REMOVE).clamp(min=0),
106
+ ymin + (h - TO_REMOVE).clamp(min=0),
107
+ )
108
+ else:
109
+ raise RuntimeError("Should not be here")
110
+
111
+ def resize(self, size, *args, **kwargs):
112
+ """
113
+ Returns a resized copy of this bounding box
114
+
115
+ :param size: The requested size in pixels, as a 2-tuple:
116
+ (width, height).
117
+ """
118
+
119
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
120
+ if ratios[0] == ratios[1]:
121
+ ratio = ratios[0]
122
+ scaled_box = self.bbox * ratio
123
+ bbox = BoxList(scaled_box, size, mode=self.mode)
124
+ # bbox._copy_extra_fields(self)
125
+ for k, v in self.extra_fields.items():
126
+ if not isinstance(v, torch.Tensor):
127
+ v = v.resize(size, *args, **kwargs)
128
+ bbox.add_field(k, v)
129
+ return bbox
130
+
131
+ ratio_width, ratio_height = ratios
132
+ xmin, ymin, xmax, ymax = self._split_into_xyxy()
133
+ scaled_xmin = xmin * ratio_width
134
+ scaled_xmax = xmax * ratio_width
135
+ scaled_ymin = ymin * ratio_height
136
+ scaled_ymax = ymax * ratio_height
137
+ scaled_box = torch.cat(
138
+ (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
139
+ )
140
+ bbox = BoxList(scaled_box, size, mode="xyxy")
141
+ # bbox._copy_extra_fields(self)
142
+ for k, v in self.extra_fields.items():
143
+ if not isinstance(v, torch.Tensor):
144
+ v = v.resize(size, *args, **kwargs)
145
+ bbox.add_field(k, v)
146
+
147
+ return bbox.convert(self.mode)
148
+
149
+ def transpose(self, method):
150
+ """
151
+ Transpose bounding box (flip or rotate in 90 degree steps)
152
+ :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
153
+ :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
154
+ :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
155
+ :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
156
+ """
157
+ if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
158
+ raise NotImplementedError(
159
+ "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
160
+ )
161
+
162
+ image_width, image_height = self.size
163
+ xmin, ymin, xmax, ymax = self._split_into_xyxy()
164
+ if method == FLIP_LEFT_RIGHT:
165
+ TO_REMOVE = 1
166
+ transposed_xmin = image_width - xmax - TO_REMOVE
167
+ transposed_xmax = image_width - xmin - TO_REMOVE
168
+ transposed_ymin = ymin
169
+ transposed_ymax = ymax
170
+ elif method == FLIP_TOP_BOTTOM:
171
+ transposed_xmin = xmin
172
+ transposed_xmax = xmax
173
+ transposed_ymin = image_height - ymax
174
+ transposed_ymax = image_height - ymin
175
+
176
+ transposed_boxes = torch.cat(
177
+ (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
178
+ )
179
+ bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
180
+ # bbox._copy_extra_fields(self)
181
+ for k, v in self.extra_fields.items():
182
+ if not isinstance(v, torch.Tensor):
183
+ v = v.transpose(method)
184
+ bbox.add_field(k, v)
185
+ return bbox.convert(self.mode)
186
+
187
+ def crop(self, box):
188
+ """
189
+ Cropss a rectangular region from this bounding box. The box is a
190
+ 4-tuple defining the left, upper, right, and lower pixel
191
+ coordinate.
192
+ """
193
+ xmin, ymin, xmax, ymax = self._split_into_xyxy()
194
+ w, h = box[2] - box[0], box[3] - box[1]
195
+ cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
196
+ cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
197
+ cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
198
+ cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)
199
+
200
+ # TODO should I filter empty boxes here?
201
+ cropped_box = torch.cat(
202
+ (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1
203
+ )
204
+ bbox = BoxList(cropped_box, (w, h), mode="xyxy")
205
+ # bbox._copy_extra_fields(self)
206
+ for k, v in self.extra_fields.items():
207
+ if not isinstance(v, torch.Tensor):
208
+ v = v.crop(box)
209
+ bbox.add_field(k, v)
210
+ return bbox.convert(self.mode)
211
+
212
+ # Tensor-like methods
213
+
214
+ def to(self, device):
215
+ bbox = BoxList(self.bbox.to(device), self.size, self.mode)
216
+ for k, v in self.extra_fields.items():
217
+ if hasattr(v, "to"):
218
+ v = v.to(device)
219
+ bbox.add_field(k, v)
220
+ return bbox
221
+
222
+ def __getitem__(self, item):
223
+ bbox = BoxList(self.bbox[item], self.size, self.mode)
224
+ for k, v in self.extra_fields.items():
225
+ bbox.add_field(k, v[item])
226
+ return bbox
227
+
228
+ def __len__(self):
229
+ return self.bbox.shape[0]
230
+
231
+ def clip_to_image(self, remove_empty=True):
232
+ TO_REMOVE = 1
233
+ x1s = self.bbox[:, 0].clamp(min=0, max=self.size[0] - TO_REMOVE)
234
+ y1s = self.bbox[:, 1].clamp(min=0, max=self.size[1] - TO_REMOVE)
235
+ x2s = self.bbox[:, 2].clamp(min=0, max=self.size[0] - TO_REMOVE)
236
+ y2s = self.bbox[:, 3].clamp(min=0, max=self.size[1] - TO_REMOVE)
237
+ self.bbox = torch.stack((x1s, y1s, x2s, y2s), dim=-1)
238
+ if remove_empty:
239
+ box = self.bbox
240
+ keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
241
+ return self[keep]
242
+ return self
243
+
244
+ def area(self):
245
+ if self.mode == 'xyxy':
246
+ TO_REMOVE = 1
247
+ box = self.bbox
248
+ area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
249
+ elif self.mode == 'xywh':
250
+ box = self.bbox
251
+ area = box[:, 2] * box[:, 3]
252
+ else:
253
+ raise RuntimeError("Should not be here")
254
+
255
+ return area
256
+
257
+ def copy_with_fields(self, fields):
258
+ bbox = BoxList(self.bbox, self.size, self.mode)
259
+ if not isinstance(fields, (list, tuple)):
260
+ fields = [fields]
261
+ for field in fields:
262
+ bbox.add_field(field, self.get_field(field))
263
+ return bbox
264
+
265
+ def __repr__(self):
266
+ s = self.__class__.__name__ + "("
267
+ s += "num_boxes={}, ".format(len(self))
268
+ s += "image_width={}, ".format(self.size[0])
269
+ s += "image_height={}, ".format(self.size[1])
270
+ s += "mode={})".format(self.mode)
271
+ return s
272
+
273
+ @staticmethod
274
+ def concate_box_list(list_of_boxes):
275
+ boxes = torch.cat([i.bbox for i in list_of_boxes], dim=0)
276
+ extra_fields_keys = list(list_of_boxes[0].extra_fields.keys())
277
+ extra_fields = {}
278
+ for key in extra_fields_keys:
279
+ extra_fields[key] = torch.cat([i.extra_fields[key] for i in list_of_boxes], dim=0)
280
+
281
+ final = list_of_boxes[0].copy_with_fields(extra_fields_keys)
282
+
283
+ final.bbox = boxes
284
+ final.extra_fields = extra_fields
285
+ return final
286
+
287
+
288
+ @torch.jit.unused
289
+ def _onnx_clip_boxes_to_image(boxes, size):
290
+ # type: (Tensor, Tuple[int, int])
291
+ """
292
+ Clip boxes so that they lie inside an image of size `size`.
293
+ Clip's min max are traced as constants. Use torch.min/max to WAR this issue
294
+ Arguments:
295
+ boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
296
+ size (Tuple[height, width]): size of the image
297
+ Returns:
298
+ clipped_boxes (Tensor[N, 4])
299
+ """
300
+ TO_REMOVE = 1
301
+ device = boxes.device
302
+ dim = boxes.dim()
303
+ boxes_x = boxes[..., 0::2]
304
+ boxes_y = boxes[..., 1::2]
305
+
306
+ boxes_x = torch.max(boxes_x, torch.tensor(0., dtype=torch.float).to(device))
307
+ boxes_x = torch.min(boxes_x, torch.tensor(size[1] - TO_REMOVE, dtype=torch.float).to(device))
308
+ boxes_y = torch.max(boxes_y, torch.tensor(0., dtype=torch.float).to(device))
309
+ boxes_y = torch.min(boxes_y, torch.tensor(size[0] - TO_REMOVE, dtype=torch.float).to(device))
310
+
311
+ clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
312
+ return clipped_boxes.reshape(boxes.shape)
313
+
314
+
315
+ if __name__ == "__main__":
316
+ bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10))
317
+ s_bbox = bbox.resize((5, 5))
318
+ print(s_bbox)
319
+ print(s_bbox.bbox)
320
+
321
+ t_bbox = bbox.transpose(0)
322
+ print(t_bbox)
323
+ print(t_bbox.bbox)
structures/grasp_box.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ class GraspCoder:
3
+ """
4
+ This class is to encode grasp annotations similar to BoxCoder class
5
+ It is supposed to support the following functions:
6
+ 1. Encode grasp annotations:
7
+ (x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
8
+ 2. Decode grasp annotations:
9
+ (x_center, y_center, width, height, sine(theta)) -> (x1, y1, x2, y2, x3, y3, x4, y4)
10
+ 3. Resize box grasp annotations when resizing image
11
+ 4. Transform box according to various image augmentations
12
+ One GraspCoder class should encode annotations of one image only
13
+ """
14
+ def __init__(self, height, width, grasp_annos, grasp_annos_reformat=None):
15
+ """
16
+
17
+ Args:
18
+ height: height of image
19
+ width: width of image
20
+ grasp_annos: list of numpy.arrays, each of length 8, in format of (x1, y1, x2, y2, x3, y3, x4, y4)
21
+ """
22
+ self.height = height
23
+ self.width = width
24
+ self.grasp_annos = grasp_annos
25
+ self.grasp_annos_reformat = grasp_annos_reformat
26
+ def __len__(self):
27
+ return len(self.grasp_annos)
28
+ def encode(self, normalize=True):
29
+ """
30
+ (x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
31
+ Args:
32
+ normalize -> bool: return values normalized to 0~1 or not
33
+ Returns:
34
+ grasp_annos_reformat: List of numpy.array
35
+ """
36
+ grasp_annos_reformat = []
37
+ for grasp in self.grasp_annos:
38
+ x1, y1, x2, y2, x3, y3, x4, y4 = tuple(grasp)
39
+ if (x1 + x2) < (x3 + x4):
40
+ x1, y1, x2, y2, x3, y3, x4, y4 = x3, y3, x4, y4, x1, y1, x2, y2
41
+ x_center = (x1 + x3)/2
42
+ y_center = (y1 + y3)/2
43
+ width = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
44
+ height = np.sqrt((x2 - x3)**2 + (y2 - y3)**2)
45
+ sine = ((y1 + y2)/2 - y_center) / (height / 2)
46
+ if normalize:
47
+ x_center /= self.width
48
+ y_center /= self.height
49
+ width /= self.width
50
+ height /= self.height
51
+ sine = (sine + 1) / 2
52
+ grasp_annos_reformat.append(np.array([x_center, y_center, width, height, sine]))
53
+ self.grasp_annos_reformat = grasp_annos_reformat
54
+ return grasp_annos_reformat
55
+ def decode(self):
56
+ """
57
+ Decode normalized grasp_annos_reformat, will overwrite self.grasp_annos, and return the overwritten value
58
+ (x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
59
+ Returns:
60
+ grasp_annos: List of numpy.array
61
+ """
62
+ grasp_annos = []
63
+ for grasp in self.grasp_annos_reformat:
64
+ x_center, y_center, width, height, sine = tuple(grasp)
65
+ x_center *= self.width
66
+ y_center *= self.height
67
+ width *= self.width
68
+ height *= self.height
69
+ sine = sine * 2 - 1
70
+ cosine = np.sqrt(1 - sine ** 2)
71
+ angle = np.arcsin(sine)
72
+ x1 = x_center + cosine * height / 2 + sine * width / 2
73
+ x2 = x_center + cosine * height / 2 - sine * width / 2
74
+ y1 = y_center + sine * height / 2 - cosine * width / 2
75
+ y2 = y_center + sine * height / 2 + cosine * width / 2
76
+ x3 = x_center * 2 - x1
77
+ x4 = x_center * 2 - x2
78
+ y3 = y_center * 2 - y1
79
+ y4 = y_center * 2 - y2
80
+ grasp_annos.append(np.array([x1, y1, x2, y2, x3, y3, x4, y4]))
81
+ self.grasp_annos = grasp_annos
82
+ return grasp_annos
83
+
84
+ def resize(self, new_size):
85
+ """
86
+ Resize the grasp annotations according to resized image
87
+ Args:
88
+ new_size -> Tuple: (new_width, new_height)
89
+ new_height: The resized image height
90
+ new_width: The resized image width
91
+
92
+ Returns:
93
+ self
94
+ """
95
+ new_width, new_height = new_size
96
+ grasp_annos = self.grasp_annos
97
+ old_height, old_width = self.height, self.width
98
+ resized_grasp_annos = []
99
+ for grasp in grasp_annos:
100
+ grasp[0::2] = grasp[0::2] / old_width * new_width
101
+ grasp[1::2] = grasp[1::2] / old_height * new_height
102
+ resized_grasp_annos.append(grasp)
103
+ self.grasp_annos = resized_grasp_annos
104
+ self.height, self.width = new_height, new_width
105
+
106
+ return self
107
+ def transpose(self, axis):
108
+ """
109
+ For Horizontal/Vertical flip
110
+ Args:
111
+ axis: 0 represents X axis, 1 represnets Y axis
112
+
113
+ Returns:
114
+ self
115
+ """
116
+ grasp_annos = self.grasp_annos
117
+ flipped_grasp_annos = []
118
+ if axis == 0:
119
+ for grasp in grasp_annos:
120
+ grasp[0::2] = self.width - grasp[0::2]
121
+ flipped_grasp_annos.append(grasp)
122
+ elif axis == 1:
123
+ for grasp in grasp_annos:
124
+ grasp[1::2] = self.height - grasp[1::2]
125
+ flipped_grasp_annos.append(grasp)
126
+ self.grasp_annos = flipped_grasp_annos
127
+ return self
structures/image_list.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ class ImageList(object):
5
+ """
6
+ Structure that holds a list of images (of possibly
7
+ varying sizes) as a single tensor.
8
+ This works by padding the images to the same size,
9
+ and storing in a field the original sizes of each image
10
+ """
11
+
12
+ def __init__(self, tensors, image_sizes):
13
+ """
14
+ Arguments:
15
+ tensors (tensor)
16
+ image_sizes (list[tuple[int, int]])
17
+ """
18
+ self.tensors = tensors
19
+ self.image_sizes = image_sizes
20
+
21
+ def to(self, *args, **kwargs):
22
+ cast_tensor = self.tensors.to(*args, **kwargs)
23
+ return ImageList(cast_tensor, self.image_sizes)
24
+
25
+
26
+ def to_image_list(tensors, size_divisible=0):
27
+ """
28
+ tensors can be an ImageList, a torch.Tensor or
29
+ an iterable of Tensors. It can't be a numpy array.
30
+ When tensors is an iterable of Tensors, it pads
31
+ the Tensors with zeros so that they have the same
32
+ shape
33
+ """
34
+ if isinstance(tensors, torch.Tensor) and size_divisible > 0:
35
+ tensors = [tensors]
36
+
37
+ if isinstance(tensors, ImageList):
38
+ return tensors
39
+ elif isinstance(tensors, torch.Tensor):
40
+ # single tensor shape can be inferred
41
+ assert tensors.dim() == 4
42
+ image_sizes = [tensor.shape[-2:] for tensor in tensors]
43
+ return ImageList(tensors, image_sizes)
44
+ elif isinstance(tensors, (tuple, list)):
45
+ max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
46
+
47
+ # TODO Ideally, just remove this and let me model handle arbitrary
48
+ # input sizs
49
+ if size_divisible > 0:
50
+ import math
51
+
52
+ stride = size_divisible
53
+ max_size = list(max_size)
54
+ max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
55
+ max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
56
+ max_size = tuple(max_size)
57
+
58
+ batch_shape = (len(tensors),) + max_size
59
+ batched_imgs = tensors[0].new(*batch_shape).zero_()
60
+ for img, pad_img in zip(tensors, batched_imgs):
61
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
62
+
63
+ image_sizes = [im.shape[-2:] for im in tensors]
64
+
65
+ return ImageList(batched_imgs, image_sizes)
66
+ else:
67
+ raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
structures/segmentation_mask.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import pycocotools.mask as mask_utils
5
+
6
+ # transpose
7
+ FLIP_LEFT_RIGHT = 0
8
+ FLIP_TOP_BOTTOM = 1
9
+
10
+
11
+ class MaskList(object):
12
+ """
13
+ This class is unfinished and not meant for use yet
14
+ It is supposed to contain the binary masks for all instances in a list of 2D tensors (H, W)
15
+ """
16
+
17
+ def __init__(self, masks, size, mode):
18
+ assert(isinstance(masks, list))
19
+ assert(mode in ['mask', 'rle'])
20
+ self.masks = masks
21
+ self.size = size # (image_width, image_height)
22
+ self.mode = mode
23
+
24
+ def transpose(self, method):
25
+ assert (self.mode == "mask"), "RLE masks cannot be transposed. Please convert them to binary first."
26
+ if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
27
+ raise NotImplementedError(
28
+ "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
29
+ )
30
+
31
+ # width, height = self.size
32
+ masks = np.array(self.masks)
33
+ if masks.ndim == 2:
34
+ masks = np.expand_dims(masks, axis=0)
35
+ if method == FLIP_LEFT_RIGHT:
36
+ masks = np.flip(masks, axis=2)
37
+ elif method == FLIP_TOP_BOTTOM:
38
+ masks = np.flip(masks, axis=1)
39
+ flipped_masks = np.split(masks, masks.shape[0])
40
+ flipped_masks = [mask.squeeze(0) for mask in flipped_masks]
41
+ return MaskList(flipped_masks, self.size, self.mode)
42
+
43
+ def resize(self, size, *args, **kwargs):
44
+ """
45
+ Resize the binary mask.
46
+ :param size: tuple, (image_width, image_height)
47
+ :param args:
48
+ :param kwargs:
49
+ :return:
50
+ """
51
+ assert(self.mode == "mask"), "RLE masks cannot be resized. Please convert them to binary first."
52
+ cat_mask = np.array(self.masks)
53
+
54
+ cat_mask = cat_mask.transpose(1, 2, 0)
55
+ cat_mask *= 255
56
+ cat_mask = cat_mask.astype(np.uint8)
57
+ resized_mask = cv2.resize(cat_mask, size)
58
+ if resized_mask.ndim == 2:
59
+ resized_mask = np.expand_dims(resized_mask, axis=2)
60
+ try:
61
+ resized_mask = resized_mask.transpose(2, 0, 1)
62
+ except ValueError:
63
+ print("?")
64
+ resized_mask = resized_mask.astype(int)
65
+ resized_mask = resized_mask // 255
66
+ # # visualize to check mask correctness
67
+ # from matplotlib import pyplot as plt
68
+ # plt.figure()
69
+ # plt.imshow(resized_mask[0]*255, cmap='gray')
70
+ # plt.show()
71
+ mask_list = np.split(resized_mask, resized_mask.shape[0])
72
+ mask_list = [mask.squeeze(0) for mask in mask_list]
73
+ return MaskList(mask_list, size, "mask")
74
+
75
+ def pad(self, size):
76
+ """
77
+ pad the binary masks according to the new size. New size must be larger than original size in all dimensions
78
+ :param size: New image size, (image_width, image_height)
79
+ :return:
80
+ """
81
+ assert(size[0] >= self.size[0] and size[1] >= self.size[1]), "New size must be larger than original size in all dimensions"
82
+ cat_mask = np.array(self.masks)
83
+ if cat_mask.ndim == 2:
84
+ cat_mask = np.expand_dims(cat_mask, axis=0)
85
+ padded_mask = np.zeros([len(self.masks), size[1], size[0]])
86
+ padded_mask[:, :cat_mask.shape[1], :cat_mask.shape[2]] = cat_mask
87
+ # # visualize to check mask correctness
88
+ # from matplotlib import pyplot as plt
89
+ # plt.figure()
90
+ # plt.imshow(padded_mask[1]*255, cmap='gray')
91
+ # plt.show()
92
+ mask_list = np.split(padded_mask, padded_mask.shape[0])
93
+ mask_list = [mask.squeeze(0) for mask in mask_list]
94
+ return MaskList(mask_list, size, "mask")
95
+
96
+ def convert(self, mode):
97
+ """
98
+ Convert mask from between mode "mask" and mode "rle"
99
+ :param mode:
100
+ :return:
101
+ """
102
+ if mode == self.mode:
103
+ return self
104
+ elif mode == "rle" and self.mode == "mask":
105
+ # use pycocotools to encode binary masks to rle
106
+ rle_mask_list = mask_utils.encode(np.asfortranarray(np.array(self.masks).transpose(1, 2, 0).astype(np.uint8)))
107
+ return MaskList(rle_mask_list, self.size, "rle")
108
+ elif mode == "mask" and self.mode == "rle":
109
+ # use pycocotools to decode rle to binary masks
110
+ bimasks = mask_utils.decode(self.masks)
111
+ mask_list = np.split(bimasks.transpose(2, 0, 1), bimasks.shape[2])
112
+ mask_list = [mask.squeeze(0) for mask in mask_list]
113
+ return MaskList(mask_list, self.size, "mask")
114
+
115
+ def bbox(self, bbox_mode="xyxy"):
116
+ """
117
+ Generate a bounding box according to the binary mask
118
+ :param bbox_mode:
119
+ :return:
120
+ """
121
+ pass
122
+
123
+ def __len__(self):
124
+ return len(self.masks)
125
+
126
+ def __repr__(self):
127
+ s = self.__class__.__name__ + "("
128
+ s += "num_masks={}, ".format(len(self))
129
+ s += "image_width={}, ".format(self.size[0])
130
+ s += "image_height={}, ".format(self.size[1])
131
+ s += "mode={})".format(self.mode)
132
+ return s
133
+
134
+
135
+ class Polygons(object):
136
+ """
137
+ This class holds a set of polygons that represents a single instance
138
+ of an object mask. The object can be represented as a set of
139
+ polygons
140
+ """
141
+
142
+ def __init__(self, polygons, size, mode):
143
+ # assert isinstance(polygons, list), '{}'.format(polygons)
144
+ if isinstance(polygons, list):
145
+ polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons]
146
+ elif isinstance(polygons, Polygons):
147
+ polygons = polygons.polygons
148
+
149
+ self.polygons = polygons
150
+ self.size = size
151
+ self.mode = mode
152
+
153
+ def transpose(self, method):
154
+ if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
155
+ raise NotImplementedError(
156
+ "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
157
+ )
158
+
159
+ flipped_polygons = []
160
+ width, height = self.size
161
+ if method == FLIP_LEFT_RIGHT:
162
+ dim = width
163
+ idx = 0
164
+ elif method == FLIP_TOP_BOTTOM:
165
+ dim = height
166
+ idx = 1
167
+
168
+ for poly in self.polygons:
169
+ p = poly.clone()
170
+ TO_REMOVE = 1
171
+ p[idx::2] = dim - poly[idx::2] - TO_REMOVE
172
+ flipped_polygons.append(p)
173
+
174
+ return Polygons(flipped_polygons, size=self.size, mode=self.mode)
175
+
176
+ def crop(self, box):
177
+ w, h = box[2] - box[0], box[3] - box[1]
178
+
179
+ # TODO chck if necessary
180
+ w = max(w, 1)
181
+ h = max(h, 1)
182
+
183
+ cropped_polygons = []
184
+ for poly in self.polygons:
185
+ p = poly.clone()
186
+ p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w)
187
+ p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h)
188
+ cropped_polygons.append(p)
189
+
190
+ return Polygons(cropped_polygons, size=(w, h), mode=self.mode)
191
+
192
+ def resize(self, size, *args, **kwargs):
193
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
194
+ if ratios[0] == ratios[1]:
195
+ ratio = ratios[0]
196
+ scaled_polys = [p * ratio for p in self.polygons]
197
+ return Polygons(scaled_polys, size, mode=self.mode)
198
+
199
+ ratio_w, ratio_h = ratios
200
+ scaled_polygons = []
201
+ for poly in self.polygons:
202
+ p = poly.clone()
203
+ p[0::2] *= ratio_w
204
+ p[1::2] *= ratio_h
205
+ scaled_polygons.append(p)
206
+
207
+ return Polygons(scaled_polygons, size=size, mode=self.mode)
208
+
209
+ def convert(self, mode):
210
+ width, height = self.size
211
+ if mode == "mask":
212
+ rles = mask_utils.frPyObjects(
213
+ [p.detach().numpy() for p in self.polygons], height, width
214
+ )
215
+ rle = mask_utils.merge(rles)
216
+ mask = mask_utils.decode(rle)
217
+ mask = torch.from_numpy(mask)
218
+ # TODO add squeeze?
219
+ return mask
220
+
221
+ def __repr__(self):
222
+ s = self.__class__.__name__ + "("
223
+ s += "num_polygons={}, ".format(len(self.polygons))
224
+ s += "image_width={}, ".format(self.size[0])
225
+ s += "image_height={}, ".format(self.size[1])
226
+ s += "mode={})".format(self.mode)
227
+ return s
228
+
229
+
230
+ class SegmentationMask(object):
231
+ """
232
+ This class stores the segmentations for all objects in the image
233
+ """
234
+
235
+ def __init__(self, polygons, size, mode=None):
236
+ """
237
+ Arguments:
238
+ polygons: a list of list of lists of numbers. The first
239
+ level of the list correspond to individual instances,
240
+ the second level to all the polygons that compose the
241
+ object, and the third level to the polygon coordinates.
242
+ """
243
+ assert isinstance(polygons, list)
244
+
245
+ self.polygons = [Polygons(p, size, mode) for p in polygons]
246
+ self.size = size
247
+ self.mode = mode
248
+
249
+ def transpose(self, method):
250
+ if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
251
+ raise NotImplementedError(
252
+ "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
253
+ )
254
+
255
+ flipped = []
256
+ for polygon in self.polygons:
257
+ flipped.append(polygon.transpose(method))
258
+ return SegmentationMask(flipped, size=self.size, mode=self.mode)
259
+
260
+ def crop(self, box):
261
+ w, h = box[2] - box[0], box[3] - box[1]
262
+ cropped = []
263
+ for polygon in self.polygons:
264
+ cropped.append(polygon.crop(box))
265
+ return SegmentationMask(cropped, size=(w, h), mode=self.mode)
266
+
267
+ def resize(self, size, *args, **kwargs):
268
+ scaled = []
269
+ for polygon in self.polygons:
270
+ scaled.append(polygon.resize(size, *args, **kwargs))
271
+ return SegmentationMask(scaled, size=size, mode=self.mode)
272
+
273
+ def to(self, *args, **kwargs):
274
+ return self
275
+
276
+ def __getitem__(self, item):
277
+ if isinstance(item, (int, slice)):
278
+ selected_polygons = [self.polygons[item]]
279
+ else:
280
+ # advanced indexing on a single dimension
281
+ selected_polygons = []
282
+ if isinstance(item, torch.Tensor) and item.dtype == torch.bool:
283
+ item = item.nonzero()
284
+ item = item.squeeze(1) if item.numel() > 0 else item
285
+ item = item.tolist()
286
+ for i in item:
287
+ selected_polygons.append(self.polygons[i])
288
+ return SegmentationMask(selected_polygons, size=self.size, mode=self.mode)
289
+
290
+ def __iter__(self):
291
+ return iter(self.polygons)
292
+
293
+ def __repr__(self):
294
+ s = self.__class__.__name__ + "("
295
+ s += "num_instances={}, ".format(len(self.polygons))
296
+ s += "image_width={}, ".format(self.size[0])
297
+ s += "image_height={})".format(self.size[1])
298
+ return s