pg56714 commited on
Commit
a375a27
1 Parent(s): d23df68

Upload 110 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. efficient_sam/__init__.py +7 -0
  2. efficient_sam/__pycache__/__init__.cpython-310.pyc +0 -0
  3. efficient_sam/__pycache__/build_efficient_sam.cpython-310.pyc +0 -0
  4. efficient_sam/__pycache__/efficient_sam.cpython-310.pyc +0 -0
  5. efficient_sam/__pycache__/efficient_sam_decoder.cpython-310.pyc +0 -0
  6. efficient_sam/__pycache__/efficient_sam_encoder.cpython-310.pyc +0 -0
  7. efficient_sam/__pycache__/mlp.cpython-310.pyc +0 -0
  8. efficient_sam/__pycache__/two_way_transformer.cpython-310.pyc +0 -0
  9. efficient_sam/build_efficient_sam.py +22 -0
  10. efficient_sam/efficient_sam.py +305 -0
  11. efficient_sam/efficient_sam_decoder.py +315 -0
  12. efficient_sam/efficient_sam_encoder.py +257 -0
  13. efficient_sam/mlp.py +29 -0
  14. efficient_sam/two_way_transformer.py +266 -0
  15. efficientvit/__init__.py +0 -0
  16. efficientvit/__pycache__/__init__.cpython-310.pyc +0 -0
  17. efficientvit/__pycache__/sam_model_zoo.cpython-310.pyc +0 -0
  18. efficientvit/apps/__init__.py +0 -0
  19. efficientvit/apps/__pycache__/__init__.cpython-310.pyc +0 -0
  20. efficientvit/apps/data_provider/__init__.py +7 -0
  21. efficientvit/apps/data_provider/__pycache__/__init__.cpython-310.pyc +0 -0
  22. efficientvit/apps/data_provider/__pycache__/base.cpython-310.pyc +0 -0
  23. efficientvit/apps/data_provider/augment/__init__.py +6 -0
  24. efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-310.pyc +0 -0
  25. efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-310.pyc +0 -0
  26. efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-310.pyc +0 -0
  27. efficientvit/apps/data_provider/augment/bbox.py +30 -0
  28. efficientvit/apps/data_provider/augment/color_aug.py +78 -0
  29. efficientvit/apps/data_provider/base.py +199 -0
  30. efficientvit/apps/data_provider/random_resolution/__init__.py +7 -0
  31. efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-310.pyc +0 -0
  32. efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-310.pyc +0 -0
  33. efficientvit/apps/data_provider/random_resolution/_data_loader.py +1538 -0
  34. efficientvit/apps/data_provider/random_resolution/_data_worker.py +358 -0
  35. efficientvit/apps/data_provider/random_resolution/controller.py +92 -0
  36. efficientvit/apps/setup.py +135 -0
  37. efficientvit/apps/trainer/__init__.py +6 -0
  38. efficientvit/apps/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  39. efficientvit/apps/trainer/__pycache__/base.cpython-310.pyc +0 -0
  40. efficientvit/apps/trainer/__pycache__/run_config.cpython-310.pyc +0 -0
  41. efficientvit/apps/trainer/base.py +299 -0
  42. efficientvit/apps/trainer/run_config.py +115 -0
  43. efficientvit/apps/utils/__init__.py +12 -0
  44. efficientvit/apps/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  45. efficientvit/apps/utils/__pycache__/dist.cpython-310.pyc +0 -0
  46. efficientvit/apps/utils/__pycache__/ema.cpython-310.pyc +0 -0
  47. efficientvit/apps/utils/__pycache__/export.cpython-310.pyc +0 -0
  48. efficientvit/apps/utils/__pycache__/init.cpython-310.pyc +0 -0
  49. efficientvit/apps/utils/__pycache__/lr.cpython-310.pyc +0 -0
  50. efficientvit/apps/utils/__pycache__/metric.cpython-310.pyc +0 -0
efficient_sam/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ from .build_efficient_sam import (
5
+ build_efficient_sam_vitt,
6
+ build_efficient_sam_vits,
7
+ )
efficient_sam/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (274 Bytes). View file
 
efficient_sam/__pycache__/build_efficient_sam.cpython-310.pyc ADDED
Binary file (650 Bytes). View file
 
efficient_sam/__pycache__/efficient_sam.cpython-310.pyc ADDED
Binary file (8.13 kB). View file
 
efficient_sam/__pycache__/efficient_sam_decoder.cpython-310.pyc ADDED
Binary file (9.79 kB). View file
 
efficient_sam/__pycache__/efficient_sam_encoder.cpython-310.pyc ADDED
Binary file (7.34 kB). View file
 
efficient_sam/__pycache__/mlp.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
efficient_sam/__pycache__/two_way_transformer.cpython-310.pyc ADDED
Binary file (7.34 kB). View file
 
efficient_sam/build_efficient_sam.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .efficient_sam import build_efficient_sam
8
+
9
+ def build_efficient_sam_vitt():
10
+ return build_efficient_sam(
11
+ encoder_patch_embed_dim=192,
12
+ encoder_num_heads=3,
13
+ checkpoint="weights/efficient_sam_vitt.pt",
14
+ ).eval()
15
+
16
+
17
+ def build_efficient_sam_vits():
18
+ return build_efficient_sam(
19
+ encoder_patch_embed_dim=384,
20
+ encoder_num_heads=6,
21
+ checkpoint="weights/efficient_sam_vits.pt",
22
+ ).eval()
efficient_sam/efficient_sam.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, List, Tuple, Type
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from torch import nn, Tensor
14
+
15
+ from .efficient_sam_decoder import MaskDecoder, PromptEncoder
16
+ from .efficient_sam_encoder import ImageEncoderViT
17
+ from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
18
+
19
+ class EfficientSam(nn.Module):
20
+ mask_threshold: float = 0.0
21
+ image_format: str = "RGB"
22
+
23
+ def __init__(
24
+ self,
25
+ image_encoder: ImageEncoderViT,
26
+ prompt_encoder: PromptEncoder,
27
+ decoder_max_num_input_points: int,
28
+ mask_decoder: MaskDecoder,
29
+ pixel_mean: List[float] = [0.485, 0.456, 0.406],
30
+ pixel_std: List[float] = [0.229, 0.224, 0.225],
31
+ ) -> None:
32
+ """
33
+ SAM predicts object masks from an image and input prompts.
34
+
35
+ Arguments:
36
+ image_encoder (ImageEncoderViT): The backbone used to encode the
37
+ image into image embeddings that allow for efficient mask prediction.
38
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
39
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
40
+ and encoded prompts.
41
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
42
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
43
+ """
44
+ super().__init__()
45
+ self.image_encoder = image_encoder
46
+ self.prompt_encoder = prompt_encoder
47
+ self.decoder_max_num_input_points = decoder_max_num_input_points
48
+ self.mask_decoder = mask_decoder
49
+ self.register_buffer(
50
+ "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
51
+ )
52
+ self.register_buffer(
53
+ "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False
54
+ )
55
+
56
+ @torch.jit.export
57
+ def predict_masks(
58
+ self,
59
+ image_embeddings: torch.Tensor,
60
+ batched_points: torch.Tensor,
61
+ batched_point_labels: torch.Tensor,
62
+ multimask_output: bool,
63
+ input_h: int,
64
+ input_w: int,
65
+ output_h: int = -1,
66
+ output_w: int = -1,
67
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ """
69
+ Predicts masks given image embeddings and prompts. This only runs the decoder.
70
+
71
+ Arguments:
72
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
73
+ batched_points: A tensor of shape [B, max_num_queries, num_pts, 2]
74
+ batched_point_labels: A tensor of shape [B, max_num_queries, num_pts]
75
+ Returns:
76
+ A tuple of two tensors:
77
+ low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks
78
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
79
+ """
80
+
81
+ batch_size, max_num_queries, num_pts, _ = batched_points.shape
82
+ num_pts = batched_points.shape[2]
83
+ rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w)
84
+
85
+ if num_pts > self.decoder_max_num_input_points:
86
+ rescaled_batched_points = rescaled_batched_points[
87
+ :, :, : self.decoder_max_num_input_points, :
88
+ ]
89
+ batched_point_labels = batched_point_labels[
90
+ :, :, : self.decoder_max_num_input_points
91
+ ]
92
+ elif num_pts < self.decoder_max_num_input_points:
93
+ rescaled_batched_points = F.pad(
94
+ rescaled_batched_points,
95
+ (0, 0, 0, self.decoder_max_num_input_points - num_pts),
96
+ value=-1.0,
97
+ )
98
+ batched_point_labels = F.pad(
99
+ batched_point_labels,
100
+ (0, self.decoder_max_num_input_points - num_pts),
101
+ value=-1.0,
102
+ )
103
+
104
+ sparse_embeddings = self.prompt_encoder(
105
+ rescaled_batched_points.reshape(
106
+ batch_size * max_num_queries, self.decoder_max_num_input_points, 2
107
+ ),
108
+ batched_point_labels.reshape(
109
+ batch_size * max_num_queries, self.decoder_max_num_input_points
110
+ ),
111
+ )
112
+ sparse_embeddings = sparse_embeddings.view(
113
+ batch_size,
114
+ max_num_queries,
115
+ sparse_embeddings.shape[1],
116
+ sparse_embeddings.shape[2],
117
+ )
118
+ low_res_masks, iou_predictions = self.mask_decoder(
119
+ image_embeddings,
120
+ self.prompt_encoder.get_dense_pe(),
121
+ sparse_prompt_embeddings=sparse_embeddings,
122
+ multimask_output=multimask_output,
123
+ )
124
+ _, num_predictions, low_res_size, _ = low_res_masks.shape
125
+
126
+ if output_w > 0 and output_h > 0:
127
+ output_masks = F.interpolate(
128
+ low_res_masks, (output_h, output_w), mode="bicubic"
129
+ )
130
+ output_masks = torch.reshape(
131
+ output_masks,
132
+ (batch_size, max_num_queries, num_predictions, output_h, output_w),
133
+ )
134
+ else:
135
+ output_masks = torch.reshape(
136
+ low_res_masks,
137
+ (
138
+ batch_size,
139
+ max_num_queries,
140
+ num_predictions,
141
+ low_res_size,
142
+ low_res_size,
143
+ ),
144
+ )
145
+ iou_predictions = torch.reshape(
146
+ iou_predictions, (batch_size, max_num_queries, num_predictions)
147
+ )
148
+ return output_masks, iou_predictions
149
+
150
+ def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int):
151
+ return torch.stack(
152
+ [
153
+ torch.where(
154
+ batched_points[..., 0] >= 0,
155
+ batched_points[..., 0] * self.image_encoder.img_size / input_w,
156
+ -1.0,
157
+ ),
158
+ torch.where(
159
+ batched_points[..., 1] >= 0,
160
+ batched_points[..., 1] * self.image_encoder.img_size / input_h,
161
+ -1.0,
162
+ ),
163
+ ],
164
+ dim=-1,
165
+ )
166
+
167
+ @torch.jit.export
168
+ def get_image_embeddings(self, batched_images) -> torch.Tensor:
169
+ """
170
+ Predicts masks end-to-end from provided images and prompts.
171
+ If prompts are not known in advance, using SamPredictor is
172
+ recommended over calling the model directly.
173
+
174
+ Arguments:
175
+ batched_images: A tensor of shape [B, 3, H, W]
176
+ Returns:
177
+ List of image embeddings each of of shape [B, C(i), H(i), W(i)].
178
+ The last embedding corresponds to the final layer.
179
+ """
180
+ batched_images = self.preprocess(batched_images)
181
+ return self.image_encoder(batched_images)
182
+
183
+ def forward(
184
+ self,
185
+ batched_images: torch.Tensor,
186
+ batched_points: torch.Tensor,
187
+ batched_point_labels: torch.Tensor,
188
+ scale_to_original_image_size: bool = True,
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Predicts masks end-to-end from provided images and prompts.
192
+ If prompts are not known in advance, using SamPredictor is
193
+ recommended over calling the model directly.
194
+
195
+ Arguments:
196
+ batched_images: A tensor of shape [B, 3, H, W]
197
+ batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
198
+ batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
199
+
200
+ Returns:
201
+ A list tuples of two tensors where the ith element is by considering the first i+1 points.
202
+ low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
203
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
204
+ """
205
+ batch_size, _, input_h, input_w = batched_images.shape
206
+ image_embeddings = self.get_image_embeddings(batched_images)
207
+ return self.predict_masks(
208
+ image_embeddings,
209
+ batched_points,
210
+ batched_point_labels,
211
+ multimask_output=True,
212
+ input_h=input_h,
213
+ input_w=input_w,
214
+ output_h=input_h if scale_to_original_image_size else -1,
215
+ output_w=input_w if scale_to_original_image_size else -1,
216
+ )
217
+
218
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
219
+ """Normalize pixel values and pad to a square input."""
220
+ if (
221
+ x.shape[2] != self.image_encoder.img_size
222
+ or x.shape[3] != self.image_encoder.img_size
223
+ ):
224
+ x = F.interpolate(
225
+ x,
226
+ (self.image_encoder.img_size, self.image_encoder.img_size),
227
+ mode="bilinear",
228
+ )
229
+ return (x - self.pixel_mean) / self.pixel_std
230
+
231
+
232
+ def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None):
233
+ img_size = 1024
234
+ encoder_patch_size = 16
235
+ encoder_depth = 12
236
+ encoder_mlp_ratio = 4.0
237
+ encoder_neck_dims = [256, 256]
238
+ decoder_max_num_input_points = 6
239
+ decoder_transformer_depth = 2
240
+ decoder_transformer_mlp_dim = 2048
241
+ decoder_num_heads = 8
242
+ decoder_upscaling_layer_dims = [64, 32]
243
+ num_multimask_outputs = 3
244
+ iou_head_depth = 3
245
+ iou_head_hidden_dim = 256
246
+ activation = "gelu"
247
+ normalization_type = "layer_norm"
248
+ normalize_before_activation = False
249
+
250
+ assert activation == "relu" or activation == "gelu"
251
+ if activation == "relu":
252
+ activation_fn = nn.ReLU
253
+ else:
254
+ activation_fn = nn.GELU
255
+
256
+ image_encoder = ImageEncoderViT(
257
+ img_size=img_size,
258
+ patch_size=encoder_patch_size,
259
+ in_chans=3,
260
+ patch_embed_dim=encoder_patch_embed_dim,
261
+ normalization_type=normalization_type,
262
+ depth=encoder_depth,
263
+ num_heads=encoder_num_heads,
264
+ mlp_ratio=encoder_mlp_ratio,
265
+ neck_dims=encoder_neck_dims,
266
+ act_layer=activation_fn,
267
+ )
268
+
269
+ image_embedding_size = image_encoder.image_embedding_size
270
+ encoder_transformer_output_dim = image_encoder.transformer_output_dim
271
+
272
+ sam = EfficientSam(
273
+ image_encoder=image_encoder,
274
+ prompt_encoder=PromptEncoder(
275
+ embed_dim=encoder_transformer_output_dim,
276
+ image_embedding_size=(image_embedding_size, image_embedding_size),
277
+ input_image_size=(img_size, img_size),
278
+ ),
279
+ decoder_max_num_input_points=decoder_max_num_input_points,
280
+ mask_decoder=MaskDecoder(
281
+ transformer_dim=encoder_transformer_output_dim,
282
+ transformer=TwoWayTransformer(
283
+ depth=decoder_transformer_depth,
284
+ embedding_dim=encoder_transformer_output_dim,
285
+ num_heads=decoder_num_heads,
286
+ mlp_dim=decoder_transformer_mlp_dim,
287
+ activation=activation_fn,
288
+ normalize_before_activation=normalize_before_activation,
289
+ ),
290
+ num_multimask_outputs=num_multimask_outputs,
291
+ activation=activation_fn,
292
+ normalization_type=normalization_type,
293
+ normalize_before_activation=normalize_before_activation,
294
+ iou_head_depth=iou_head_depth - 1,
295
+ iou_head_hidden_dim=iou_head_hidden_dim,
296
+ upscaling_layer_dims=decoder_upscaling_layer_dims,
297
+ ),
298
+ pixel_mean=[0.485, 0.456, 0.406],
299
+ pixel_std=[0.229, 0.224, 0.225],
300
+ )
301
+ if checkpoint is not None:
302
+ with open(checkpoint, "rb") as f:
303
+ state_dict = torch.load(f, map_location="cpu")
304
+ sam.load_state_dict(state_dict["model"])
305
+ return sam
efficient_sam/efficient_sam_decoder.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple, Type
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .mlp import MLPBlock
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ ) -> None:
24
+ """
25
+ Encodes prompts for input to SAM's mask decoder.
26
+
27
+ Arguments:
28
+ embed_dim (int): The prompts' embedding dimension
29
+ image_embedding_size (tuple(int, int)): The spatial size of the
30
+ image embedding, as (H, W).
31
+ input_image_size (int): The padded size of the image as input
32
+ to the image encoder, as (H, W).
33
+ """
34
+ super().__init__()
35
+ self.embed_dim = embed_dim
36
+ self.input_image_size = input_image_size
37
+ self.image_embedding_size = image_embedding_size
38
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
39
+ self.invalid_points = nn.Embedding(1, embed_dim)
40
+ self.point_embeddings = nn.Embedding(1, embed_dim)
41
+ self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim)
42
+ self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim)
43
+
44
+ def get_dense_pe(self) -> torch.Tensor:
45
+ """
46
+ Returns the positional encoding used to encode point prompts,
47
+ applied to a dense set of points the shape of the image encoding.
48
+
49
+ Returns:
50
+ torch.Tensor: Positional encoding with shape
51
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
52
+ """
53
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
54
+
55
+ def _embed_points(
56
+ self,
57
+ points: torch.Tensor,
58
+ labels: torch.Tensor,
59
+ ) -> torch.Tensor:
60
+ """Embeds point prompts."""
61
+ points = points + 0.5 # Shift to center of pixel
62
+ point_embedding = self.pe_layer.forward_with_coords(
63
+ points, self.input_image_size
64
+ )
65
+ invalid_label_ids = torch.eq(labels, -1)[:,:,None]
66
+ point_label_ids = torch.eq(labels, 1)[:,:,None]
67
+ topleft_label_ids = torch.eq(labels, 2)[:,:,None]
68
+ bottomright_label_ids = torch.eq(labels, 3)[:,:,None]
69
+ point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids
70
+ point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids
71
+ point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids
72
+ point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids
73
+ return point_embedding
74
+
75
+ def forward(
76
+ self,
77
+ coords,
78
+ labels,
79
+ ) -> torch.Tensor:
80
+ """
81
+ Embeds different types of prompts, returning both sparse and dense
82
+ embeddings.
83
+
84
+ Arguments:
85
+ points: A tensor of shape [B, 2]
86
+ labels: An integer tensor of shape [B] where each element is 1,2 or 3.
87
+
88
+ Returns:
89
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
90
+ BxNx(embed_dim), where N is determined by the number of input points
91
+ and boxes.
92
+ """
93
+ return self._embed_points(coords, labels)
94
+
95
+
96
+ class PositionEmbeddingRandom(nn.Module):
97
+ """
98
+ Positional encoding using random spatial frequencies.
99
+ """
100
+
101
+ def __init__(self, num_pos_feats: int) -> None:
102
+ super().__init__()
103
+ self.register_buffer(
104
+ "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats))
105
+ )
106
+
107
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
108
+ """Positionally encode points that are normalized to [0,1]."""
109
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
110
+ coords = 2 * coords - 1
111
+ coords = coords @ self.positional_encoding_gaussian_matrix
112
+ coords = 2 * np.pi * coords
113
+ # outputs d_1 x ... x d_n x C shape
114
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
115
+
116
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
117
+ """Generate positional encoding for a grid of the specified size."""
118
+ h, w = size
119
+ device = self.positional_encoding_gaussian_matrix.device
120
+ grid = torch.ones([h, w], device=device, dtype=torch.float32)
121
+ y_embed = grid.cumsum(dim=0) - 0.5
122
+ x_embed = grid.cumsum(dim=1) - 0.5
123
+ y_embed = y_embed / h
124
+ x_embed = x_embed / w
125
+
126
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
127
+ return pe.permute(2, 0, 1) # C x H x W
128
+
129
+ def forward_with_coords(
130
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
131
+ ) -> torch.Tensor:
132
+ """Positionally encode points that are not normalized to [0,1]."""
133
+ coords = coords_input.clone()
134
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
135
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
136
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
137
+
138
+
139
+ class MaskDecoder(nn.Module):
140
+ def __init__(
141
+ self,
142
+ *,
143
+ transformer_dim: int,
144
+ transformer: nn.Module,
145
+ num_multimask_outputs: int,
146
+ activation: Type[nn.Module],
147
+ normalization_type: str,
148
+ normalize_before_activation: bool,
149
+ iou_head_depth: int,
150
+ iou_head_hidden_dim: int,
151
+ upscaling_layer_dims: List[int],
152
+ ) -> None:
153
+ """
154
+ Predicts masks given an image and prompt embeddings, using a
155
+ transformer architecture.
156
+
157
+ Arguments:
158
+ transformer_dim (int): the channel dimension of the transformer
159
+ transformer (nn.Module): the transformer used to predict masks
160
+ num_multimask_outputs (int): the number of masks to predict
161
+ when disambiguating masks
162
+ activation (nn.Module): the type of activation to use when
163
+ upscaling masks
164
+ iou_head_depth (int): the depth of the MLP used to predict
165
+ mask quality
166
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
167
+ used to predict mask quality
168
+ """
169
+ super().__init__()
170
+ self.transformer_dim = transformer_dim
171
+ self.transformer = transformer
172
+
173
+ self.num_multimask_outputs = num_multimask_outputs
174
+
175
+ self.iou_token = nn.Embedding(1, transformer_dim)
176
+ if num_multimask_outputs > 1:
177
+ self.num_mask_tokens = num_multimask_outputs + 1
178
+ else:
179
+ self.num_mask_tokens = 1
180
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
181
+ output_dim_after_upscaling = transformer_dim
182
+
183
+ self.final_output_upscaling_layers = nn.ModuleList([])
184
+ for idx, layer_dims in enumerate(upscaling_layer_dims):
185
+ self.final_output_upscaling_layers.append(
186
+ nn.Sequential(
187
+ nn.ConvTranspose2d(
188
+ output_dim_after_upscaling,
189
+ layer_dims,
190
+ kernel_size=2,
191
+ stride=2,
192
+ ),
193
+ nn.GroupNorm(1, layer_dims)
194
+ if idx < len(upscaling_layer_dims) - 1
195
+ else nn.Identity(),
196
+ activation(),
197
+ )
198
+ )
199
+ output_dim_after_upscaling = layer_dims
200
+
201
+ self.output_hypernetworks_mlps = nn.ModuleList(
202
+ [
203
+ MLPBlock(
204
+ input_dim=transformer_dim,
205
+ hidden_dim=transformer_dim,
206
+ output_dim=output_dim_after_upscaling,
207
+ num_layers=2,
208
+ act=activation,
209
+ )
210
+ for i in range(self.num_mask_tokens)
211
+ ]
212
+ )
213
+
214
+ self.iou_prediction_head = MLPBlock(
215
+ input_dim=transformer_dim,
216
+ hidden_dim=iou_head_hidden_dim,
217
+ output_dim=self.num_mask_tokens,
218
+ num_layers=iou_head_depth,
219
+ act=activation,
220
+ )
221
+
222
+ def forward(
223
+ self,
224
+ image_embeddings: torch.Tensor,
225
+ image_pe: torch.Tensor,
226
+ sparse_prompt_embeddings: torch.Tensor,
227
+ multimask_output: bool,
228
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
229
+ """
230
+ Predict masks given image and prompt embeddings.
231
+
232
+ Arguments:
233
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
234
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable).
235
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
236
+ multimask_output (bool): Whether to return multiple masks or a single
237
+ mask.
238
+
239
+ Returns:
240
+ torch.Tensor: batched predicted masks
241
+ torch.Tensor: batched predictions of mask quality
242
+ """
243
+
244
+ (
245
+ batch_size,
246
+ max_num_queries,
247
+ sparse_embed_dim_1,
248
+ sparse_embed_dim_2,
249
+ ) = sparse_prompt_embeddings.shape
250
+
251
+ (
252
+ _,
253
+ image_embed_dim_c,
254
+ image_embed_dim_h,
255
+ image_embed_dim_w,
256
+ ) = image_embeddings.shape
257
+
258
+ # Tile the image embedding for all queries.
259
+ image_embeddings_tiled = torch.tile(
260
+ image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
261
+ ).view(
262
+ batch_size * max_num_queries,
263
+ image_embed_dim_c,
264
+ image_embed_dim_h,
265
+ image_embed_dim_w,
266
+ )
267
+ sparse_prompt_embeddings = sparse_prompt_embeddings.reshape(
268
+ batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2
269
+ )
270
+ masks, iou_pred = self.predict_masks(
271
+ image_embeddings=image_embeddings_tiled,
272
+ image_pe=image_pe,
273
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
274
+ )
275
+ if multimask_output and self.num_multimask_outputs > 1:
276
+ return masks[:, 1:, :], iou_pred[:, 1:]
277
+ else:
278
+ return masks[:, :1, :], iou_pred[:, :1]
279
+
280
+ def predict_masks(
281
+ self,
282
+ image_embeddings: torch.Tensor,
283
+ image_pe: torch.Tensor,
284
+ sparse_prompt_embeddings: torch.Tensor,
285
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
286
+ """Predicts masks. See 'forward' for more details."""
287
+ # Concatenate output tokens
288
+ output_tokens = torch.cat(
289
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
290
+ )
291
+ output_tokens = output_tokens.unsqueeze(0).expand(
292
+ sparse_prompt_embeddings.size(0), -1, -1
293
+ )
294
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
295
+ # Expand per-image data in batch direction to be per-mask
296
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
297
+ b, c, h, w = image_embeddings.shape
298
+ hs, src = self.transformer(image_embeddings, pos_src, tokens)
299
+ iou_token_out = hs[:, 0, :]
300
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
301
+
302
+ # Upscale mask embeddings and predict masks using the mask tokens
303
+ upscaled_embedding = src.transpose(1, 2).view(b, c, h, w)
304
+
305
+ for upscaling_layer in self.final_output_upscaling_layers:
306
+ upscaled_embedding = upscaling_layer(upscaled_embedding)
307
+ hyper_in_list: List[torch.Tensor] = []
308
+ for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps):
309
+ hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :]))
310
+ hyper_in = torch.stack(hyper_in_list, dim=1)
311
+ b, c, h, w = upscaled_embedding.shape
312
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
313
+ # Generate mask quality predictions
314
+ iou_pred = self.iou_prediction_head(iou_token_out)
315
+ return masks, iou_pred
efficient_sam/efficient_sam_encoder.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import List, Optional, Tuple, Type
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class LayerNorm2d(nn.Module):
16
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
17
+ super().__init__()
18
+ self.weight = nn.Parameter(torch.ones(num_channels))
19
+ self.bias = nn.Parameter(torch.zeros(num_channels))
20
+ self.eps = eps
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ u = x.mean(1, keepdim=True)
24
+ s = (x - u).pow(2).mean(1, keepdim=True)
25
+ x = (x - u) / torch.sqrt(s + self.eps)
26
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
27
+ return x
28
+
29
+
30
+ class PatchEmbed(nn.Module):
31
+ """2D Image to Patch Embedding"""
32
+
33
+ def __init__(
34
+ self,
35
+ img_size,
36
+ patch_size,
37
+ in_chans,
38
+ embed_dim,
39
+ ):
40
+ super().__init__()
41
+ self.proj = nn.Conv2d(
42
+ in_chans,
43
+ embed_dim,
44
+ kernel_size=(patch_size, patch_size),
45
+ stride=(patch_size, patch_size),
46
+ bias=True,
47
+ )
48
+
49
+ def forward(self, x):
50
+ B, C, H, W = x.shape
51
+ x = self.proj(x)
52
+ return x
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(
57
+ self,
58
+ dim,
59
+ num_heads,
60
+ qkv_bias,
61
+ qk_scale=None,
62
+ ):
63
+ super().__init__()
64
+ self.num_heads = num_heads
65
+ head_dim = dim // num_heads
66
+ self.scale = qk_scale or head_dim**-0.5
67
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
68
+ self.proj = nn.Linear(dim, dim)
69
+
70
+ def forward(self, x):
71
+ B, N, C = x.shape
72
+ qkv = (
73
+ self.qkv(x)
74
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
75
+ .permute(2, 0, 3, 1, 4)
76
+ )
77
+ q, k, v = (
78
+ qkv[0],
79
+ qkv[1],
80
+ qkv[2],
81
+ )
82
+ attn = (q @ k.transpose(-2, -1)) * self.scale
83
+ attn = attn.softmax(dim=-1)
84
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
85
+ x = self.proj(x)
86
+ return x
87
+
88
+
89
+ class Mlp(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_features,
93
+ hidden_features=None,
94
+ out_features=None,
95
+ act_layer=nn.GELU,
96
+ ):
97
+ super().__init__()
98
+ out_features = out_features or in_features
99
+ hidden_features = hidden_features or in_features
100
+ self.fc1 = nn.Linear(in_features, hidden_features)
101
+ self.act = act_layer()
102
+ self.fc2 = nn.Linear(hidden_features, out_features)
103
+
104
+ def forward(self, x):
105
+ x = self.fc1(x)
106
+ x = self.act(x)
107
+ x = self.fc2(x)
108
+ return x
109
+
110
+
111
+ class Block(nn.Module):
112
+ def __init__(
113
+ self,
114
+ dim,
115
+ num_heads,
116
+ mlp_ratio=4.0,
117
+ qkv_bias=False,
118
+ qk_scale=None,
119
+ act_layer=nn.GELU,
120
+ ):
121
+ super().__init__()
122
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
123
+ self.attn = Attention(
124
+ dim,
125
+ num_heads=num_heads,
126
+ qkv_bias=qkv_bias,
127
+ qk_scale=qk_scale,
128
+ )
129
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
130
+ mlp_hidden_dim = int(dim * mlp_ratio)
131
+ self.mlp = Mlp(
132
+ in_features=dim,
133
+ hidden_features=mlp_hidden_dim,
134
+ act_layer=act_layer,
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = x + self.attn(self.norm1(x))
139
+ x = x + self.mlp(self.norm2(x))
140
+ return x
141
+
142
+
143
+ @torch.jit.export
144
+ def get_abs_pos(
145
+ abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int]
146
+ ) -> torch.Tensor:
147
+ """
148
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
149
+ dimension for the original embeddings.
150
+ Args:
151
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
152
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
153
+ hw (Tuple): size of input image tokens.
154
+
155
+ Returns:
156
+ Absolute positional embeddings after processing with shape (1, H, W, C)
157
+ """
158
+ h = hw[0]
159
+ w = hw[1]
160
+ if has_cls_token:
161
+ abs_pos = abs_pos[:, 1:]
162
+ xy_num = abs_pos.shape[1]
163
+ size = int(math.sqrt(xy_num))
164
+ assert size * size == xy_num
165
+
166
+ if size != h or size != w:
167
+ new_abs_pos = F.interpolate(
168
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
169
+ size=(h, w),
170
+ mode="bicubic",
171
+ align_corners=False,
172
+ )
173
+ return new_abs_pos.permute(0, 2, 3, 1)
174
+ else:
175
+ return abs_pos.reshape(1, h, w, -1)
176
+
177
+
178
+ # Image encoder for efficient SAM.
179
+ class ImageEncoderViT(nn.Module):
180
+ def __init__(
181
+ self,
182
+ img_size: int,
183
+ patch_size: int,
184
+ in_chans: int,
185
+ patch_embed_dim: int,
186
+ normalization_type: str,
187
+ depth: int,
188
+ num_heads: int,
189
+ mlp_ratio: float,
190
+ neck_dims: List[int],
191
+ act_layer: Type[nn.Module],
192
+ ) -> None:
193
+ """
194
+ Args:
195
+ img_size (int): Input image size.
196
+ patch_size (int): Patch size.
197
+ in_chans (int): Number of input image channels.
198
+ patch_embed_dim (int): Patch embedding dimension.
199
+ depth (int): Depth of ViT.
200
+ num_heads (int): Number of attention heads in each ViT block.
201
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
202
+ act_layer (nn.Module): Activation layer.
203
+ """
204
+ super().__init__()
205
+
206
+ self.img_size = img_size
207
+ self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
208
+ self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1]
209
+ self.pretrain_use_cls_token = True
210
+ pretrain_img_size = 224
211
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim)
212
+ # Initialize absolute positional embedding with pretrain image size.
213
+ num_patches = (pretrain_img_size // patch_size) * (
214
+ pretrain_img_size // patch_size
215
+ )
216
+ num_positions = num_patches + 1
217
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
218
+ self.blocks = nn.ModuleList()
219
+ for i in range(depth):
220
+ vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
221
+ self.blocks.append(vit_block)
222
+ self.neck = nn.Sequential(
223
+ nn.Conv2d(
224
+ patch_embed_dim,
225
+ neck_dims[0],
226
+ kernel_size=1,
227
+ bias=False,
228
+ ),
229
+ LayerNorm2d(neck_dims[0]),
230
+ nn.Conv2d(
231
+ neck_dims[0],
232
+ neck_dims[0],
233
+ kernel_size=3,
234
+ padding=1,
235
+ bias=False,
236
+ ),
237
+ LayerNorm2d(neck_dims[0]),
238
+ )
239
+
240
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
241
+ assert (
242
+ x.shape[2] == self.img_size and x.shape[3] == self.img_size
243
+ ), "input image size must match self.img_size"
244
+ x = self.patch_embed(x)
245
+ # B C H W -> B H W C
246
+ x = x.permute(0, 2, 3, 1)
247
+ x = x + get_abs_pos(
248
+ self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]]
249
+ )
250
+ num_patches = x.shape[1]
251
+ assert x.shape[2] == num_patches
252
+ x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3])
253
+ for blk in self.blocks:
254
+ x = blk(x)
255
+ x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2])
256
+ x = self.neck(x.permute(0, 3, 1, 2))
257
+ return x
efficient_sam/mlp.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+
3
+ from torch import nn
4
+
5
+
6
+ # Lightly adapted from
7
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
8
+ class MLPBlock(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_dim: int,
12
+ hidden_dim: int,
13
+ output_dim: int,
14
+ num_layers: int,
15
+ act: Type[nn.Module],
16
+ ) -> None:
17
+ super().__init__()
18
+ self.num_layers = num_layers
19
+ h = [hidden_dim] * (num_layers - 1)
20
+ self.layers = nn.ModuleList(
21
+ nn.Sequential(nn.Linear(n, k), act())
22
+ for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
23
+ )
24
+ self.fc = nn.Linear(hidden_dim, output_dim)
25
+
26
+ def forward(self, x):
27
+ for layer in self.layers:
28
+ x = layer(x)
29
+ return self.fc(x)
efficient_sam/two_way_transformer.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Type
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from .mlp import MLPBlock
6
+
7
+
8
+
9
+
10
+ class TwoWayTransformer(nn.Module):
11
+ def __init__(
12
+ self,
13
+ depth: int,
14
+ embedding_dim: int,
15
+ num_heads: int,
16
+ mlp_dim: int,
17
+ activation: Type[nn.Module],
18
+ normalize_before_activation: bool,
19
+ attention_downsample_rate: int = 2,
20
+ ) -> None:
21
+ """
22
+ A transformer decoder that attends to an input image using
23
+ queries whose positional embedding is supplied.
24
+
25
+ Args:
26
+ depth (int): number of layers in the transformer
27
+ embedding_dim (int): the channel dimension for the input embeddings
28
+ num_heads (int): the number of heads for multihead attention. Must
29
+ divide embedding_dim
30
+ mlp_dim (int): the channel dimension internal to the MLP block
31
+ activation (nn.Module): the activation to use in the MLP block
32
+ """
33
+ super().__init__()
34
+ self.depth = depth
35
+ self.embedding_dim = embedding_dim
36
+ self.num_heads = num_heads
37
+ self.mlp_dim = mlp_dim
38
+ self.layers = nn.ModuleList()
39
+
40
+ for i in range(depth):
41
+ curr_layer = TwoWayAttentionBlock(
42
+ embedding_dim=embedding_dim,
43
+ num_heads=num_heads,
44
+ mlp_dim=mlp_dim,
45
+ activation=activation,
46
+ normalize_before_activation=normalize_before_activation,
47
+ attention_downsample_rate=attention_downsample_rate,
48
+ skip_first_layer_pe=(i == 0),
49
+ )
50
+ self.layers.append(curr_layer)
51
+
52
+ self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock(
53
+ embedding_dim,
54
+ num_heads,
55
+ downsample_rate=attention_downsample_rate,
56
+ )
57
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
58
+
59
+ def forward(
60
+ self,
61
+ image_embedding: Tensor,
62
+ image_pe: Tensor,
63
+ point_embedding: Tensor,
64
+ ) -> Tuple[Tensor, Tensor]:
65
+ """
66
+ Args:
67
+ image_embedding (torch.Tensor): image to attend to. Should be shape
68
+ B x embedding_dim x h x w for any h and w.
69
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
70
+ have the same shape as image_embedding.
71
+ point_embedding (torch.Tensor): the embedding to add to the query points.
72
+ Must have shape B x N_points x embedding_dim for any N_points.
73
+
74
+ Returns:
75
+ torch.Tensor: the processed point_embedding
76
+ torch.Tensor: the processed image_embedding
77
+ """
78
+
79
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
80
+ bs, c, h, w = image_embedding.shape
81
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
82
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
83
+
84
+ # Prepare queries
85
+ queries = point_embedding
86
+ keys = image_embedding
87
+
88
+ # Apply transformer blocks and final layernorm
89
+ for idx, layer in enumerate(self.layers):
90
+ queries, keys = layer(
91
+ queries=queries,
92
+ keys=keys,
93
+ query_pe=point_embedding,
94
+ key_pe=image_pe,
95
+ )
96
+
97
+ # Apply the final attention layer from the points to the image
98
+ q = queries + point_embedding
99
+ k = keys + image_pe
100
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
101
+ queries = queries + attn_out
102
+ queries = self.norm_final_attn(queries)
103
+ return queries, keys
104
+
105
+
106
+ class TwoWayAttentionBlock(nn.Module):
107
+ def __init__(
108
+ self,
109
+ embedding_dim: int,
110
+ num_heads: int,
111
+ mlp_dim: int,
112
+ activation: Type[nn.Module],
113
+ normalize_before_activation: bool,
114
+ attention_downsample_rate: int = 2,
115
+ skip_first_layer_pe: bool = False,
116
+ ) -> None:
117
+ """
118
+ A transformer block with four layers: (1) self-attention of sparse
119
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
120
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
121
+ inputs.
122
+
123
+ Arguments:
124
+ embedding_dim (int): the channel dimension of the embeddings
125
+ num_heads (int): the number of heads in the attention layers
126
+ mlp_dim (int): the hidden dimension of the mlp block
127
+ activation (nn.Module): the activation of the mlp block
128
+ skip_first_layer_pe (bool): skip the PE on the first layer
129
+ """
130
+ super().__init__()
131
+ self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads)
132
+ self.norm1 = nn.LayerNorm(embedding_dim)
133
+
134
+ self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock(
135
+ embedding_dim,
136
+ num_heads,
137
+ downsample_rate=attention_downsample_rate,
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(
142
+ embedding_dim,
143
+ mlp_dim,
144
+ embedding_dim,
145
+ 1,
146
+ activation,
147
+ )
148
+
149
+ self.norm3 = nn.LayerNorm(embedding_dim)
150
+
151
+ self.norm4 = nn.LayerNorm(embedding_dim)
152
+ self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock(
153
+ embedding_dim,
154
+ num_heads,
155
+ downsample_rate=attention_downsample_rate,
156
+ )
157
+
158
+ self.skip_first_layer_pe = skip_first_layer_pe
159
+
160
+ def forward(
161
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
162
+ ) -> Tuple[Tensor, Tensor]:
163
+ # Self attention block
164
+ if not self.skip_first_layer_pe:
165
+ queries = queries + query_pe
166
+ attn_out = self.self_attn(q=queries, k=queries, v=queries)
167
+ queries = queries + attn_out
168
+ queries = self.norm1(queries)
169
+
170
+ # Cross attention block, tokens attending to image embedding
171
+ q = queries + query_pe
172
+ k = keys + key_pe
173
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
174
+ queries = queries + attn_out
175
+ queries = self.norm2(queries)
176
+
177
+ # MLP block
178
+ mlp_out = self.mlp(queries)
179
+ queries = queries + mlp_out
180
+ queries = self.norm3(queries)
181
+
182
+ # Cross attention block, image embedding attending to tokens
183
+ q = queries + query_pe
184
+ k = keys + key_pe
185
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
186
+ keys = keys + attn_out
187
+ keys = self.norm4(keys)
188
+
189
+ return queries, keys
190
+
191
+
192
+ class AttentionForTwoWayAttentionBlock(nn.Module):
193
+ """
194
+ An attention layer that allows for downscaling the size of the embedding
195
+ after projection to queries, keys, and values.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ embedding_dim: int,
201
+ num_heads: int,
202
+ downsample_rate: int = 1,
203
+ ) -> None:
204
+ super().__init__()
205
+ self.embedding_dim = embedding_dim
206
+ self.internal_dim = embedding_dim // downsample_rate
207
+ self.num_heads = num_heads
208
+ assert (
209
+ self.internal_dim % num_heads == 0
210
+ ), "num_heads must divide embedding_dim."
211
+ self.c_per_head = self.internal_dim / num_heads
212
+ self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head)
213
+
214
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
215
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
216
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
217
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
218
+ self._reset_parameters()
219
+
220
+ def _reset_parameters(self) -> None:
221
+ # The fan_out is incorrect, but matches pytorch's initialization
222
+ # for which qkv is a single 3*embedding_dim x embedding_dim matrix
223
+ fan_in = self.embedding_dim
224
+ fan_out = 3 * self.internal_dim
225
+ # Xavier uniform with our custom fan_out
226
+ bnd = math.sqrt(6 / (fan_in + fan_out))
227
+ nn.init.uniform_(self.q_proj.weight, -bnd, bnd)
228
+ nn.init.uniform_(self.k_proj.weight, -bnd, bnd)
229
+ nn.init.uniform_(self.v_proj.weight, -bnd, bnd)
230
+ # out_proj.weight is left with default initialization, like pytorch attention
231
+ nn.init.zeros_(self.q_proj.bias)
232
+ nn.init.zeros_(self.k_proj.bias)
233
+ nn.init.zeros_(self.v_proj.bias)
234
+ nn.init.zeros_(self.out_proj.bias)
235
+
236
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
237
+ b, n, c = x.shape
238
+ x = x.reshape(b, n, num_heads, c // num_heads)
239
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
240
+
241
+ def _recombine_heads(self, x: Tensor) -> Tensor:
242
+ b, n_heads, n_tokens, c_per_head = x.shape
243
+ x = x.transpose(1, 2)
244
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
245
+
246
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
247
+ # Input projections
248
+ q = self.q_proj(q)
249
+ k = self.k_proj(k)
250
+ v = self.v_proj(v)
251
+
252
+ # Separate into heads
253
+ q = self._separate_heads(q, self.num_heads)
254
+ k = self._separate_heads(k, self.num_heads)
255
+ v = self._separate_heads(v, self.num_heads)
256
+
257
+ # Attention
258
+ _, _, _, c_per_head = q.shape
259
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
260
+ attn = attn * self.inv_sqrt_c_per_head
261
+ attn = torch.softmax(attn, dim=-1)
262
+ # Get output
263
+ out = attn @ v
264
+ out = self._recombine_heads(out)
265
+ out = self.out_proj(out)
266
+ return out
efficientvit/__init__.py ADDED
File without changes
efficientvit/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
efficientvit/__pycache__/sam_model_zoo.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
efficientvit/apps/__init__.py ADDED
File without changes
efficientvit/apps/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file
 
efficientvit/apps/data_provider/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .augment import *
6
+ from .base import *
7
+ from .random_resolution import *
efficientvit/apps/data_provider/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (258 Bytes). View file
 
efficientvit/apps/data_provider/__pycache__/base.cpython-310.pyc ADDED
Binary file (6.34 kB). View file
 
efficientvit/apps/data_provider/augment/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .bbox import *
6
+ from .color_aug import *
efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (239 Bytes). View file
 
efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-310.pyc ADDED
Binary file (802 Bytes). View file
 
efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-310.pyc ADDED
Binary file (3.13 kB). View file
 
efficientvit/apps/data_provider/augment/bbox.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+
7
+ __all__ = ["rand_bbox"]
8
+
9
+
10
+ def rand_bbox(
11
+ h: int,
12
+ w: int,
13
+ lam: float,
14
+ rand_func: callable = np.random.uniform,
15
+ ) -> tuple[int, int, int, int]:
16
+ """randomly sample bbox, used in cutmix"""
17
+ cut_rat = np.sqrt(1.0 - lam)
18
+ cut_w = w * cut_rat
19
+ cut_h = h * cut_rat
20
+
21
+ # uniform
22
+ cx = rand_func(0, w)
23
+ cy = rand_func(0, h)
24
+
25
+ bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
26
+ bby1 = int(np.clip(cy - cut_h / 2, 0, h))
27
+ bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
28
+ bby2 = int(np.clip(cy + cut_h / 2, 0, h))
29
+
30
+ return bbx1, bby1, bbx2, bby2
efficientvit/apps/data_provider/augment/color_aug.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ from timm.data.auto_augment import rand_augment_transform
9
+
10
+ __all__ = ["ColorAug", "RandAug"]
11
+
12
+
13
+ class ImageAug:
14
+ def aug_image(self, image: Image.Image) -> Image.Image:
15
+ raise NotImplementedError
16
+
17
+ def __call__(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image:
18
+ if isinstance(feed_dict, dict):
19
+ output_dict = feed_dict
20
+ image = feed_dict[self.key]
21
+ else:
22
+ output_dict = None
23
+ image = feed_dict
24
+ is_ndarray = isinstance(image, np.ndarray)
25
+ if is_ndarray:
26
+ image = Image.fromarray(image)
27
+
28
+ image = self.aug_image(image)
29
+
30
+ if is_ndarray:
31
+ image = np.array(image)
32
+
33
+ if output_dict is None:
34
+ return image
35
+ else:
36
+ output_dict[self.key] = image
37
+ return output_dict
38
+
39
+
40
+ class ColorAug(transforms.ColorJitter, ImageAug):
41
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
42
+ super().__init__(
43
+ brightness=brightness,
44
+ contrast=contrast,
45
+ saturation=saturation,
46
+ hue=hue,
47
+ )
48
+ self.key = key
49
+
50
+ def aug_image(self, image: Image.Image) -> Image.Image:
51
+ return transforms.ColorJitter.forward(self, image)
52
+
53
+ def forward(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image:
54
+ return ImageAug.__call__(self, feed_dict)
55
+
56
+
57
+ class RandAug(ImageAug):
58
+ def __init__(self, config: dict[str, any], mean: tuple[float, float, float], key="data"):
59
+ n = config.get("n", 2)
60
+ m = config.get("m", 9)
61
+ mstd = config.get("mstd", 1.0)
62
+ inc = config.get("inc", 1)
63
+ tpct = config.get("tpct", 0.45)
64
+ config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
65
+
66
+ aa_params = dict(
67
+ translate_pct=tpct,
68
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
69
+ interpolation=Image.BICUBIC,
70
+ )
71
+ self.aug_op = rand_augment_transform(config_str, aa_params)
72
+ self.key = key
73
+
74
+ def aug_image(self, image: Image.Image) -> Image.Image:
75
+ return self.aug_op(image)
76
+
77
+ def __repr__(self):
78
+ return self.aug_op.__repr__()
efficientvit/apps/data_provider/base.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import warnings
7
+
8
+ import torch.utils.data
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from efficientvit.apps.data_provider.random_resolution import RRSController
12
+ from efficientvit.models.utils import val2tuple
13
+
14
+ __all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
15
+
16
+
17
+ def parse_image_size(size: int or str) -> tuple[int, int]:
18
+ if isinstance(size, str):
19
+ size = [int(val) for val in size.split("-")]
20
+ return size[0], size[1]
21
+ else:
22
+ return val2tuple(size, 2)
23
+
24
+
25
+ def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
26
+ g = torch.Generator()
27
+ g.manual_seed(seed) # set random seed before sampling validation set
28
+ rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
29
+
30
+ dropped_indexes = rand_indexes[:drop_size]
31
+ remaining_indexes = rand_indexes[drop_size:]
32
+
33
+ dropped_dataset = copy.deepcopy(dataset)
34
+ for key in keys:
35
+ setattr(dropped_dataset, key, [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes])
36
+ setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
37
+ return dataset, dropped_dataset
38
+
39
+
40
+ class DataProvider:
41
+ data_keys = ("samples",)
42
+ mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
43
+ SUB_SEED = 937162211 # random seed for sampling subset
44
+ VALID_SEED = 2147483647 # random seed for the validation set
45
+
46
+ name: str
47
+
48
+ def __init__(
49
+ self,
50
+ train_batch_size: int,
51
+ test_batch_size: int or None,
52
+ valid_size: int or float or None,
53
+ n_worker: int,
54
+ image_size: int or list[int] or str or list[str],
55
+ num_replicas: int or None = None,
56
+ rank: int or None = None,
57
+ train_ratio: float or None = None,
58
+ drop_last: bool = False,
59
+ ):
60
+ warnings.filterwarnings("ignore")
61
+ super().__init__()
62
+
63
+ # batch_size & valid_size
64
+ self.train_batch_size = train_batch_size
65
+ self.test_batch_size = test_batch_size or self.train_batch_size
66
+ self.valid_size = valid_size
67
+
68
+ # image size
69
+ if isinstance(image_size, list):
70
+ self.image_size = [parse_image_size(size) for size in image_size]
71
+ self.image_size.sort() # e.g., 160 -> 224
72
+ RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
73
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
74
+ else:
75
+ self.image_size = parse_image_size(image_size)
76
+ RRSController.IMAGE_SIZE_LIST = [self.image_size]
77
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
78
+
79
+ # distributed configs
80
+ self.num_replicas = num_replicas
81
+ self.rank = rank
82
+
83
+ # build datasets
84
+ train_dataset, val_dataset, test_dataset = self.build_datasets()
85
+
86
+ if train_ratio is not None and train_ratio < 1.0:
87
+ assert 0 < train_ratio < 1
88
+ _, train_dataset = random_drop_data(
89
+ train_dataset,
90
+ int(train_ratio * len(train_dataset)),
91
+ self.SUB_SEED,
92
+ self.data_keys,
93
+ )
94
+
95
+ # build data loader
96
+ self.train = self.build_dataloader(train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True)
97
+ self.valid = self.build_dataloader(val_dataset, test_batch_size, n_worker, drop_last=False, train=False)
98
+ self.test = self.build_dataloader(test_dataset, test_batch_size, n_worker, drop_last=False, train=False)
99
+ if self.valid is None:
100
+ self.valid = self.test
101
+ self.sub_train = None
102
+
103
+ @property
104
+ def data_shape(self) -> tuple[int, ...]:
105
+ return 3, self.active_image_size[0], self.active_image_size[1]
106
+
107
+ def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
108
+ raise NotImplementedError
109
+
110
+ def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
111
+ raise NotImplementedError
112
+
113
+ def build_datasets(self) -> tuple[any, any, any]:
114
+ raise NotImplementedError
115
+
116
+ def build_dataloader(self, dataset: any or None, batch_size: int, n_worker: int, drop_last: bool, train: bool):
117
+ if dataset is None:
118
+ return None
119
+ if isinstance(self.image_size, list) and train:
120
+ from efficientvit.apps.data_provider.random_resolution._data_loader import RRSDataLoader
121
+
122
+ dataloader_class = RRSDataLoader
123
+ else:
124
+ dataloader_class = torch.utils.data.DataLoader
125
+ if self.num_replicas is None:
126
+ return dataloader_class(
127
+ dataset=dataset,
128
+ batch_size=batch_size,
129
+ shuffle=True,
130
+ num_workers=n_worker,
131
+ pin_memory=True,
132
+ drop_last=drop_last,
133
+ )
134
+ else:
135
+ sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
136
+ return dataloader_class(
137
+ dataset=dataset,
138
+ batch_size=batch_size,
139
+ sampler=sampler,
140
+ num_workers=n_worker,
141
+ pin_memory=True,
142
+ drop_last=drop_last,
143
+ )
144
+
145
+ def set_epoch(self, epoch: int) -> None:
146
+ RRSController.set_epoch(epoch, len(self.train))
147
+ if isinstance(self.train.sampler, DistributedSampler):
148
+ self.train.sampler.set_epoch(epoch)
149
+
150
+ def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
151
+ self.active_image_size = val2tuple(new_size, 2)
152
+ new_transform = self.build_valid_transform(self.active_image_size)
153
+ # change the transform of the valid and test set
154
+ self.valid.dataset.transform = self.test.dataset.transform = new_transform
155
+
156
+ def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
157
+ if self.valid_size is not None:
158
+ if 0 < self.valid_size < 1:
159
+ valid_size = int(self.valid_size * len(train_dataset))
160
+ else:
161
+ assert self.valid_size >= 1
162
+ valid_size = int(self.valid_size)
163
+ train_dataset, val_dataset = random_drop_data(
164
+ train_dataset,
165
+ valid_size,
166
+ self.VALID_SEED,
167
+ self.data_keys,
168
+ )
169
+ val_dataset.transform = valid_transform
170
+ else:
171
+ val_dataset = None
172
+ return train_dataset, val_dataset
173
+
174
+ def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
175
+ # used for resetting BN running statistics
176
+ if self.sub_train is None:
177
+ self.sub_train = {}
178
+ if self.active_image_size in self.sub_train:
179
+ return self.sub_train[self.active_image_size]
180
+
181
+ # construct dataset and dataloader
182
+ train_dataset = copy.deepcopy(self.train.dataset)
183
+ if n_samples < len(train_dataset):
184
+ _, train_dataset = random_drop_data(
185
+ train_dataset,
186
+ n_samples,
187
+ self.SUB_SEED,
188
+ self.data_keys,
189
+ )
190
+ RRSController.ACTIVE_SIZE = self.active_image_size
191
+ train_dataset.transform = self.build_train_transform(image_size=self.active_image_size)
192
+ data_loader = self.build_dataloader(train_dataset, batch_size, self.train.num_workers, True, False)
193
+
194
+ # pre-fetch data
195
+ self.sub_train[self.active_image_size] = [
196
+ data for data in data_loader for _ in range(max(1, n_samples // len(train_dataset)))
197
+ ]
198
+
199
+ return self.sub_train[self.active_image_size]
efficientvit/apps/data_provider/random_resolution/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Random resolution data loader compatible with multi-processing and distributed training.
2
+
3
+ Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
4
+ at the training time, resolution sampling is controlled by RRSController
5
+ """
6
+
7
+ from .controller import *
efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (491 Bytes). View file
 
efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-310.pyc ADDED
Binary file (3.31 kB). View file
 
efficientvit/apps/data_provider/random_resolution/_data_loader.py ADDED
@@ -0,0 +1,1538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""This file is based on torch/utils/data/data_loader.py
2
+
3
+ Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
4
+
5
+ To support these two classes, in `./_utils` we define many utility methods and
6
+ functions to be run in multiprocessing. E.g., the data loading worker loop is
7
+ in `./_utils/worker.py`.
8
+ """
9
+
10
+ import functools
11
+ import itertools
12
+ import logging
13
+ import multiprocessing as python_multiprocessing
14
+ import os
15
+ import queue
16
+ import threading
17
+ import warnings
18
+ from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar, Union
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ import torch.multiprocessing as multiprocessing
23
+ import torch.utils.data.graph_settings
24
+ from torch._utils import ExceptionWrapper
25
+ from torch.utils.data import (
26
+ BatchSampler,
27
+ Dataset,
28
+ IterableDataset,
29
+ IterDataPipe,
30
+ MapDataPipe,
31
+ RandomSampler,
32
+ Sampler,
33
+ SequentialSampler,
34
+ _utils,
35
+ )
36
+ from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper
37
+
38
+ from ._data_worker import _worker_loop
39
+
40
+ __all__ = ["RRSDataLoader"]
41
+
42
+ T_co = TypeVar("T_co", covariant=True)
43
+ T = TypeVar("T")
44
+ _worker_init_fn_t = Callable[[int], None]
45
+
46
+ # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
47
+ # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
48
+ # See https://github.com/python/mypy/issues/3737.
49
+ _collate_fn_t = Callable[[List[T]], Any]
50
+
51
+
52
+ # These functions used to be defined in this file. However, it was moved to
53
+ # _utils/collate.py. Although it is rather hard to access this from user land
54
+ # (one has to explicitly directly `import torch.utils.data.dataloader`), there
55
+ # probably is user code out there using it. This aliasing maintains BC in this
56
+ # aspect.
57
+ default_collate: _collate_fn_t = _utils.collate.default_collate
58
+ default_convert = _utils.collate.default_convert
59
+
60
+ get_worker_info = _utils.worker.get_worker_info
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+
65
+ class _DatasetKind:
66
+ Map = 0
67
+ Iterable = 1
68
+
69
+ @staticmethod
70
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
71
+ if kind == _DatasetKind.Map:
72
+ return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
73
+ else:
74
+ return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
75
+
76
+
77
+ class _InfiniteConstantSampler(Sampler):
78
+ r"""Analogous to ``itertools.repeat(None, None)``.
79
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
80
+
81
+ Args:
82
+ data_source (Dataset): dataset to sample from
83
+ """
84
+
85
+ def __init__(self):
86
+ super().__init__(None)
87
+
88
+ def __iter__(self):
89
+ while True:
90
+ yield None
91
+
92
+
93
+ def _get_distributed_settings():
94
+ if dist.is_available() and dist.is_initialized():
95
+ return dist.get_world_size(), dist.get_rank()
96
+ else:
97
+ return 1, 0
98
+
99
+
100
+ def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
101
+ global_worker_id = worker_id
102
+ info = torch.utils.data.get_worker_info()
103
+ assert info is not None
104
+ total_workers = info.num_workers
105
+ datapipe = info.dataset
106
+ assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
107
+ # To distribute elements across distributed process evenly, we should shard data on distributed
108
+ # processes first then shard on worker processes
109
+ total_workers *= world_size
110
+ global_worker_id = global_worker_id * world_size + rank_id
111
+ # For BC, use default SHARDING_PRIORITIES
112
+ torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id)
113
+ if worker_init_fn is not None:
114
+ worker_init_fn(worker_id)
115
+
116
+
117
+ def _share_dist_seed(generator, pg):
118
+ _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
119
+ if isinstance(pg, dist.ProcessGroup):
120
+ dist.broadcast(_shared_seed, src=0, group=pg)
121
+ return _shared_seed.item()
122
+
123
+
124
+ class RRSDataLoader(Generic[T_co]):
125
+ r"""
126
+ Data loader. Combines a dataset and a sampler, and provides an iterable over
127
+ the given dataset.
128
+
129
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
130
+ iterable-style datasets with single- or multi-process loading, customizing
131
+ loading order and optional automatic batching (collation) and memory pinning.
132
+
133
+ See :py:mod:`torch.utils.data` documentation page for more details.
134
+
135
+ Args:
136
+ dataset (Dataset): dataset from which to load the data.
137
+ batch_size (int, optional): how many samples per batch to load
138
+ (default: ``1``).
139
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
140
+ at every epoch (default: ``False``).
141
+ sampler (Sampler or Iterable, optional): defines the strategy to draw
142
+ samples from the dataset. Can be any ``Iterable`` with ``__len__``
143
+ implemented. If specified, :attr:`shuffle` must not be specified.
144
+ batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
145
+ returns a batch of indices at a time. Mutually exclusive with
146
+ :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
147
+ and :attr:`drop_last`.
148
+ num_workers (int, optional): how many subprocesses to use for data
149
+ loading. ``0`` means that the data will be loaded in the main process.
150
+ (default: ``0``)
151
+ collate_fn (Callable, optional): merges a list of samples to form a
152
+ mini-batch of Tensor(s). Used when using batched loading from a
153
+ map-style dataset.
154
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
155
+ into device/CUDA pinned memory before returning them. If your data elements
156
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
157
+ see the example below.
158
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
159
+ if the dataset size is not divisible by the batch size. If ``False`` and
160
+ the size of dataset is not divisible by the batch size, then the last batch
161
+ will be smaller. (default: ``False``)
162
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
163
+ from workers. Should always be non-negative. (default: ``0``)
164
+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
165
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
166
+ input, after seeding and before data loading. (default: ``None``)
167
+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
168
+ by RandomSampler to generate random indexes and multiprocessing to generate
169
+ `base_seed` for workers. (default: ``None``)
170
+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
171
+ in advance by each worker. ``2`` means there will be a total of
172
+ 2 * num_workers batches prefetched across all workers. (default value depends
173
+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
174
+ Otherwise if value of num_workers>0 default is ``2``).
175
+ persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
176
+ the worker processes after a dataset has been consumed once. This allows to
177
+ maintain the workers `Dataset` instances alive. (default: ``False``)
178
+ pin_memory_device (str, optional): the data loader will copy Tensors
179
+ into device pinned memory before returning them if pin_memory is set to true.
180
+
181
+
182
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
183
+ cannot be an unpicklable object, e.g., a lambda function. See
184
+ :ref:`multiprocessing-best-practices` on more details related
185
+ to multiprocessing in PyTorch.
186
+
187
+ .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
188
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
189
+ it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
190
+ rounding depending on :attr:`drop_last`, regardless of multi-process loading
191
+ configurations. This represents the best guess PyTorch can make because PyTorch
192
+ trusts user :attr:`dataset` code in correctly handling multi-process
193
+ loading to avoid duplicate data.
194
+
195
+ However, if sharding results in multiple workers having incomplete last batches,
196
+ this estimate can still be inaccurate, because (1) an otherwise complete batch can
197
+ be broken into multiple ones and (2) more than one batch worth of samples can be
198
+ dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
199
+ cases in general.
200
+
201
+ See `Dataset Types`_ for more details on these two types of datasets and how
202
+ :class:`~torch.utils.data.IterableDataset` interacts with
203
+ `Multi-process data loading`_.
204
+
205
+ .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
206
+ :ref:`data-loading-randomness` notes for random seed related questions.
207
+ """
208
+
209
+ dataset: Dataset[T_co]
210
+ batch_size: Optional[int]
211
+ num_workers: int
212
+ pin_memory: bool
213
+ drop_last: bool
214
+ timeout: float
215
+ sampler: Union[Sampler, Iterable]
216
+ pin_memory_device: str
217
+ prefetch_factor: Optional[int]
218
+ _iterator: Optional["_BaseDataLoaderIter"]
219
+ __initialized = False
220
+
221
+ def __init__(
222
+ self,
223
+ dataset: Dataset[T_co],
224
+ batch_size: Optional[int] = 1,
225
+ shuffle: Optional[bool] = None,
226
+ sampler: Union[Sampler, Iterable, None] = None,
227
+ batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
228
+ num_workers: int = 0,
229
+ collate_fn: Optional[_collate_fn_t] = None,
230
+ pin_memory: bool = False,
231
+ drop_last: bool = False,
232
+ timeout: float = 0,
233
+ worker_init_fn: Optional[_worker_init_fn_t] = None,
234
+ multiprocessing_context=None,
235
+ generator=None,
236
+ *,
237
+ prefetch_factor: Optional[int] = None,
238
+ persistent_workers: bool = False,
239
+ pin_memory_device: str = ""
240
+ ):
241
+ torch._C._log_api_usage_once("python.data_loader")
242
+
243
+ if num_workers < 0:
244
+ raise ValueError(
245
+ "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing."
246
+ )
247
+
248
+ if timeout < 0:
249
+ raise ValueError("timeout option should be non-negative")
250
+
251
+ if num_workers == 0 and prefetch_factor is not None:
252
+ raise ValueError(
253
+ "prefetch_factor option could only be specified in multiprocessing."
254
+ "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
255
+ )
256
+ elif num_workers > 0 and prefetch_factor is None:
257
+ prefetch_factor = 2
258
+ elif prefetch_factor is not None and prefetch_factor < 0:
259
+ raise ValueError("prefetch_factor option should be non-negative")
260
+
261
+ if persistent_workers and num_workers == 0:
262
+ raise ValueError("persistent_workers option needs num_workers > 0")
263
+
264
+ self.dataset = dataset
265
+ self.num_workers = num_workers
266
+ self.prefetch_factor = prefetch_factor
267
+ self.pin_memory = pin_memory
268
+ self.pin_memory_device = pin_memory_device
269
+ self.timeout = timeout
270
+ self.worker_init_fn = worker_init_fn
271
+ self.multiprocessing_context = multiprocessing_context
272
+
273
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
274
+ # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
275
+ if isinstance(self.dataset, IterDataPipe):
276
+ self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
277
+ elif isinstance(self.dataset, MapDataPipe):
278
+ self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
279
+
280
+ # Arg-check dataset related before checking samplers because we want to
281
+ # tell users that iterable-style datasets are incompatible with custom
282
+ # samplers first, so that they don't learn that this combo doesn't work
283
+ # after spending time fixing the custom sampler errors.
284
+ if isinstance(dataset, IterableDataset):
285
+ self._dataset_kind = _DatasetKind.Iterable
286
+ # NOTE [ Custom Samplers and IterableDataset ]
287
+ #
288
+ # `IterableDataset` does not support custom `batch_sampler` or
289
+ # `sampler` since the key is irrelevant (unless we support
290
+ # generator-style dataset one day...).
291
+ #
292
+ # For `sampler`, we always create a dummy sampler. This is an
293
+ # infinite sampler even when the dataset may have an implemented
294
+ # finite `__len__` because in multi-process data loading, naive
295
+ # settings will return duplicated data (which may be desired), and
296
+ # thus using a sampler with length matching that of dataset will
297
+ # cause data lost (you may have duplicates of the first couple
298
+ # batches, but never see anything afterwards). Therefore,
299
+ # `Iterabledataset` always uses an infinite sampler, an instance of
300
+ # `_InfiniteConstantSampler` defined above.
301
+ #
302
+ # A custom `batch_sampler` essentially only controls the batch size.
303
+ # However, it is unclear how useful it would be since an iterable-style
304
+ # dataset can handle that within itself. Moreover, it is pointless
305
+ # in multi-process data loading as the assignment order of batches
306
+ # to workers is an implementation detail so users can not control
307
+ # how to batchify each worker's iterable. Thus, we disable this
308
+ # option. If this turns out to be useful in future, we can re-enable
309
+ # this, and support custom samplers that specify the assignments to
310
+ # specific workers.
311
+ if isinstance(dataset, IterDataPipe):
312
+ if shuffle is not None:
313
+ dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
314
+ # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
315
+ elif shuffle not in {False, None}:
316
+ raise ValueError(
317
+ "DataLoader with IterableDataset: expected unspecified "
318
+ "shuffle option, but got shuffle={}".format(shuffle)
319
+ )
320
+
321
+ if sampler is not None:
322
+ # See NOTE [ Custom Samplers and IterableDataset ]
323
+ raise ValueError(
324
+ "DataLoader with IterableDataset: expected unspecified "
325
+ "sampler option, but got sampler={}".format(sampler)
326
+ )
327
+ elif batch_sampler is not None:
328
+ # See NOTE [ Custom Samplers and IterableDataset ]
329
+ raise ValueError(
330
+ "DataLoader with IterableDataset: expected unspecified "
331
+ "batch_sampler option, but got batch_sampler={}".format(batch_sampler)
332
+ )
333
+ else:
334
+ shuffle = bool(shuffle)
335
+ self._dataset_kind = _DatasetKind.Map
336
+
337
+ if sampler is not None and shuffle:
338
+ raise ValueError("sampler option is mutually exclusive with " "shuffle")
339
+
340
+ if batch_sampler is not None:
341
+ # auto_collation with custom batch_sampler
342
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
343
+ raise ValueError(
344
+ "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last"
345
+ )
346
+ batch_size = None
347
+ drop_last = False
348
+ elif batch_size is None:
349
+ # no auto_collation
350
+ if drop_last:
351
+ raise ValueError(
352
+ "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last"
353
+ )
354
+
355
+ if sampler is None: # give default samplers
356
+ if self._dataset_kind == _DatasetKind.Iterable:
357
+ # See NOTE [ Custom Samplers and IterableDataset ]
358
+ sampler = _InfiniteConstantSampler()
359
+ else: # map-style
360
+ if shuffle:
361
+ sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
362
+ else:
363
+ sampler = SequentialSampler(dataset) # type: ignore[arg-type]
364
+
365
+ if batch_size is not None and batch_sampler is None:
366
+ # auto_collation without custom batch_sampler
367
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
368
+
369
+ self.batch_size = batch_size
370
+ self.drop_last = drop_last
371
+ self.sampler = sampler
372
+ self.batch_sampler = batch_sampler
373
+ self.generator = generator
374
+
375
+ if collate_fn is None:
376
+ if self._auto_collation:
377
+ collate_fn = _utils.collate.default_collate
378
+ else:
379
+ collate_fn = _utils.collate.default_convert
380
+
381
+ self.collate_fn = collate_fn
382
+ self.persistent_workers = persistent_workers
383
+
384
+ self.__initialized = True
385
+ self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
386
+
387
+ self._iterator = None
388
+
389
+ self.check_worker_number_rationality()
390
+
391
+ torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
392
+
393
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
394
+ if self.num_workers == 0:
395
+ return _SingleProcessDataLoaderIter(self)
396
+ else:
397
+ self.check_worker_number_rationality()
398
+ return _MultiProcessingDataLoaderIter(self)
399
+
400
+ @property
401
+ def multiprocessing_context(self):
402
+ return self.__multiprocessing_context
403
+
404
+ @multiprocessing_context.setter
405
+ def multiprocessing_context(self, multiprocessing_context):
406
+ if multiprocessing_context is not None:
407
+ if self.num_workers > 0:
408
+ if isinstance(multiprocessing_context, str):
409
+ valid_start_methods = multiprocessing.get_all_start_methods()
410
+ if multiprocessing_context not in valid_start_methods:
411
+ raise ValueError(
412
+ (
413
+ "multiprocessing_context option "
414
+ "should specify a valid start method in {!r}, but got "
415
+ "multiprocessing_context={!r}"
416
+ ).format(valid_start_methods, multiprocessing_context)
417
+ )
418
+ multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
419
+
420
+ if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
421
+ raise TypeError(
422
+ (
423
+ "multiprocessing_context option should be a valid context "
424
+ "object or a string specifying the start method, but got "
425
+ "multiprocessing_context={}"
426
+ ).format(multiprocessing_context)
427
+ )
428
+ else:
429
+ raise ValueError(
430
+ (
431
+ "multiprocessing_context can only be used with "
432
+ "multi-process loading (num_workers > 0), but got "
433
+ "num_workers={}"
434
+ ).format(self.num_workers)
435
+ )
436
+
437
+ self.__multiprocessing_context = multiprocessing_context
438
+
439
+ def __setattr__(self, attr, val):
440
+ if self.__initialized and attr in (
441
+ "batch_size",
442
+ "batch_sampler",
443
+ "sampler",
444
+ "drop_last",
445
+ "dataset",
446
+ "persistent_workers",
447
+ ):
448
+ raise ValueError(
449
+ "{} attribute should not be set after {} is " "initialized".format(attr, self.__class__.__name__)
450
+ )
451
+
452
+ super().__setattr__(attr, val)
453
+
454
+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
455
+ # since '_BaseDataLoaderIter' references 'DataLoader'.
456
+ def __iter__(self) -> "_BaseDataLoaderIter":
457
+ # When using a single worker the returned iterator should be
458
+ # created everytime to avoid reseting its state
459
+ # However, in the case of a multiple workers iterator
460
+ # the iterator is only created once in the lifetime of the
461
+ # DataLoader object so that workers can be reused
462
+ if self.persistent_workers and self.num_workers > 0:
463
+ if self._iterator is None:
464
+ self._iterator = self._get_iterator()
465
+ else:
466
+ self._iterator._reset(self)
467
+ return self._iterator
468
+ else:
469
+ return self._get_iterator()
470
+
471
+ @property
472
+ def _auto_collation(self):
473
+ return self.batch_sampler is not None
474
+
475
+ @property
476
+ def _index_sampler(self):
477
+ # The actual sampler used for generating indices for `_DatasetFetcher`
478
+ # (see _utils/fetch.py) to read data at each time. This would be
479
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
480
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
481
+ # reasons.
482
+ if self._auto_collation:
483
+ return self.batch_sampler
484
+ else:
485
+ return self.sampler
486
+
487
+ def __len__(self) -> int:
488
+ if self._dataset_kind == _DatasetKind.Iterable:
489
+ # NOTE [ IterableDataset and __len__ ]
490
+ #
491
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
492
+ # does multi-processing data loading, since the samples will be duplicated.
493
+ # However, no real use case should be actually using that behavior, so
494
+ # it should count as a user error. We should generally trust user
495
+ # code to do the proper thing (e.g., configure each replica differently
496
+ # in `__iter__`), and give us the correct `__len__` if they choose to
497
+ # implement it (this will still throw if the dataset does not implement
498
+ # a `__len__`).
499
+ #
500
+ # To provide a further warning, we track if `__len__` was called on the
501
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
502
+ # if the iterator ends up yielding more than this number of samples.
503
+
504
+ # Cannot statically verify that dataset is Sized
505
+ length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
506
+ if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
507
+ from math import ceil
508
+
509
+ if self.drop_last:
510
+ length = length // self.batch_size
511
+ else:
512
+ length = ceil(length / self.batch_size)
513
+ return length
514
+ else:
515
+ return len(self._index_sampler)
516
+
517
+ def check_worker_number_rationality(self):
518
+ # This function check whether the dataloader's worker number is rational based on
519
+ # current system's resource. Current rule is that if the number of workers this
520
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
521
+ # use, than we will pop up a warning to let user pay attention.
522
+ #
523
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
524
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
525
+ # DataLoader process can use half of them which is 32, then the rational max number of
526
+ # worker that initiated from this process is 32.
527
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
528
+ # So the warning message is triggered to notify the user to lower the worker number if
529
+ # necessary.
530
+ #
531
+ #
532
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
533
+ # available (available in most of Linux system, but not OSX and Windows).
534
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
535
+ # it doesn't repect cpuset.
536
+ # We don't take threading into account since each worker process is single threaded
537
+ # at this time.
538
+ #
539
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
540
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
541
+ # in functions use 3rd party modules that rely on those threading flags to determine
542
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
543
+ # set those flags correctly.
544
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
545
+ suggested_max_worker_msg = (
546
+ (
547
+ (
548
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
549
+ "than what this DataLoader is going to create."
550
+ ).format(
551
+ num_worker_suggest,
552
+ ("" if cpuset_checked else " (`cpuset` is not taken into account)"),
553
+ )
554
+ )
555
+ if num_worker_suggest is not None
556
+ else ("DataLoader is not able to compute a suggested max number of worker in current system.")
557
+ )
558
+
559
+ warn_msg = (
560
+ "This DataLoader will create {} worker processes in total. {} "
561
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
562
+ "lower the worker number to avoid potential slowness/freeze if necessary."
563
+ ).format(num_worker_created, suggested_max_worker_msg)
564
+ return warn_msg
565
+
566
+ if not self.num_workers or self.num_workers == 0:
567
+ return
568
+
569
+ # try to compute a suggested max number of worker based on system's resource
570
+ max_num_worker_suggest = None
571
+ cpuset_checked = False
572
+ if hasattr(os, "sched_getaffinity"):
573
+ try:
574
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
575
+ cpuset_checked = True
576
+ except Exception:
577
+ pass
578
+ if max_num_worker_suggest is None:
579
+ # os.cpu_count() could return Optional[int]
580
+ # get cpu count first and check None in order to satify mypy check
581
+ cpu_count = os.cpu_count()
582
+ if cpu_count is not None:
583
+ max_num_worker_suggest = cpu_count
584
+
585
+ if max_num_worker_suggest is None:
586
+ warnings.warn(_create_warning_msg(max_num_worker_suggest, self.num_workers, cpuset_checked))
587
+ return
588
+
589
+ if self.num_workers > max_num_worker_suggest:
590
+ warnings.warn(_create_warning_msg(max_num_worker_suggest, self.num_workers, cpuset_checked))
591
+
592
+
593
+ class _BaseDataLoaderIter:
594
+ def __init__(self, loader: RRSDataLoader) -> None:
595
+ self._dataset = loader.dataset
596
+ self._shared_seed = None
597
+ self._pg = None
598
+ if isinstance(self._dataset, IterDataPipe):
599
+ if dist.is_available() and dist.is_initialized():
600
+ self._pg = dist.new_group(backend="gloo")
601
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
602
+ shared_rng = torch.Generator()
603
+ shared_rng.manual_seed(self._shared_seed)
604
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
605
+ self._dataset_kind = loader._dataset_kind
606
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
607
+ self._auto_collation = loader._auto_collation
608
+ self._drop_last = loader.drop_last
609
+ self._index_sampler = loader._index_sampler
610
+ self._num_workers = loader.num_workers
611
+ ws, rank = _get_distributed_settings()
612
+ self._world_size = ws
613
+ self._rank = rank
614
+ # for other backends, pin_memory_device need to set. if not set
615
+ # default behaviour is CUDA device. if pin_memory_device is selected
616
+ # and pin_memory is not set, the default behaviour false.
617
+ if len(loader.pin_memory_device) == 0:
618
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
619
+ self._pin_memory_device = None
620
+ else:
621
+ if not loader.pin_memory:
622
+ warn_msg = (
623
+ "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
624
+ "please set pin_memory to true, if you need to use the device pin memory"
625
+ )
626
+ warnings.warn(warn_msg)
627
+
628
+ self._pin_memory = loader.pin_memory
629
+ self._pin_memory_device = loader.pin_memory_device
630
+ self._timeout = loader.timeout
631
+ self._collate_fn = loader.collate_fn
632
+ self._sampler_iter = iter(self._index_sampler)
633
+ self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
634
+ self._persistent_workers = loader.persistent_workers
635
+ self._num_yielded = 0
636
+ self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
637
+
638
+ def __iter__(self) -> "_BaseDataLoaderIter":
639
+ return self
640
+
641
+ def _reset(self, loader, first_iter=False):
642
+ self._sampler_iter = iter(self._index_sampler)
643
+ self._num_yielded = 0
644
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
645
+ if isinstance(self._dataset, IterDataPipe):
646
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
647
+ shared_rng = torch.Generator()
648
+ shared_rng.manual_seed(self._shared_seed)
649
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
650
+
651
+ def _next_index(self):
652
+ return next(self._sampler_iter) # may raise StopIteration
653
+
654
+ def _next_data(self):
655
+ raise NotImplementedError
656
+
657
+ def __next__(self) -> Any:
658
+ with torch.autograd.profiler.record_function(self._profile_name):
659
+ if self._sampler_iter is None:
660
+ # TODO(https://github.com/pytorch/pytorch/issues/76750)
661
+ self._reset() # type: ignore[call-arg]
662
+ data = self._next_data()
663
+ self._num_yielded += 1
664
+ if (
665
+ self._dataset_kind == _DatasetKind.Iterable
666
+ and self._IterableDataset_len_called is not None
667
+ and self._num_yielded > self._IterableDataset_len_called
668
+ ):
669
+ warn_msg = (
670
+ "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
671
+ "samples have been fetched. "
672
+ ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded)
673
+ if self._num_workers > 0:
674
+ warn_msg += (
675
+ "For multiprocessing data-loading, this could be caused by not properly configuring the "
676
+ "IterableDataset replica at each worker. Please see "
677
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
678
+ )
679
+ warnings.warn(warn_msg)
680
+ return data
681
+
682
+ def __len__(self) -> int:
683
+ return len(self._index_sampler)
684
+
685
+ def __getstate__(self):
686
+ # TODO: add limited pickling support for sharing an iterator
687
+ # across multiple threads for HOGWILD.
688
+ # Probably the best way to do this is by moving the sample pushing
689
+ # to a separate thread and then just sharing the data queue
690
+ # but signalling the end is tricky without a non-blocking API
691
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
692
+
693
+
694
+ class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
695
+ def __init__(self, loader):
696
+ super().__init__(loader)
697
+ assert self._timeout == 0
698
+ assert self._num_workers == 0
699
+
700
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
701
+ # Taking care of distributed sharding
702
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
703
+ # For BC, use default SHARDING_PRIORITIES
704
+ torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
705
+
706
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
707
+ self._dataset_kind,
708
+ self._dataset,
709
+ self._auto_collation,
710
+ self._collate_fn,
711
+ self._drop_last,
712
+ )
713
+
714
+ def _next_data(self):
715
+ index = self._next_index() # may raise StopIteration
716
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
717
+ if self._pin_memory:
718
+ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
719
+ return data
720
+
721
+
722
+ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
723
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
724
+
725
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
726
+ #
727
+ # Preliminary:
728
+ #
729
+ # Our data model looks like this (queues are indicated with curly brackets):
730
+ #
731
+ # main process ||
732
+ # | ||
733
+ # {index_queue} ||
734
+ # | ||
735
+ # worker processes || DATA
736
+ # | ||
737
+ # {worker_result_queue} || FLOW
738
+ # | ||
739
+ # pin_memory_thread of main process || DIRECTION
740
+ # | ||
741
+ # {data_queue} ||
742
+ # | ||
743
+ # data output \/
744
+ #
745
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
746
+ # `pin_memory=False`.
747
+ #
748
+ #
749
+ # Terminating multiprocessing logic requires very careful design. In
750
+ # particular, we need to make sure that
751
+ #
752
+ # 1. The iterator gracefully exits the workers when its last reference is
753
+ # gone or it is depleted.
754
+ #
755
+ # In this case, the workers should be gracefully exited because the
756
+ # main process may still need to continue to run, and we want cleaning
757
+ # up code in the workers to be executed (e.g., releasing GPU memory).
758
+ # Naturally, we implement the shutdown logic in `__del__` of
759
+ # DataLoaderIterator.
760
+ #
761
+ # We delay the discussion on the logic in this case until later.
762
+ #
763
+ # 2. The iterator exits the workers when the loader process and/or worker
764
+ # processes exits normally or with error.
765
+ #
766
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
767
+ #
768
+ # You may ask, why can't we make the workers non-daemonic, and
769
+ # gracefully exit using the same logic as we have in `__del__` when the
770
+ # iterator gets deleted (see 1 above)?
771
+ #
772
+ # First of all, `__del__` is **not** guaranteed to be called when
773
+ # interpreter exits. Even if it is called, by the time it executes,
774
+ # many Python core library resources may alreay be freed, and even
775
+ # simple things like acquiring an internal lock of a queue may hang.
776
+ # Therefore, in this case, we actually need to prevent `__del__` from
777
+ # being executed, and rely on the automatic termination of daemonic
778
+ # children.
779
+ #
780
+ # Thus, we register an `atexit` hook that sets a global flag
781
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
782
+ # reverse order of registration, we are guaranteed that this flag is
783
+ # set before library resources we use are freed (which, at least in
784
+ # CPython, is done via an `atexit` handler defined in
785
+ # `multiprocessing/util.py`
786
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
787
+ # registered when an object requiring this mechanism is first
788
+ # created, e.g., `mp.Queue`
789
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
790
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
791
+ # )
792
+ #
793
+ # So in `__del__`, we check if `_utils.python_exit_status` is set or
794
+ # `None` (freed), and perform no-op if so.
795
+ #
796
+ # However, simply letting library clean-up codes run can also be bad,
797
+ # because such codes (i.e., `multiprocessing.util._exit_function()`)
798
+ # include join putting threads for `mp.Queue`, which can be blocking.
799
+ # Hence, the main process putting threads are called with
800
+ # `cancel_join_thread` at creation. See later section
801
+ # [ 3b. A process won't hang when putting into a queue; ]
802
+ # for more details.
803
+ #
804
+ # Here are two example cases where library clean-up codes can run
805
+ # before `__del__` is called:
806
+ #
807
+ # 1. If we hold onto a reference to the iterator, it more often
808
+ # than not tries to do `multiprocessing` library cleaning before
809
+ # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
810
+ # and thus prevents our cleaning-up code to run first.
811
+ #
812
+ # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
813
+ # When a process ends, it shuts the all its daemonic children
814
+ # down with a SIGTERM (instead of joining them without a timeout).
815
+ # Simiarly for threads, but by a different mechanism. This fact,
816
+ # together with a few implementation details of multiprocessing, forces
817
+ # us to make workers daemonic. All of our problems arise when a
818
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
819
+ # code which looks more or less like this:
820
+ #
821
+ # try:
822
+ # your_function_using_a_dataloader()
823
+ # finally:
824
+ # multiprocessing.util._exit_function()
825
+ #
826
+ # The joining/termination mentioned above happens inside
827
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
828
+ # throws, the stack trace stored in the exception will prevent the
829
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
830
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
831
+ # its `__del__`, which starts the shutdown procedure, will not be
832
+ # called. That, in turn, means that workers aren't notified. Attempting
833
+ # to join in `_exit_function` will then result in a hang.
834
+ #
835
+ # For context, `_exit_function` is also registered as an `atexit` call.
836
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
837
+ # The code dates back to 2008 and there is no comment on the original
838
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
839
+ # the finally block and the `atexit` registration) that explains this.
840
+ #
841
+ #
842
+ # Finally, another choice is to just shutdown workers with logic in 1
843
+ # above whenever we see an error in `next`. This isn't ideal because
844
+ # a. It prevents users from using try-catch to resume data loading.
845
+ # b. It doesn't prevent hanging if users have references to the
846
+ # iterator.
847
+ #
848
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
849
+ #
850
+ # As shown above, the workers are set as daemonic children of the main
851
+ # process. However, automatic cleaning-up of such child processes only
852
+ # happens if the parent process exits gracefully (e.g., not via fatal
853
+ # signals like SIGKILL). So we must ensure that each process will exit
854
+ # even the process that should send/receive data to/from it were
855
+ # killed, i.e.,
856
+ #
857
+ # a. A process won't hang when getting from a queue.
858
+ #
859
+ # Even with carefully designed data dependencies (i.e., a `put()`
860
+ # always corresponding to a `get()`), hanging on `get()` can still
861
+ # happen when data in queue is corrupted (e.g., due to
862
+ # `cancel_join_thread` or unexpected exit).
863
+ #
864
+ # For child exit, we set a timeout whenever we try to get data
865
+ # from `data_queue`, and check the workers' status on each timeout
866
+ # and error.
867
+ # See `_DataLoaderiter._get_batch()` and
868
+ # `_DataLoaderiter._try_get_data()` for details.
869
+ #
870
+ # Additionally, for child exit on non-Windows platforms, we also
871
+ # register a SIGCHLD handler (which is supported on Windows) on
872
+ # the main process, which checks if any of the workers fail in the
873
+ # (Python) handler. This is more efficient and faster in detecting
874
+ # worker failures, compared to only using the above mechanism.
875
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
876
+ #
877
+ # For `.get()` calls where the sender(s) is not the workers, we
878
+ # guard them with timeouts, and check the status of the sender
879
+ # when timeout happens:
880
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
881
+ # checks the status of the main process.
882
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
883
+ # check `pin_memory_thread` status periodically until `.get()`
884
+ # returns or see that `pin_memory_thread` died.
885
+ #
886
+ # b. A process won't hang when putting into a queue;
887
+ #
888
+ # We use `mp.Queue` which has a separate background thread to put
889
+ # objects from an unbounded buffer array. The background thread is
890
+ # daemonic and usually automatically joined when the process
891
+ # *exits*.
892
+ #
893
+ # In case that the receiver has ended abruptly while
894
+ # reading from the pipe, the join will hang forever. The usual
895
+ # solution for this in Python is calling `q.cancel_join_thread`,
896
+ # which prevents automatically joining it when finalizing
897
+ # (exiting).
898
+ #
899
+ # Nonetheless, `cancel_join_thread` must only be called when the
900
+ # queue is **not** going to be read from or write into by another
901
+ # process, because it may hold onto a lock or leave corrupted data
902
+ # in the queue, leading other readers/writers to hang.
903
+ #
904
+ # Hence,
905
+ # + For worker processes, we only do so (for their output
906
+ # queues, i.e., `worker_result_queue`) before exiting.
907
+ # + For `pin_memory_thread`, its output queue `data_queue` is a
908
+ # `queue.Queue` that does blocking `put` if the queue is full.
909
+ # So there is no above problem, but as a result, in
910
+ # `_pin_memory_loop`, we do need to wrap the `put` in a loop
911
+ # that breaks not only upon success, but also when the main
912
+ # process stops reading, i.e., is shutting down.
913
+ # + For loader process, we `cancel_join_thread()` for all
914
+ # `_index_queues` because the whole purpose of workers and
915
+ # `pin_memory_thread` is to serve the loader process. If
916
+ # loader process is already exiting, we don't really care if
917
+ # the queues are corrupted.
918
+ #
919
+ #
920
+ # Now let's get back to 1:
921
+ # how we gracefully exit the workers when the last reference to the
922
+ # iterator is gone.
923
+ #
924
+ # To achieve this, we implement the following logic along with the design
925
+ # choices mentioned above:
926
+ #
927
+ # `workers_done_event`:
928
+ # A `multiprocessing.Event` shared among the main process and all worker
929
+ # processes. This is used to signal the workers that the iterator is
930
+ # shutting down. After it is set, they will not send processed data to
931
+ # queues anymore, and only wait for the final `None` before exiting.
932
+ # `done_event` isn't strictly needed. I.e., we can just check for `None`
933
+ # from the input queue, but it allows us to skip wasting resources
934
+ # processing data if we are already shutting down.
935
+ #
936
+ # `pin_memory_thread_done_event`:
937
+ # A `threading.Event` for a similar purpose to that of
938
+ # `workers_done_event`, but is for the `pin_memory_thread`. The reason
939
+ # that separate events are needed is that `pin_memory_thread` reads from
940
+ # the output queue of the workers. But the workers, upon seeing that
941
+ # `workers_done_event` is set, only wants to see the final `None`, and is
942
+ # not required to flush all data in the output queue (e.g., it may call
943
+ # `cancel_join_thread` on that queue if its `IterableDataset` iterator
944
+ # happens to exhaust coincidentally, which is out of the control of the
945
+ # main process). Thus, since we will exit `pin_memory_thread` before the
946
+ # workers (see below), two separete events are used.
947
+ #
948
+ # NOTE: In short, the protocol is that the main process will set these
949
+ # `done_event`s and then the corresponding processes/threads a `None`,
950
+ # and that they may exit at any time after receiving the `None`.
951
+ #
952
+ # NOTE: Using `None` as the final signal is valid, since normal data will
953
+ # always be a 2-tuple with the 1st element being the index of the data
954
+ # transferred (different from dataset index/key), and the 2nd being
955
+ # either the dataset key or the data sample (depending on which part
956
+ # of the data model the queue is at).
957
+ #
958
+ # [ worker processes ]
959
+ # While loader process is alive:
960
+ # Get from `index_queue`.
961
+ # If get anything else,
962
+ # Check `workers_done_event`.
963
+ # If set, continue to next iteration
964
+ # i.e., keep getting until see the `None`, then exit.
965
+ # Otherwise, process data:
966
+ # If is fetching from an `IterableDataset` and the iterator
967
+ # is exhausted, send an `_IterableDatasetStopIteration`
968
+ # object to signal iteration end. The main process, upon
969
+ # receiving such an object, will send `None` to this
970
+ # worker and not use the corresponding `index_queue`
971
+ # anymore.
972
+ # If timed out,
973
+ # No matter `workers_done_event` is set (still need to see `None`)
974
+ # or not, must continue to next iteration.
975
+ # (outside loop)
976
+ # If `workers_done_event` is set, (this can be False with `IterableDataset`)
977
+ # `data_queue.cancel_join_thread()`. (Everything is ending here:
978
+ # main process won't read from it;
979
+ # other workers will also call
980
+ # `cancel_join_thread`.)
981
+ #
982
+ # [ pin_memory_thread ]
983
+ # # No need to check main thread. If this thread is alive, the main loader
984
+ # # thread must be alive, because this thread is set as daemonic.
985
+ # While `pin_memory_thread_done_event` is not set:
986
+ # Get from `index_queue`.
987
+ # If timed out, continue to get in the next iteration.
988
+ # Otherwise, process data.
989
+ # While `pin_memory_thread_done_event` is not set:
990
+ # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
991
+ # If timed out, continue to put in the next iteration.
992
+ # Otherwise, break, i.e., continuing to the out loop.
993
+ #
994
+ # NOTE: we don't check the status of the main thread because
995
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
996
+ # ends.
997
+ # 2. in other cases, either the cleaning-up in __del__ or the
998
+ # automatic exit of daemonic thread will take care of it.
999
+ # This won't busy-wait either because `.get(timeout)` does not
1000
+ # busy-wait.
1001
+ #
1002
+ # [ main process ]
1003
+ # In the DataLoader Iter's `__del__`
1004
+ # b. Exit `pin_memory_thread`
1005
+ # i. Set `pin_memory_thread_done_event`.
1006
+ # ii Put `None` in `worker_result_queue`.
1007
+ # iii. Join the `pin_memory_thread`.
1008
+ # iv. `worker_result_queue.cancel_join_thread()`.
1009
+ #
1010
+ # c. Exit the workers.
1011
+ # i. Set `workers_done_event`.
1012
+ # ii. Put `None` in each worker's `index_queue`.
1013
+ # iii. Join the workers.
1014
+ # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
1015
+ #
1016
+ # NOTE: (c) is better placed after (b) because it may leave corrupted
1017
+ # data in `worker_result_queue`, which `pin_memory_thread`
1018
+ # reads from, in which case the `pin_memory_thread` can only
1019
+ # happen at timeing out, which is slow. Nonetheless, same thing
1020
+ # happens if a worker is killed by signal at unfortunate times,
1021
+ # but in other cases, we are better off having a non-corrupted
1022
+ # `worker_result_queue` for `pin_memory_thread`.
1023
+ #
1024
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
1025
+ # can be omitted
1026
+ #
1027
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
1028
+ # `None` from `index_queue`, but it allows us to skip wasting resources
1029
+ # processing indices already in `index_queue` if we are already shutting
1030
+ # down.
1031
+
1032
+ def __init__(self, loader):
1033
+ super().__init__(loader)
1034
+
1035
+ self._prefetch_factor = loader.prefetch_factor
1036
+
1037
+ assert self._num_workers > 0
1038
+ assert self._prefetch_factor > 0
1039
+
1040
+ if loader.multiprocessing_context is None:
1041
+ multiprocessing_context = multiprocessing
1042
+ else:
1043
+ multiprocessing_context = loader.multiprocessing_context
1044
+
1045
+ self._worker_init_fn = loader.worker_init_fn
1046
+
1047
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
1048
+ # Additional worker init function will take care of sharding in MP and Distributed
1049
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
1050
+ self._worker_init_fn = functools.partial(
1051
+ _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
1052
+ )
1053
+
1054
+ # No certainty which module multiprocessing_context is
1055
+ self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1056
+ self._worker_pids_set = False
1057
+ self._shutdown = False
1058
+ self._workers_done_event = multiprocessing_context.Event()
1059
+
1060
+ self._index_queues = []
1061
+ self._workers = []
1062
+ for i in range(self._num_workers):
1063
+ # No certainty which module multiprocessing_context is
1064
+ index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1065
+ # Need to `cancel_join_thread` here!
1066
+ # See sections (2) and (3b) above.
1067
+ index_queue.cancel_join_thread()
1068
+ w = multiprocessing_context.Process(
1069
+ target=_worker_loop,
1070
+ args=(
1071
+ self._dataset_kind,
1072
+ self._dataset,
1073
+ index_queue,
1074
+ self._worker_result_queue,
1075
+ self._workers_done_event,
1076
+ self._auto_collation,
1077
+ self._collate_fn,
1078
+ self._drop_last,
1079
+ self._base_seed,
1080
+ self._worker_init_fn,
1081
+ i,
1082
+ self._num_workers,
1083
+ self._persistent_workers,
1084
+ self._shared_seed,
1085
+ ),
1086
+ )
1087
+ w.daemon = True
1088
+ # NB: Process.start() actually take some time as it needs to
1089
+ # start a process and pass the arguments over via a pipe.
1090
+ # Therefore, we only add a worker to self._workers list after
1091
+ # it started, so that we do not call .join() if program dies
1092
+ # before it starts, and __del__ tries to join but will get:
1093
+ # AssertionError: can only join a started process.
1094
+ w.start()
1095
+ self._index_queues.append(index_queue)
1096
+ self._workers.append(w)
1097
+
1098
+ if self._pin_memory:
1099
+ self._pin_memory_thread_done_event = threading.Event()
1100
+
1101
+ # Queue is not type-annotated
1102
+ self._data_queue = queue.Queue() # type: ignore[var-annotated]
1103
+ if self._pin_memory_device == "xpu":
1104
+ current_device = torch.xpu.current_device() # type: ignore[attr-defined]
1105
+ else:
1106
+ current_device = torch.cuda.current_device() # choose cuda for default
1107
+ pin_memory_thread = threading.Thread(
1108
+ target=_utils.pin_memory._pin_memory_loop,
1109
+ args=(
1110
+ self._worker_result_queue,
1111
+ self._data_queue,
1112
+ current_device,
1113
+ self._pin_memory_thread_done_event,
1114
+ self._pin_memory_device,
1115
+ ),
1116
+ )
1117
+ pin_memory_thread.daemon = True
1118
+ pin_memory_thread.start()
1119
+ # Similar to workers (see comment above), we only register
1120
+ # pin_memory_thread once it is started.
1121
+ self._pin_memory_thread = pin_memory_thread
1122
+ else:
1123
+ self._data_queue = self._worker_result_queue
1124
+
1125
+ # In some rare cases, persistent workers (daemonic processes)
1126
+ # would be terminated before `__del__` of iterator is invoked
1127
+ # when main process exits
1128
+ # It would cause failure when pin_memory_thread tries to read
1129
+ # corrupted data from worker_result_queue
1130
+ # atexit is used to shutdown thread and child processes in the
1131
+ # right sequence before main process exits
1132
+ if self._persistent_workers and self._pin_memory:
1133
+ import atexit
1134
+
1135
+ for w in self._workers:
1136
+ atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
1137
+
1138
+ # .pid can be None only before process is spawned (not the case, so ignore)
1139
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
1140
+ _utils.signal_handling._set_SIGCHLD_handler()
1141
+ self._worker_pids_set = True
1142
+ self._reset(loader, first_iter=True)
1143
+
1144
+ def _reset(self, loader, first_iter=False):
1145
+ super()._reset(loader, first_iter)
1146
+ self._send_idx = 0 # idx of the next task to be sent to workers
1147
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
1148
+ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
1149
+ # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
1150
+ # \ (worker_id, data) if data is already fetched (out-of-order)
1151
+ self._task_info = {}
1152
+ self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
1153
+ # A list of booleans representing whether each worker still has work to
1154
+ # do, i.e., not having exhausted its iterable dataset object. It always
1155
+ # contains all `True`s if not using an iterable-style dataset
1156
+ # (i.e., if kind != Iterable).
1157
+ # Not that this indicates that a worker still has work to do *for this epoch*.
1158
+ # It does not mean that a worker is dead. In case of `_persistent_workers`,
1159
+ # the worker will be reset to available in the next epoch.
1160
+ self._workers_status = [True for i in range(self._num_workers)]
1161
+ # Reset the worker queue cycle so it resumes next epoch at worker 0
1162
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
1163
+ # We resume the prefetching in case it was enabled
1164
+ if not first_iter:
1165
+ for idx in range(self._num_workers):
1166
+ self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed))
1167
+ resume_iteration_cnt = self._num_workers
1168
+ while resume_iteration_cnt > 0:
1169
+ return_idx, return_data = self._get_data()
1170
+ if isinstance(return_idx, _utils.worker._ResumeIteration):
1171
+ assert return_data is None
1172
+ resume_iteration_cnt -= 1
1173
+ # prime the prefetch loop
1174
+ for _ in range(self._prefetch_factor * self._num_workers):
1175
+ self._try_put_index()
1176
+
1177
+ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1178
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
1179
+ # This can also be used as inner loop of fetching without timeout, with
1180
+ # the sender status as the loop condition.
1181
+ #
1182
+ # This raises a `RuntimeError` if any worker died expectedly. This error
1183
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
1184
+ # (only for non-Windows platforms), or the manual check below on errors
1185
+ # and timeouts.
1186
+ #
1187
+ # Returns a 2-tuple:
1188
+ # (bool: whether successfully get data, any: data if successful else None)
1189
+ try:
1190
+ data = self._data_queue.get(timeout=timeout)
1191
+ return (True, data)
1192
+ except Exception as e:
1193
+ # At timeout and error, we manually check whether any worker has
1194
+ # failed. Note that this is the only mechanism for Windows to detect
1195
+ # worker failures.
1196
+ failed_workers = []
1197
+ for worker_id, w in enumerate(self._workers):
1198
+ if self._workers_status[worker_id] and not w.is_alive():
1199
+ failed_workers.append(w)
1200
+ self._mark_worker_as_unavailable(worker_id)
1201
+ if len(failed_workers) > 0:
1202
+ pids_str = ", ".join(str(w.pid) for w in failed_workers)
1203
+ raise RuntimeError("DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)) from e
1204
+ if isinstance(e, queue.Empty):
1205
+ return (False, None)
1206
+ import errno
1207
+ import tempfile
1208
+
1209
+ try:
1210
+ # Raise an exception if we are this close to the FDs limit.
1211
+ # Apparently, trying to open only one file is not a sufficient
1212
+ # test.
1213
+ # See NOTE [ DataLoader on Linux and open files limit ]
1214
+ fds_limit_margin = 10
1215
+ fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
1216
+ except OSError as e:
1217
+ if e.errno == errno.EMFILE:
1218
+ raise RuntimeError(
1219
+ "Too many open files. Communication with the"
1220
+ " workers is no longer possible. Please increase the"
1221
+ " limit using `ulimit -n` in the shell or change the"
1222
+ " sharing strategy by calling"
1223
+ " `torch.multiprocessing.set_sharing_strategy('file_system')`"
1224
+ " at the beginning of your code"
1225
+ ) from None
1226
+ raise
1227
+
1228
+ # NOTE [ DataLoader on Linux and open files limit ]
1229
+ #
1230
+ # On Linux when DataLoader is used with multiprocessing we pass the data between
1231
+ # the root process and the workers through SHM files. We remove those files from
1232
+ # the filesystem as soon as they are created and keep them alive by
1233
+ # passing around their file descriptors through AF_UNIX sockets. (See
1234
+ # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
1235
+ # the wiki (https://github.com/pytorch/pytorch/wiki).)
1236
+ #
1237
+ # This sometimes leads us to exceeding the open files limit. When that happens,
1238
+ # and the offending file descriptor is coming over a socket, the `socket` Python
1239
+ # package silently strips the file descriptor from the message, setting only the
1240
+ # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
1241
+ # it _indicates that some control data were discarded due to lack of space in
1242
+ # the buffer for ancillary data_). This might reflect the C implementation of
1243
+ # AF_UNIX sockets.
1244
+ #
1245
+ # This behaviour can be reproduced with the script and instructions at the
1246
+ # bottom of this note.
1247
+ #
1248
+ # When that happens, the standard Python `multiprocessing` (and not
1249
+ # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
1250
+ #
1251
+ # Sometimes, instead of the FD being stripped, you may get an `OSError:
1252
+ # Too many open files`, both in the script below and in DataLoader. However,
1253
+ # this is rare and seems to be nondeterministic.
1254
+ #
1255
+ #
1256
+ # #!/usr/bin/env python3
1257
+ # import sys
1258
+ # import socket
1259
+ # import os
1260
+ # import array
1261
+ # import shutil
1262
+ # import socket
1263
+ #
1264
+ #
1265
+ # if len(sys.argv) != 4:
1266
+ # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
1267
+ # sys.exit(1)
1268
+ #
1269
+ # if __name__ == '__main__':
1270
+ # dirname = sys.argv[1]
1271
+ # sock_path = dirname + "/sock"
1272
+ # iterations = int(sys.argv[2])
1273
+ # def dummy_path(i):
1274
+ # return dirname + "/" + str(i) + ".dummy"
1275
+ #
1276
+ #
1277
+ # if sys.argv[3] == 'send':
1278
+ # while not os.path.exists(sock_path):
1279
+ # pass
1280
+ # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1281
+ # client.connect(sock_path)
1282
+ # for i in range(iterations):
1283
+ # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
1284
+ # ancdata = array.array('i', [fd])
1285
+ # msg = bytes([i % 256])
1286
+ # print("Sending fd ", fd, " (iteration #", i, ")")
1287
+ # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
1288
+ #
1289
+ #
1290
+ # else:
1291
+ # assert sys.argv[3] == 'recv'
1292
+ #
1293
+ # if os.path.exists(dirname):
1294
+ # raise Exception("Directory exists")
1295
+ #
1296
+ # os.mkdir(dirname)
1297
+ #
1298
+ # print("Opening socket...")
1299
+ # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1300
+ # server.bind(sock_path)
1301
+ #
1302
+ # print("Listening...")
1303
+ # for i in range(iterations):
1304
+ # a = array.array('i')
1305
+ # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
1306
+ # assert(len(ancdata) == 1)
1307
+ # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
1308
+ # a.frombytes(cmsg_data)
1309
+ # print("Received fd ", a[0], " (iteration #", i, ")")
1310
+ #
1311
+ # shutil.rmtree(dirname)
1312
+ #
1313
+ # Steps to reproduce:
1314
+ #
1315
+ # 1. Run two shells and set lower file descriptor limit in the receiving one:
1316
+ # (shell1) ulimit -n 1020
1317
+ # (shell2) ulimit -n 1022
1318
+ #
1319
+ # 2. Run the script above with the `recv` option in the first shell
1320
+ # (shell1) ./test_socket.py sock_tmp 1017 recv
1321
+ #
1322
+ # 3. Run the script with the `send` option in the second shell:
1323
+ # (shell2) ./test_socket.py sock_tmp 1017 send
1324
+
1325
+ def _get_data(self):
1326
+ # Fetches data from `self._data_queue`.
1327
+ #
1328
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
1329
+ # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
1330
+ # in a loop. This is the only mechanism to detect worker failures for
1331
+ # Windows. For other platforms, a SIGCHLD handler is also used for
1332
+ # worker failure detection.
1333
+ #
1334
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
1335
+ # died at timeouts.
1336
+ if self._timeout > 0:
1337
+ success, data = self._try_get_data(self._timeout)
1338
+ if success:
1339
+ return data
1340
+ else:
1341
+ raise RuntimeError("DataLoader timed out after {} seconds".format(self._timeout))
1342
+ elif self._pin_memory:
1343
+ while self._pin_memory_thread.is_alive():
1344
+ success, data = self._try_get_data()
1345
+ if success:
1346
+ return data
1347
+ else:
1348
+ # while condition is false, i.e., pin_memory_thread died.
1349
+ raise RuntimeError("Pin memory thread exited unexpectedly")
1350
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1351
+ # need to call `.task_done()` because we don't use `.join()`.
1352
+ else:
1353
+ while True:
1354
+ success, data = self._try_get_data()
1355
+ if success:
1356
+ return data
1357
+
1358
+ def _next_data(self):
1359
+ while True:
1360
+ # If the worker responsible for `self._rcvd_idx` has already ended
1361
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
1362
+ # we try to advance `self._rcvd_idx` to find the next valid index.
1363
+ #
1364
+ # This part needs to run in the loop because both the `self._get_data()`
1365
+ # call and `_IterableDatasetStopIteration` check below can mark
1366
+ # extra worker(s) as dead.
1367
+ while self._rcvd_idx < self._send_idx:
1368
+ info = self._task_info[self._rcvd_idx]
1369
+ worker_id = info[0]
1370
+ if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
1371
+ break
1372
+ del self._task_info[self._rcvd_idx]
1373
+ self._rcvd_idx += 1
1374
+ else:
1375
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
1376
+ if not self._persistent_workers:
1377
+ self._shutdown_workers()
1378
+ raise StopIteration
1379
+
1380
+ # Now `self._rcvd_idx` is the batch index we want to fetch
1381
+
1382
+ # Check if the next sample has already been generated
1383
+ if len(self._task_info[self._rcvd_idx]) == 2:
1384
+ data = self._task_info.pop(self._rcvd_idx)[1]
1385
+ return self._process_data(data)
1386
+
1387
+ assert not self._shutdown and self._tasks_outstanding > 0
1388
+ idx, data = self._get_data()
1389
+ self._tasks_outstanding -= 1
1390
+ if self._dataset_kind == _DatasetKind.Iterable:
1391
+ # Check for _IterableDatasetStopIteration
1392
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
1393
+ if self._persistent_workers:
1394
+ self._workers_status[data.worker_id] = False
1395
+ else:
1396
+ self._mark_worker_as_unavailable(data.worker_id)
1397
+ self._try_put_index()
1398
+ continue
1399
+
1400
+ if idx != self._rcvd_idx:
1401
+ # store out-of-order samples
1402
+ self._task_info[idx] += (data,)
1403
+ else:
1404
+ del self._task_info[idx]
1405
+ return self._process_data(data)
1406
+
1407
+ def _try_put_index(self):
1408
+ assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
1409
+
1410
+ try:
1411
+ index = self._next_index()
1412
+ except StopIteration:
1413
+ return
1414
+ for _ in range(self._num_workers): # find the next active worker, if any
1415
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
1416
+ if self._workers_status[worker_queue_idx]:
1417
+ break
1418
+ else:
1419
+ # not found (i.e., didn't break)
1420
+ return
1421
+
1422
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
1423
+ self._task_info[self._send_idx] = (worker_queue_idx,)
1424
+ self._tasks_outstanding += 1
1425
+ self._send_idx += 1
1426
+
1427
+ def _process_data(self, data):
1428
+ self._rcvd_idx += 1
1429
+ self._try_put_index()
1430
+ if isinstance(data, ExceptionWrapper):
1431
+ data.reraise()
1432
+ return data
1433
+
1434
+ def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
1435
+ # Mark a worker as having finished its work e.g., due to
1436
+ # exhausting an `IterableDataset`. This should be used only when this
1437
+ # `_MultiProcessingDataLoaderIter` is going to continue running.
1438
+
1439
+ assert self._workers_status[worker_id] or (self._persistent_workers and shutdown)
1440
+
1441
+ # Signal termination to that specific worker.
1442
+ q = self._index_queues[worker_id]
1443
+ # Indicate that no more data will be put on this queue by the current
1444
+ # process.
1445
+ q.put(None)
1446
+
1447
+ # Note that we don't actually join the worker here, nor do we remove the
1448
+ # worker's pid from C side struct because (1) joining may be slow, and
1449
+ # (2) since we don't join, the worker may still raise error, and we
1450
+ # prefer capturing those, rather than ignoring them, even though they
1451
+ # are raised after the worker has finished its job.
1452
+ # Joinning is deferred to `_shutdown_workers`, which it is called when
1453
+ # all workers finish their jobs (e.g., `IterableDataset` replicas) or
1454
+ # when this iterator is garbage collected.
1455
+
1456
+ self._workers_status[worker_id] = False
1457
+
1458
+ assert self._workers_done_event.is_set() == shutdown
1459
+
1460
+ def _shutdown_workers(self):
1461
+ # Called when shutting down this `_MultiProcessingDataLoaderIter`.
1462
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
1463
+ # the logic of this function.
1464
+ if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None:
1465
+ # See (2) of the note. If Python is shutting down, do no-op.
1466
+ return
1467
+ # Normal exit when last reference is gone / iterator is depleted.
1468
+ # See (1) and the second half of the note.
1469
+ if not self._shutdown:
1470
+ self._shutdown = True
1471
+ try:
1472
+ # Normal exit when last reference is gone / iterator is depleted.
1473
+ # See (1) and the second half of the note.
1474
+
1475
+ # Exit `pin_memory_thread` first because exiting workers may leave
1476
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
1477
+ # reads from.
1478
+ if hasattr(self, "_pin_memory_thread"):
1479
+ # Use hasattr in case error happens before we set the attribute.
1480
+ self._pin_memory_thread_done_event.set()
1481
+ # Send something to pin_memory_thread in case it is waiting
1482
+ # so that it can wake up and check `pin_memory_thread_done_event`
1483
+ self._worker_result_queue.put((None, None))
1484
+ self._pin_memory_thread.join()
1485
+ self._worker_result_queue.cancel_join_thread()
1486
+ self._worker_result_queue.close()
1487
+
1488
+ # Exit workers now.
1489
+ self._workers_done_event.set()
1490
+ for worker_id in range(len(self._workers)):
1491
+ # Get number of workers from `len(self._workers)` instead of
1492
+ # `self._num_workers` in case we error before starting all
1493
+ # workers.
1494
+ # If we are using workers_status with persistent_workers
1495
+ # we have to shut it down because the worker is paused
1496
+ if self._persistent_workers or self._workers_status[worker_id]:
1497
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
1498
+ for w in self._workers:
1499
+ # We should be able to join here, but in case anything went
1500
+ # wrong, we set a timeout and if the workers fail to join,
1501
+ # they are killed in the `finally` block.
1502
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1503
+ for q in self._index_queues:
1504
+ q.cancel_join_thread()
1505
+ q.close()
1506
+ finally:
1507
+ # Even though all this function does is putting into queues that
1508
+ # we have called `cancel_join_thread` on, weird things can
1509
+ # happen when a worker is killed by a signal, e.g., hanging in
1510
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
1511
+ # and remove pids from the C side data structure only at the
1512
+ # end.
1513
+ #
1514
+ # FIXME: Unfortunately, for Windows, we are missing a worker
1515
+ # error detection mechanism here in this function, as it
1516
+ # doesn't provide a SIGCHLD handler.
1517
+ if self._worker_pids_set:
1518
+ _utils.signal_handling._remove_worker_pids(id(self))
1519
+ self._worker_pids_set = False
1520
+ for w in self._workers:
1521
+ if w.is_alive():
1522
+ # Existing mechanisms try to make the workers exit
1523
+ # peacefully, but in case that we unfortunately reach
1524
+ # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
1525
+ # we kill the worker.
1526
+ w.terminate()
1527
+
1528
+ # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
1529
+ @staticmethod
1530
+ def _clean_up_worker(w):
1531
+ try:
1532
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1533
+ finally:
1534
+ if w.is_alive():
1535
+ w.terminate()
1536
+
1537
+ def __del__(self):
1538
+ self._shutdown_workers()
efficientvit/apps/data_provider/random_resolution/_data_worker.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""""This file is based on torch/utils/data/_utils/worker.py
2
+
3
+ Contains definitions of the methods used by the _BaseDataLoaderIter workers.
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import os
9
+ import queue
10
+ import random
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING, Optional, Union
13
+
14
+ import torch
15
+ from torch._utils import ExceptionWrapper
16
+ from torch.utils.data._utils import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling
17
+
18
+ if TYPE_CHECKING:
19
+ from torch.utils.data import Dataset
20
+
21
+ from .controller import RRSController
22
+
23
+ if IS_WINDOWS:
24
+ import ctypes
25
+ from ctypes.wintypes import BOOL, DWORD, HANDLE
26
+
27
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
28
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
29
+ # of the manager and ask if the process status has changed.
30
+ class ManagerWatchdog:
31
+ def __init__(self):
32
+ self.manager_pid = os.getppid()
33
+
34
+ # mypy cannot detect this code is windows only
35
+ self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
36
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
37
+ self.kernel32.OpenProcess.restype = HANDLE
38
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
39
+ self.kernel32.WaitForSingleObject.restype = DWORD
40
+
41
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
42
+ SYNCHRONIZE = 0x00100000
43
+ self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
44
+
45
+ if not self.manager_handle:
46
+ raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
47
+
48
+ self.manager_dead = False
49
+
50
+ def is_alive(self):
51
+ if not self.manager_dead:
52
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
53
+ self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
54
+ return not self.manager_dead
55
+
56
+ else:
57
+
58
+ class ManagerWatchdog: # type: ignore[no-redef]
59
+ def __init__(self):
60
+ self.manager_pid = os.getppid()
61
+ self.manager_dead = False
62
+
63
+ def is_alive(self):
64
+ if not self.manager_dead:
65
+ self.manager_dead = os.getppid() != self.manager_pid
66
+ return not self.manager_dead
67
+
68
+
69
+ _worker_info = None
70
+
71
+
72
+ class WorkerInfo:
73
+ id: int
74
+ num_workers: int
75
+ seed: int
76
+ dataset: "Dataset"
77
+ __initialized = False
78
+
79
+ def __init__(self, **kwargs):
80
+ for k, v in kwargs.items():
81
+ setattr(self, k, v)
82
+ self.__keys = tuple(kwargs.keys())
83
+ self.__initialized = True
84
+
85
+ def __setattr__(self, key, val):
86
+ if self.__initialized:
87
+ raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
88
+ return super().__setattr__(key, val)
89
+
90
+ def __repr__(self):
91
+ items = []
92
+ for k in self.__keys:
93
+ items.append("{}={}".format(k, getattr(self, k)))
94
+ return "{}({})".format(self.__class__.__name__, ", ".join(items))
95
+
96
+
97
+ def get_worker_info() -> Optional[WorkerInfo]:
98
+ r"""Returns the information about the current
99
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
100
+
101
+ When called in a worker, this returns an object guaranteed to have the
102
+ following attributes:
103
+
104
+ * :attr:`id`: the current worker id.
105
+ * :attr:`num_workers`: the total number of workers.
106
+ * :attr:`seed`: the random seed set for the current worker. This value is
107
+ determined by main process RNG and the worker id. See
108
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
109
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
110
+ that this will be a different object in a different process than the one
111
+ in the main process.
112
+
113
+ When called in the main process, this returns ``None``.
114
+
115
+ .. note::
116
+ When used in a :attr:`worker_init_fn` passed over to
117
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
118
+ set up each worker process differently, for instance, using ``worker_id``
119
+ to configure the ``dataset`` object to only read a specific fraction of a
120
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
121
+ code.
122
+ """
123
+ return _worker_info
124
+
125
+
126
+ r"""Dummy class used to signal the end of an IterableDataset"""
127
+
128
+
129
+ @dataclass(frozen=True)
130
+ class _IterableDatasetStopIteration:
131
+ worker_id: int
132
+
133
+
134
+ r"""Dummy class used to resume the fetching when worker reuse is enabled"""
135
+
136
+
137
+ @dataclass(frozen=True)
138
+ class _ResumeIteration:
139
+ seed: Optional[int] = None
140
+
141
+
142
+ # The function `_generate_state` is adapted from `numpy.random.SeedSequence`
143
+ # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
144
+ # It's MIT licensed, here is the copyright:
145
+
146
+ # Copyright (c) 2015 Melissa E. O'Neill
147
+ # Copyright (c) 2019 NumPy Developers
148
+ #
149
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
150
+ # of this software and associated documentation files (the "Software"), to deal
151
+ # in the Software without restriction, including without limitation the rights
152
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
153
+ # copies of the Software, and to permit persons to whom the Software is
154
+ # furnished to do so, subject to the following conditions:
155
+ #
156
+ # The above copyright notice and this permission notice shall be included in
157
+ # all copies or substantial portions of the Software.
158
+ #
159
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
160
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
161
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
162
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
163
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
164
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
165
+ # SOFTWARE.
166
+
167
+
168
+ # This function generates an array of int32 as the seed for
169
+ # `numpy.random`, in order to prevent state collision due to same
170
+ # seed and algorithm for `numpy.random` and `random` modules.
171
+ # TODO: Implement `SeedSequence` like object for `torch.random`
172
+ def _generate_state(base_seed, worker_id):
173
+ INIT_A = 0x43B0D7E5
174
+ MULT_A = 0x931E8875
175
+ INIT_B = 0x8B51F9DD
176
+ MULT_B = 0x58F38DED
177
+ MIX_MULT_L = 0xCA01F9DD
178
+ MIX_MULT_R = 0x4973F715
179
+ XSHIFT = 4 * 8 // 2
180
+ MASK32 = 0xFFFFFFFF
181
+
182
+ entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
183
+ pool = [0] * 4
184
+
185
+ hash_const_A = INIT_A
186
+
187
+ def hash(value):
188
+ nonlocal hash_const_A
189
+ value = (value ^ hash_const_A) & MASK32
190
+ hash_const_A = (hash_const_A * MULT_A) & MASK32
191
+ value = (value * hash_const_A) & MASK32
192
+ value = (value ^ (value >> XSHIFT)) & MASK32
193
+ return value
194
+
195
+ def mix(x, y):
196
+ result_x = (MIX_MULT_L * x) & MASK32
197
+ result_y = (MIX_MULT_R * y) & MASK32
198
+ result = (result_x - result_y) & MASK32
199
+ result = (result ^ (result >> XSHIFT)) & MASK32
200
+ return result
201
+
202
+ # Add in the entropy to the pool.
203
+ for i in range(len(pool)):
204
+ pool[i] = hash(entropy[i])
205
+
206
+ # Mix all bits together so late bits can affect earlier bits.
207
+ for i_src in range(len(pool)):
208
+ for i_dst in range(len(pool)):
209
+ if i_src != i_dst:
210
+ pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
211
+
212
+ hash_const_B = INIT_B
213
+ state = []
214
+ for i_dst in range(4):
215
+ data_val = pool[i_dst]
216
+ data_val = (data_val ^ hash_const_B) & MASK32
217
+ hash_const_B = (hash_const_B * MULT_B) & MASK32
218
+ data_val = (data_val * hash_const_B) & MASK32
219
+ data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
220
+ state.append(data_val)
221
+ return state
222
+
223
+
224
+ def _worker_loop(
225
+ dataset_kind,
226
+ dataset,
227
+ index_queue,
228
+ data_queue,
229
+ done_event,
230
+ auto_collation,
231
+ collate_fn,
232
+ drop_last,
233
+ base_seed,
234
+ init_fn,
235
+ worker_id,
236
+ num_workers,
237
+ persistent_workers,
238
+ shared_seed,
239
+ ):
240
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
241
+ # logic of this function.
242
+
243
+ try:
244
+ # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
245
+ # module's handlers are executed after Python returns from C low-level
246
+ # handlers, likely when the same fatal signal had already happened
247
+ # again.
248
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
249
+ signal_handling._set_worker_signal_handlers()
250
+
251
+ torch.set_num_threads(1)
252
+ seed = base_seed + worker_id
253
+ random.seed(seed)
254
+ torch.manual_seed(seed)
255
+ if HAS_NUMPY:
256
+ np_seed = _generate_state(base_seed, worker_id)
257
+ import numpy as np
258
+
259
+ np.random.seed(np_seed)
260
+
261
+ from torch.utils.data import IterDataPipe
262
+ from torch.utils.data.graph_settings import apply_random_seed
263
+
264
+ shared_rng = torch.Generator()
265
+ if isinstance(dataset, IterDataPipe):
266
+ assert shared_seed is not None
267
+ shared_rng.manual_seed(shared_seed)
268
+ dataset = apply_random_seed(dataset, shared_rng)
269
+
270
+ global _worker_info
271
+ _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset)
272
+
273
+ from torch.utils.data import _DatasetKind
274
+
275
+ init_exception = None
276
+
277
+ try:
278
+ if init_fn is not None:
279
+ init_fn(worker_id)
280
+
281
+ fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
282
+ except Exception:
283
+ init_exception = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id))
284
+
285
+ # When using Iterable mode, some worker can exit earlier than others due
286
+ # to the IterableDataset behaving differently for different workers.
287
+ # When such things happen, an `_IterableDatasetStopIteration` object is
288
+ # sent over to the main process with the ID of this worker, so that the
289
+ # main process won't send more tasks to this worker, and will send
290
+ # `None` to this worker to properly exit it.
291
+ #
292
+ # Note that we cannot set `done_event` from a worker as it is shared
293
+ # among all processes. Instead, we set the `iteration_end` flag to
294
+ # signify that the iterator is exhausted. When either `done_event` or
295
+ # `iteration_end` is set, we skip all processing step and just wait for
296
+ # `None`.
297
+ iteration_end = False
298
+
299
+ watchdog = ManagerWatchdog()
300
+
301
+ while watchdog.is_alive():
302
+ try:
303
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
304
+ except queue.Empty:
305
+ continue
306
+ if isinstance(r, _ResumeIteration):
307
+ # Acknowledge the main process
308
+ data_queue.put((r, None))
309
+ iteration_end = False
310
+
311
+ if isinstance(dataset, IterDataPipe):
312
+ assert r.seed is not None
313
+ shared_rng.manual_seed(r.seed)
314
+ dataset = apply_random_seed(dataset, shared_rng)
315
+
316
+ # Recreate the fetcher for worker-reuse policy
317
+ fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
318
+ continue
319
+ elif r is None:
320
+ # Received the final signal
321
+ assert done_event.is_set() or iteration_end
322
+ break
323
+ elif done_event.is_set() or iteration_end:
324
+ # `done_event` is set. But I haven't received the final signal
325
+ # (None) yet. I will keep continuing until get it, and skip the
326
+ # processing steps.
327
+ continue
328
+ idx, index = r
329
+ """ Added """
330
+ RRSController.sample_resolution(batch_id=idx)
331
+ """ Added """
332
+ data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
333
+ if init_exception is not None:
334
+ data = init_exception
335
+ init_exception = None
336
+ else:
337
+ try:
338
+ data = fetcher.fetch(index)
339
+ except Exception as e:
340
+ if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
341
+ data = _IterableDatasetStopIteration(worker_id)
342
+ # Set `iteration_end`
343
+ # (1) to save future `next(...)` calls, and
344
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
345
+ iteration_end = True
346
+ else:
347
+ # It is important that we don't store exc_info in a variable.
348
+ # `ExceptionWrapper` does the correct thing.
349
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
350
+ data = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id))
351
+ data_queue.put((idx, data))
352
+ del data, idx, index, r # save memory
353
+ except KeyboardInterrupt:
354
+ # Main process will raise KeyboardInterrupt anyways.
355
+ pass
356
+ if done_event.is_set():
357
+ data_queue.cancel_join_thread()
358
+ data_queue.close()
efficientvit/apps/data_provider/random_resolution/controller.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ import torchvision.transforms.functional as F
10
+
11
+ from efficientvit.models.utils import torch_random_choices
12
+
13
+ __all__ = [
14
+ "RRSController",
15
+ "get_interpolate",
16
+ "MyRandomResizedCrop",
17
+ ]
18
+
19
+
20
+ class RRSController:
21
+ ACTIVE_SIZE = (224, 224)
22
+ IMAGE_SIZE_LIST = [(224, 224)]
23
+
24
+ CHOICE_LIST = None
25
+
26
+ @staticmethod
27
+ def get_candidates() -> list[tuple[int, int]]:
28
+ return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
29
+
30
+ @staticmethod
31
+ def sample_resolution(batch_id: int) -> None:
32
+ RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
33
+
34
+ @staticmethod
35
+ def set_epoch(epoch: int, batch_per_epoch: int) -> None:
36
+ g = torch.Generator()
37
+ g.manual_seed(epoch)
38
+ RRSController.CHOICE_LIST = torch_random_choices(
39
+ RRSController.get_candidates(),
40
+ g,
41
+ batch_per_epoch,
42
+ )
43
+
44
+
45
+ def get_interpolate(name: str) -> F.InterpolationMode:
46
+ mapping = {
47
+ "nearest": F.InterpolationMode.NEAREST,
48
+ "bilinear": F.InterpolationMode.BILINEAR,
49
+ "bicubic": F.InterpolationMode.BICUBIC,
50
+ "box": F.InterpolationMode.BOX,
51
+ "hamming": F.InterpolationMode.HAMMING,
52
+ "lanczos": F.InterpolationMode.LANCZOS,
53
+ }
54
+ if name in mapping:
55
+ return mapping[name]
56
+ elif name == "random":
57
+ return torch_random_choices(
58
+ [
59
+ F.InterpolationMode.NEAREST,
60
+ F.InterpolationMode.BILINEAR,
61
+ F.InterpolationMode.BICUBIC,
62
+ F.InterpolationMode.BOX,
63
+ F.InterpolationMode.HAMMING,
64
+ F.InterpolationMode.LANCZOS,
65
+ ],
66
+ )
67
+ else:
68
+ raise NotImplementedError
69
+
70
+
71
+ class MyRandomResizedCrop(transforms.RandomResizedCrop):
72
+ def __init__(
73
+ self,
74
+ scale=(0.08, 1.0),
75
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
76
+ interpolation: str = "random",
77
+ ):
78
+ super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
79
+ self.interpolation = interpolation
80
+
81
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
82
+ i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
83
+ target_size = RRSController.ACTIVE_SIZE
84
+ return F.resized_crop(img, i, j, h, w, list(target_size), get_interpolate(self.interpolation))
85
+
86
+ def __repr__(self) -> str:
87
+ format_string = self.__class__.__name__
88
+ format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
89
+ format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
90
+ format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
91
+ format_string += f"\tinterpolation={self.interpolation})"
92
+ return format_string
efficientvit/apps/setup.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+ import time
7
+ from copy import deepcopy
8
+
9
+ import torch.backends.cudnn
10
+ import torch.distributed
11
+ import torch.nn as nn
12
+
13
+ from efficientvit.apps.data_provider import DataProvider
14
+ from efficientvit.apps.trainer.run_config import RunConfig
15
+ from efficientvit.apps.utils import (
16
+ dist_init,
17
+ dump_config,
18
+ get_dist_local_rank,
19
+ get_dist_rank,
20
+ get_dist_size,
21
+ init_modules,
22
+ is_master,
23
+ load_config,
24
+ partial_update_config,
25
+ zero_last_gamma,
26
+ )
27
+ from efficientvit.models.utils import build_kwargs_from_config, load_state_dict_from_file
28
+
29
+ __all__ = [
30
+ "save_exp_config",
31
+ "setup_dist_env",
32
+ "setup_seed",
33
+ "setup_exp_config",
34
+ "setup_data_provider",
35
+ "setup_run_config",
36
+ "init_model",
37
+ ]
38
+
39
+
40
+ def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
41
+ if not is_master():
42
+ return
43
+ dump_config(exp_config, os.path.join(path, name))
44
+
45
+
46
+ def setup_dist_env(gpu: str or None = None) -> None:
47
+ if gpu is not None:
48
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu
49
+ if not torch.distributed.is_initialized():
50
+ dist_init()
51
+ torch.backends.cudnn.benchmark = True
52
+ torch.cuda.set_device(get_dist_local_rank())
53
+
54
+
55
+ def setup_seed(manual_seed: int, resume: bool) -> None:
56
+ if resume:
57
+ manual_seed = int(time.time())
58
+ manual_seed = get_dist_rank() + manual_seed
59
+ torch.manual_seed(manual_seed)
60
+ torch.cuda.manual_seed_all(manual_seed)
61
+
62
+
63
+ def setup_exp_config(config_path: str, recursive=True, opt_args: dict or None = None) -> dict:
64
+ # load config
65
+ if not os.path.isfile(config_path):
66
+ raise ValueError(config_path)
67
+
68
+ fpaths = [config_path]
69
+ if recursive:
70
+ extension = os.path.splitext(config_path)[1]
71
+ while os.path.dirname(config_path) != config_path:
72
+ config_path = os.path.dirname(config_path)
73
+ fpath = os.path.join(config_path, "default" + extension)
74
+ if os.path.isfile(fpath):
75
+ fpaths.append(fpath)
76
+ fpaths = fpaths[::-1]
77
+
78
+ default_config = load_config(fpaths[0])
79
+ exp_config = deepcopy(default_config)
80
+ for fpath in fpaths[1:]:
81
+ partial_update_config(exp_config, load_config(fpath))
82
+ # update config via args
83
+ if opt_args is not None:
84
+ partial_update_config(exp_config, opt_args)
85
+
86
+ return exp_config
87
+
88
+
89
+ def setup_data_provider(
90
+ exp_config: dict, data_provider_classes: list[type[DataProvider]], is_distributed: bool = True
91
+ ) -> DataProvider:
92
+ dp_config = exp_config["data_provider"]
93
+ dp_config["num_replicas"] = get_dist_size() if is_distributed else None
94
+ dp_config["rank"] = get_dist_rank() if is_distributed else None
95
+ dp_config["test_batch_size"] = dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
96
+ dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config["base_batch_size"]
97
+
98
+ data_provider_lookup = {provider.name: provider for provider in data_provider_classes}
99
+ data_provider_class = data_provider_lookup[dp_config["dataset"]]
100
+
101
+ data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
102
+ data_provider = data_provider_class(**data_provider_kwargs)
103
+ return data_provider
104
+
105
+
106
+ def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
107
+ exp_config["run_config"]["init_lr"] = exp_config["run_config"]["base_lr"] * get_dist_size()
108
+
109
+ run_config = run_config_cls(**exp_config["run_config"])
110
+
111
+ return run_config
112
+
113
+
114
+ def init_model(
115
+ network: nn.Module,
116
+ init_from: str or None = None,
117
+ backbone_init_from: str or None = None,
118
+ rand_init="trunc_normal",
119
+ last_gamma=None,
120
+ ) -> None:
121
+ # initialization
122
+ init_modules(network, init_type=rand_init)
123
+ # zero gamma of last bn in each block
124
+ if last_gamma is not None:
125
+ zero_last_gamma(network, last_gamma)
126
+
127
+ # load weight
128
+ if init_from is not None and os.path.isfile(init_from):
129
+ network.load_state_dict(load_state_dict_from_file(init_from))
130
+ print(f"Loaded init from {init_from}")
131
+ elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
132
+ network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
133
+ print(f"Loaded backbone init from {backbone_init_from}")
134
+ else:
135
+ print(f"Random init ({rand_init}) with last gamma {last_gamma}")
efficientvit/apps/trainer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .base import *
6
+ from .run_config import *
efficientvit/apps/trainer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (226 Bytes). View file
 
efficientvit/apps/trainer/__pycache__/base.cpython-310.pyc ADDED
Binary file (8.47 kB). View file
 
efficientvit/apps/trainer/__pycache__/run_config.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
efficientvit/apps/trainer/base.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from efficientvit.apps.data_provider import DataProvider, parse_image_size
11
+ from efficientvit.apps.trainer.run_config import RunConfig
12
+ from efficientvit.apps.utils import EMA, dist_barrier, get_dist_local_rank, is_master
13
+ from efficientvit.models.nn.norm import reset_bn
14
+ from efficientvit.models.utils import is_parallel, load_state_dict_from_file
15
+
16
+ __all__ = ["Trainer"]
17
+
18
+
19
+ class Trainer:
20
+ def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
21
+ self.path = os.path.realpath(os.path.expanduser(path))
22
+ self.model = model.cuda()
23
+ self.data_provider = data_provider
24
+
25
+ self.ema = None
26
+
27
+ self.checkpoint_path = os.path.join(self.path, "checkpoint")
28
+ self.logs_path = os.path.join(self.path, "logs")
29
+ for path in [self.path, self.checkpoint_path, self.logs_path]:
30
+ os.makedirs(path, exist_ok=True)
31
+
32
+ self.best_val = 0.0
33
+ self.start_epoch = 0
34
+
35
+ @property
36
+ def network(self) -> nn.Module:
37
+ return self.model.module if is_parallel(self.model) else self.model
38
+
39
+ @property
40
+ def eval_network(self) -> nn.Module:
41
+ if self.ema is None:
42
+ model = self.model
43
+ else:
44
+ model = self.ema.shadows
45
+ model = model.module if is_parallel(model) else model
46
+ return model
47
+
48
+ def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
49
+ if is_master():
50
+ fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
51
+ fout.write(log_str + "\n")
52
+ fout.flush()
53
+ fout.close()
54
+ if print_log:
55
+ print(log_str)
56
+
57
+ def save_model(
58
+ self,
59
+ checkpoint=None,
60
+ only_state_dict=True,
61
+ epoch=0,
62
+ model_name=None,
63
+ ) -> None:
64
+ if is_master():
65
+ if checkpoint is None:
66
+ if only_state_dict:
67
+ checkpoint = {"state_dict": self.network.state_dict()}
68
+ else:
69
+ checkpoint = {
70
+ "state_dict": self.network.state_dict(),
71
+ "epoch": epoch,
72
+ "best_val": self.best_val,
73
+ "optimizer": self.optimizer.state_dict(),
74
+ "lr_scheduler": self.lr_scheduler.state_dict(),
75
+ "ema": self.ema.state_dict() if self.ema is not None else None,
76
+ "scaler": self.scaler.state_dict() if self.enable_amp else None,
77
+ }
78
+
79
+ model_name = model_name or "checkpoint.pt"
80
+
81
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
82
+ model_path = os.path.join(self.checkpoint_path, model_name)
83
+ with open(latest_fname, "w") as _fout:
84
+ _fout.write(model_path + "\n")
85
+ torch.save(checkpoint, model_path)
86
+
87
+ def load_model(self, model_fname=None) -> None:
88
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
89
+ if model_fname is None and os.path.exists(latest_fname):
90
+ with open(latest_fname, "r") as fin:
91
+ model_fname = fin.readline()
92
+ if len(model_fname) > 0 and model_fname[-1] == "\n":
93
+ model_fname = model_fname[:-1]
94
+ try:
95
+ if model_fname is None:
96
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
97
+ elif not os.path.exists(model_fname):
98
+ model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
99
+ if not os.path.exists(model_fname):
100
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
101
+ print(f"=> loading checkpoint {model_fname}")
102
+ checkpoint = load_state_dict_from_file(model_fname, False)
103
+ except Exception:
104
+ self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
105
+ return
106
+
107
+ # load checkpoint
108
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
109
+ log = []
110
+ if "epoch" in checkpoint:
111
+ self.start_epoch = checkpoint["epoch"] + 1
112
+ self.run_config.update_global_step(self.start_epoch)
113
+ log.append(f"epoch={self.start_epoch - 1}")
114
+ if "best_val" in checkpoint:
115
+ self.best_val = checkpoint["best_val"]
116
+ log.append(f"best_val={self.best_val:.2f}")
117
+ if "optimizer" in checkpoint:
118
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
119
+ log.append("optimizer")
120
+ if "lr_scheduler" in checkpoint:
121
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
122
+ log.append("lr_scheduler")
123
+ if "ema" in checkpoint and self.ema is not None:
124
+ self.ema.load_state_dict(checkpoint["ema"])
125
+ log.append("ema")
126
+ if "scaler" in checkpoint and self.enable_amp:
127
+ self.scaler.load_state_dict(checkpoint["scaler"])
128
+ log.append("scaler")
129
+ self.write_log("Loaded: " + ", ".join(log))
130
+
131
+ """ validate """
132
+
133
+ def reset_bn(
134
+ self,
135
+ network: nn.Module or None = None,
136
+ subset_size: int = 16000,
137
+ subset_batch_size: int = 100,
138
+ data_loader=None,
139
+ progress_bar=False,
140
+ ) -> None:
141
+ network = network or self.network
142
+ if data_loader is None:
143
+ data_loader = []
144
+ for data in self.data_provider.build_sub_train_loader(subset_size, subset_batch_size):
145
+ if isinstance(data, list):
146
+ data_loader.append(data[0])
147
+ elif isinstance(data, dict):
148
+ data_loader.append(data["data"])
149
+ elif isinstance(data, torch.Tensor):
150
+ data_loader.append(data)
151
+ else:
152
+ raise NotImplementedError
153
+
154
+ network.eval()
155
+ reset_bn(
156
+ network,
157
+ data_loader,
158
+ sync=True,
159
+ progress_bar=progress_bar,
160
+ )
161
+
162
+ def _validate(self, model, data_loader, epoch) -> dict[str, any]:
163
+ raise NotImplementedError
164
+
165
+ def validate(self, model=None, data_loader=None, is_test=True, epoch=0) -> dict[str, any]:
166
+ model = model or self.eval_network
167
+ if data_loader is None:
168
+ if is_test:
169
+ data_loader = self.data_provider.test
170
+ else:
171
+ data_loader = self.data_provider.valid
172
+
173
+ model.eval()
174
+ return self._validate(model, data_loader, epoch)
175
+
176
+ def multires_validate(
177
+ self,
178
+ model=None,
179
+ data_loader=None,
180
+ is_test=True,
181
+ epoch=0,
182
+ eval_image_size=None,
183
+ ) -> dict[str, dict[str, any]]:
184
+ eval_image_size = eval_image_size or self.run_config.eval_image_size
185
+ eval_image_size = eval_image_size or self.data_provider.image_size
186
+ model = model or self.eval_network
187
+
188
+ if not isinstance(eval_image_size, list):
189
+ eval_image_size = [eval_image_size]
190
+
191
+ output_dict = {}
192
+ for r in eval_image_size:
193
+ self.data_provider.assign_active_image_size(parse_image_size(r))
194
+ if self.run_config.reset_bn:
195
+ self.reset_bn(
196
+ network=model,
197
+ subset_size=self.run_config.reset_bn_size,
198
+ subset_batch_size=self.run_config.reset_bn_batch_size,
199
+ progress_bar=True,
200
+ )
201
+ output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
202
+ return output_dict
203
+
204
+ """ training """
205
+
206
+ def prep_for_training(self, run_config: RunConfig, ema_decay: float or None = None, amp="fp32") -> None:
207
+ self.run_config = run_config
208
+ self.model = nn.parallel.DistributedDataParallel(
209
+ self.model.cuda(),
210
+ device_ids=[get_dist_local_rank()],
211
+ static_graph=True,
212
+ )
213
+
214
+ self.run_config.global_step = 0
215
+ self.run_config.batch_per_epoch = len(self.data_provider.train)
216
+ assert self.run_config.batch_per_epoch > 0, "Training set is empty"
217
+
218
+ # build optimizer
219
+ self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
220
+
221
+ if ema_decay is not None:
222
+ self.ema = EMA(self.network, ema_decay)
223
+
224
+ # amp
225
+ self.amp = amp
226
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp)
227
+
228
+ @property
229
+ def enable_amp(self) -> bool:
230
+ return self.amp != "fp32"
231
+
232
+ @property
233
+ def amp_dtype(self) -> torch.dtype:
234
+ if self.amp == "fp16":
235
+ return torch.float16
236
+ elif self.amp == "bf16":
237
+ return torch.bfloat16
238
+ else:
239
+ return torch.float32
240
+
241
+ def sync_model(self):
242
+ print("Sync model")
243
+ self.save_model(model_name="sync.pt")
244
+ dist_barrier()
245
+ checkpoint = torch.load(os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu")
246
+ dist_barrier()
247
+ if is_master():
248
+ os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
249
+ dist_barrier()
250
+
251
+ # load checkpoint
252
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
253
+ if "optimizer" in checkpoint:
254
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
255
+ if "lr_scheduler" in checkpoint:
256
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
257
+ if "ema" in checkpoint and self.ema is not None:
258
+ self.ema.load_state_dict(checkpoint["ema"])
259
+ if "scaler" in checkpoint and self.enable_amp:
260
+ self.scaler.load_state_dict(checkpoint["scaler"])
261
+
262
+ def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
263
+ for key in feed_dict:
264
+ if isinstance(feed_dict[key], torch.Tensor):
265
+ feed_dict[key] = feed_dict[key].cuda()
266
+ return feed_dict
267
+
268
+ def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
269
+ raise NotImplementedError
270
+
271
+ def after_step(self) -> None:
272
+ self.scaler.unscale_(self.optimizer)
273
+ # gradient clip
274
+ if self.run_config.grad_clip is not None:
275
+ torch.nn.utils.clip_grad_value_(self.model.parameters(), self.run_config.grad_clip)
276
+ # update
277
+ self.scaler.step(self.optimizer)
278
+ self.scaler.update()
279
+
280
+ self.lr_scheduler.step()
281
+ self.run_config.step()
282
+ # update ema
283
+ if self.ema is not None:
284
+ self.ema.step(self.network, self.run_config.global_step)
285
+
286
+ def _train_one_epoch(self, epoch: int) -> dict[str, any]:
287
+ raise NotImplementedError
288
+
289
+ def train_one_epoch(self, epoch: int) -> dict[str, any]:
290
+ self.model.train()
291
+
292
+ self.data_provider.set_epoch(epoch)
293
+
294
+ train_info_dict = self._train_one_epoch(epoch)
295
+
296
+ return train_info_dict
297
+
298
+ def train(self) -> None:
299
+ raise NotImplementedError
efficientvit/apps/trainer/run_config.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import json
6
+
7
+ import numpy as np
8
+ import torch.nn as nn
9
+
10
+ from efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
11
+
12
+ __all__ = ["Scheduler", "RunConfig"]
13
+
14
+
15
+ class Scheduler:
16
+ PROGRESS = 0
17
+
18
+
19
+ class RunConfig:
20
+ n_epochs: int
21
+ init_lr: float
22
+ warmup_epochs: int
23
+ warmup_lr: float
24
+ lr_schedule_name: str
25
+ lr_schedule_param: dict
26
+ optimizer_name: str
27
+ optimizer_params: dict
28
+ weight_decay: float
29
+ no_wd_keys: list
30
+ grad_clip: float # allow none to turn off grad clipping
31
+ reset_bn: bool
32
+ reset_bn_size: int
33
+ reset_bn_batch_size: int
34
+ eval_image_size: list # allow none to use image_size in data_provider
35
+
36
+ @property
37
+ def none_allowed(self):
38
+ return ["grad_clip", "eval_image_size"]
39
+
40
+ def __init__(self, **kwargs): # arguments must be passed as kwargs
41
+ for k, val in kwargs.items():
42
+ setattr(self, k, val)
43
+
44
+ # check that all relevant configs are there
45
+ annotations = {}
46
+ for clas in type(self).mro():
47
+ if hasattr(clas, "__annotations__"):
48
+ annotations.update(clas.__annotations__)
49
+ for k, k_type in annotations.items():
50
+ assert hasattr(self, k), f"Key {k} with type {k_type} required for initialization."
51
+ attr = getattr(self, k)
52
+ if k in self.none_allowed:
53
+ k_type = (k_type, type(None))
54
+ assert isinstance(attr, k_type), f"Key {k} must be type {k_type}, provided={attr}."
55
+
56
+ self.global_step = 0
57
+ self.batch_per_epoch = 1
58
+
59
+ def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
60
+ r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
61
+ param_dict = {}
62
+ for name, param in network.named_parameters():
63
+ if param.requires_grad:
64
+ opt_config = [self.weight_decay, self.init_lr]
65
+ if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
66
+ if np.any([key in name for key in self.no_wd_keys]):
67
+ opt_config[0] = 0
68
+ opt_key = json.dumps(opt_config)
69
+ param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
70
+
71
+ net_params = []
72
+ for opt_key, param_list in param_dict.items():
73
+ wd, lr = json.loads(opt_key)
74
+ net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
75
+
76
+ optimizer = build_optimizer(net_params, self.optimizer_name, self.optimizer_params, self.init_lr)
77
+ # build lr scheduler
78
+ if self.lr_schedule_name == "cosine":
79
+ decay_steps = []
80
+ for epoch in self.lr_schedule_param.get("step", []):
81
+ decay_steps.append(epoch * self.batch_per_epoch)
82
+ decay_steps.append(self.n_epochs * self.batch_per_epoch)
83
+ decay_steps.sort()
84
+ lr_scheduler = CosineLRwithWarmup(
85
+ optimizer,
86
+ self.warmup_epochs * self.batch_per_epoch,
87
+ self.warmup_lr,
88
+ decay_steps,
89
+ )
90
+ else:
91
+ raise NotImplementedError
92
+ return optimizer, lr_scheduler
93
+
94
+ def update_global_step(self, epoch, batch_id=0) -> None:
95
+ self.global_step = epoch * self.batch_per_epoch + batch_id
96
+ Scheduler.PROGRESS = self.progress
97
+
98
+ @property
99
+ def progress(self) -> float:
100
+ warmup_steps = self.warmup_epochs * self.batch_per_epoch
101
+ steps = max(0, self.global_step - warmup_steps)
102
+ return steps / (self.n_epochs * self.batch_per_epoch)
103
+
104
+ def step(self) -> None:
105
+ self.global_step += 1
106
+ Scheduler.PROGRESS = self.progress
107
+
108
+ def get_remaining_epoch(self, epoch, post=True) -> int:
109
+ return self.n_epochs + self.warmup_epochs - epoch - int(post)
110
+
111
+ def epoch_format(self, epoch: int) -> str:
112
+ epoch_format = f"%.{len(str(self.n_epochs))}d"
113
+ epoch_format = f"[{epoch_format}/{epoch_format}]"
114
+ epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
115
+ return epoch_format
efficientvit/apps/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .dist import *
6
+ from .ema import *
7
+ from .export import *
8
+ from .init import *
9
+ from .lr import *
10
+ from .metric import *
11
+ from .misc import *
12
+ from .opt import *
efficientvit/apps/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (314 Bytes). View file
 
efficientvit/apps/utils/__pycache__/dist.cpython-310.pyc ADDED
Binary file (2.13 kB). View file
 
efficientvit/apps/utils/__pycache__/ema.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
efficientvit/apps/utils/__pycache__/export.cpython-310.pyc ADDED
Binary file (1.35 kB). View file
 
efficientvit/apps/utils/__pycache__/init.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
efficientvit/apps/utils/__pycache__/lr.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
efficientvit/apps/utils/__pycache__/metric.cpython-310.pyc ADDED
Binary file (1.61 kB). View file