reach-vb HF staff commited on
Commit
50f8b94
1 Parent(s): 8f05c80

afff4f781c38808b9aafcdd5ec92a88a4aca77a59bb34d95614d31dab397a490

Browse files
Files changed (50) hide show
  1. lib/python3.11/site-packages/functorch/dim/op_properties.py +311 -0
  2. lib/python3.11/site-packages/functorch/dim/reference.py +645 -0
  3. lib/python3.11/site-packages/functorch/dim/tree_map.py +14 -0
  4. lib/python3.11/site-packages/functorch/dim/wrap_type.py +71 -0
  5. lib/python3.11/site-packages/functorch/einops/__init__.py +3 -0
  6. lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc +0 -0
  7. lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc +0 -0
  8. lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc +0 -0
  9. lib/python3.11/site-packages/functorch/einops/_parsing.py +302 -0
  10. lib/python3.11/site-packages/functorch/einops/rearrange.py +207 -0
  11. lib/python3.11/site-packages/functorch/experimental/__init__.py +6 -0
  12. lib/python3.11/site-packages/functorch/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
  13. lib/python3.11/site-packages/functorch/experimental/__pycache__/_map.cpython-311.pyc +0 -0
  14. lib/python3.11/site-packages/functorch/experimental/__pycache__/control_flow.cpython-311.pyc +0 -0
  15. lib/python3.11/site-packages/functorch/experimental/__pycache__/ops.cpython-311.pyc +0 -0
  16. lib/python3.11/site-packages/functorch/experimental/_map.py +393 -0
  17. lib/python3.11/site-packages/functorch/experimental/control_flow.py +6 -0
  18. lib/python3.11/site-packages/functorch/experimental/ops.py +1 -0
  19. lib/python3.11/site-packages/huggingface_hub/__init__.py +650 -0
  20. lib/python3.11/site-packages/huggingface_hub/__pycache__/__init__.cpython-311.pyc +0 -0
  21. lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_api.cpython-311.pyc +0 -0
  22. lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_scheduler.cpython-311.pyc +0 -0
  23. lib/python3.11/site-packages/huggingface_hub/__pycache__/_inference_endpoints.cpython-311.pyc +0 -0
  24. lib/python3.11/site-packages/huggingface_hub/__pycache__/_login.cpython-311.pyc +0 -0
  25. lib/python3.11/site-packages/huggingface_hub/__pycache__/_multi_commits.cpython-311.pyc +0 -0
  26. lib/python3.11/site-packages/huggingface_hub/__pycache__/_snapshot_download.cpython-311.pyc +0 -0
  27. lib/python3.11/site-packages/huggingface_hub/__pycache__/_space_api.cpython-311.pyc +0 -0
  28. lib/python3.11/site-packages/huggingface_hub/__pycache__/_tensorboard_logger.cpython-311.pyc +0 -0
  29. lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_payload.cpython-311.pyc +0 -0
  30. lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_server.cpython-311.pyc +0 -0
  31. lib/python3.11/site-packages/huggingface_hub/__pycache__/community.cpython-311.pyc +0 -0
  32. lib/python3.11/site-packages/huggingface_hub/__pycache__/constants.cpython-311.pyc +0 -0
  33. lib/python3.11/site-packages/huggingface_hub/__pycache__/fastai_utils.cpython-311.pyc +0 -0
  34. lib/python3.11/site-packages/huggingface_hub/__pycache__/file_download.cpython-311.pyc +0 -0
  35. lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_api.cpython-311.pyc +0 -0
  36. lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_file_system.cpython-311.pyc +0 -0
  37. lib/python3.11/site-packages/huggingface_hub/__pycache__/hub_mixin.cpython-311.pyc +0 -0
  38. lib/python3.11/site-packages/huggingface_hub/__pycache__/inference_api.cpython-311.pyc +0 -0
  39. lib/python3.11/site-packages/huggingface_hub/__pycache__/keras_mixin.cpython-311.pyc +0 -0
  40. lib/python3.11/site-packages/huggingface_hub/__pycache__/lfs.cpython-311.pyc +0 -0
  41. lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard.cpython-311.pyc +0 -0
  42. lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard_data.cpython-311.pyc +0 -0
  43. lib/python3.11/site-packages/huggingface_hub/__pycache__/repository.cpython-311.pyc +0 -0
  44. lib/python3.11/site-packages/huggingface_hub/_commit_api.py +670 -0
  45. lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py +327 -0
  46. lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py +373 -0
  47. lib/python3.11/site-packages/huggingface_hub/_login.py +395 -0
  48. lib/python3.11/site-packages/huggingface_hub/_multi_commits.py +305 -0
  49. lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py +319 -0
  50. lib/python3.11/site-packages/huggingface_hub/_space_api.py +154 -0
lib/python3.11/site-packages/functorch/dim/op_properties.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import torch
7
+
8
+ # pointwise operators can go through a faster pathway
9
+
10
+ tensor_magic_methods = ["add", ""]
11
+ pointwise_magic_methods_with_reverse = (
12
+ "add",
13
+ "sub",
14
+ "mul",
15
+ "floordiv",
16
+ "div",
17
+ "truediv",
18
+ "mod",
19
+ "pow",
20
+ "lshift",
21
+ "rshift",
22
+ "and",
23
+ "or",
24
+ "xor",
25
+ )
26
+ pointwise_magic_methods = (
27
+ *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
28
+ "eq",
29
+ "gt",
30
+ "le",
31
+ "lt",
32
+ "ge",
33
+ "gt",
34
+ "ne",
35
+ "neg",
36
+ "pos",
37
+ "abs",
38
+ "invert",
39
+ "iadd",
40
+ "isub",
41
+ "imul",
42
+ "ifloordiv",
43
+ "idiv",
44
+ "itruediv",
45
+ "imod",
46
+ "ipow",
47
+ "ilshift",
48
+ "irshift",
49
+ "iand",
50
+ "ior",
51
+ "ixor",
52
+ "int",
53
+ "long",
54
+ "float",
55
+ "complex",
56
+ )
57
+
58
+ pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
59
+
60
+ pointwise = (
61
+ *(getattr(torch.Tensor, m) for m in pointwise_methods),
62
+ torch.nn.functional.dropout,
63
+ torch.where,
64
+ torch.Tensor.abs,
65
+ torch.abs,
66
+ torch.Tensor.acos,
67
+ torch.acos,
68
+ torch.Tensor.acosh,
69
+ torch.acosh,
70
+ torch.Tensor.add,
71
+ torch.add,
72
+ torch.Tensor.addcdiv,
73
+ torch.addcdiv,
74
+ torch.Tensor.addcmul,
75
+ torch.addcmul,
76
+ torch.Tensor.addr,
77
+ torch.addr,
78
+ torch.Tensor.angle,
79
+ torch.angle,
80
+ torch.Tensor.asin,
81
+ torch.asin,
82
+ torch.Tensor.asinh,
83
+ torch.asinh,
84
+ torch.Tensor.atan,
85
+ torch.atan,
86
+ torch.Tensor.atan2,
87
+ torch.atan2,
88
+ torch.Tensor.atanh,
89
+ torch.atanh,
90
+ torch.Tensor.bitwise_and,
91
+ torch.bitwise_and,
92
+ torch.Tensor.bitwise_left_shift,
93
+ torch.bitwise_left_shift,
94
+ torch.Tensor.bitwise_not,
95
+ torch.bitwise_not,
96
+ torch.Tensor.bitwise_or,
97
+ torch.bitwise_or,
98
+ torch.Tensor.bitwise_right_shift,
99
+ torch.bitwise_right_shift,
100
+ torch.Tensor.bitwise_xor,
101
+ torch.bitwise_xor,
102
+ torch.Tensor.ceil,
103
+ torch.ceil,
104
+ torch.celu,
105
+ torch.nn.functional.celu,
106
+ torch.Tensor.clamp,
107
+ torch.clamp,
108
+ torch.Tensor.clamp_max,
109
+ torch.clamp_max,
110
+ torch.Tensor.clamp_min,
111
+ torch.clamp_min,
112
+ torch.Tensor.copysign,
113
+ torch.copysign,
114
+ torch.Tensor.cos,
115
+ torch.cos,
116
+ torch.Tensor.cosh,
117
+ torch.cosh,
118
+ torch.Tensor.deg2rad,
119
+ torch.deg2rad,
120
+ torch.Tensor.digamma,
121
+ torch.digamma,
122
+ torch.Tensor.div,
123
+ torch.div,
124
+ torch.dropout,
125
+ torch.nn.functional.dropout,
126
+ torch.nn.functional.elu,
127
+ torch.Tensor.eq,
128
+ torch.eq,
129
+ torch.Tensor.erf,
130
+ torch.erf,
131
+ torch.Tensor.erfc,
132
+ torch.erfc,
133
+ torch.Tensor.erfinv,
134
+ torch.erfinv,
135
+ torch.Tensor.exp,
136
+ torch.exp,
137
+ torch.Tensor.exp2,
138
+ torch.exp2,
139
+ torch.Tensor.expm1,
140
+ torch.expm1,
141
+ torch.feature_dropout,
142
+ torch.Tensor.float_power,
143
+ torch.float_power,
144
+ torch.Tensor.floor,
145
+ torch.floor,
146
+ torch.Tensor.floor_divide,
147
+ torch.floor_divide,
148
+ torch.Tensor.fmod,
149
+ torch.fmod,
150
+ torch.Tensor.frac,
151
+ torch.frac,
152
+ torch.Tensor.frexp,
153
+ torch.frexp,
154
+ torch.Tensor.gcd,
155
+ torch.gcd,
156
+ torch.Tensor.ge,
157
+ torch.ge,
158
+ torch.nn.functional.gelu,
159
+ torch.nn.functional.glu,
160
+ torch.Tensor.gt,
161
+ torch.gt,
162
+ torch.Tensor.hardshrink,
163
+ torch.hardshrink,
164
+ torch.nn.functional.hardshrink,
165
+ torch.nn.functional.hardsigmoid,
166
+ torch.nn.functional.hardswish,
167
+ torch.nn.functional.hardtanh,
168
+ torch.Tensor.heaviside,
169
+ torch.heaviside,
170
+ torch.Tensor.hypot,
171
+ torch.hypot,
172
+ torch.Tensor.i0,
173
+ torch.i0,
174
+ torch.Tensor.igamma,
175
+ torch.igamma,
176
+ torch.Tensor.igammac,
177
+ torch.igammac,
178
+ torch.Tensor.isclose,
179
+ torch.isclose,
180
+ torch.Tensor.isfinite,
181
+ torch.isfinite,
182
+ torch.Tensor.isinf,
183
+ torch.isinf,
184
+ torch.Tensor.isnan,
185
+ torch.isnan,
186
+ torch.Tensor.isneginf,
187
+ torch.isneginf,
188
+ torch.Tensor.isposinf,
189
+ torch.isposinf,
190
+ torch.Tensor.isreal,
191
+ torch.isreal,
192
+ torch.Tensor.kron,
193
+ torch.kron,
194
+ torch.Tensor.lcm,
195
+ torch.lcm,
196
+ torch.Tensor.ldexp,
197
+ torch.ldexp,
198
+ torch.Tensor.le,
199
+ torch.le,
200
+ torch.nn.functional.leaky_relu,
201
+ torch.Tensor.lerp,
202
+ torch.lerp,
203
+ torch.Tensor.lgamma,
204
+ torch.lgamma,
205
+ torch.Tensor.log,
206
+ torch.log,
207
+ torch.Tensor.log10,
208
+ torch.log10,
209
+ torch.Tensor.log1p,
210
+ torch.log1p,
211
+ torch.Tensor.log2,
212
+ torch.log2,
213
+ torch.nn.functional.logsigmoid,
214
+ torch.Tensor.logical_and,
215
+ torch.logical_and,
216
+ torch.Tensor.logical_not,
217
+ torch.logical_not,
218
+ torch.Tensor.logical_or,
219
+ torch.logical_or,
220
+ torch.Tensor.logical_xor,
221
+ torch.logical_xor,
222
+ torch.Tensor.logit,
223
+ torch.logit,
224
+ torch.Tensor.lt,
225
+ torch.lt,
226
+ torch.Tensor.maximum,
227
+ torch.maximum,
228
+ torch.Tensor.minimum,
229
+ torch.minimum,
230
+ torch.nn.functional.mish,
231
+ torch.Tensor.mvlgamma,
232
+ torch.mvlgamma,
233
+ torch.Tensor.nan_to_num,
234
+ torch.nan_to_num,
235
+ torch.Tensor.ne,
236
+ torch.ne,
237
+ torch.Tensor.neg,
238
+ torch.neg,
239
+ torch.Tensor.nextafter,
240
+ torch.nextafter,
241
+ torch.Tensor.outer,
242
+ torch.outer,
243
+ torch.polar,
244
+ torch.Tensor.polygamma,
245
+ torch.polygamma,
246
+ torch.Tensor.positive,
247
+ torch.positive,
248
+ torch.Tensor.pow,
249
+ torch.pow,
250
+ torch.Tensor.prelu,
251
+ torch.prelu,
252
+ torch.nn.functional.prelu,
253
+ torch.Tensor.rad2deg,
254
+ torch.rad2deg,
255
+ torch.Tensor.reciprocal,
256
+ torch.reciprocal,
257
+ torch.Tensor.relu,
258
+ torch.relu,
259
+ torch.nn.functional.relu,
260
+ torch.nn.functional.relu6,
261
+ torch.Tensor.remainder,
262
+ torch.remainder,
263
+ torch.Tensor.round,
264
+ torch.round,
265
+ torch.rrelu,
266
+ torch.nn.functional.rrelu,
267
+ torch.Tensor.rsqrt,
268
+ torch.rsqrt,
269
+ torch.rsub,
270
+ torch.selu,
271
+ torch.nn.functional.selu,
272
+ torch.Tensor.sgn,
273
+ torch.sgn,
274
+ torch.Tensor.sigmoid,
275
+ torch.sigmoid,
276
+ torch.nn.functional.sigmoid,
277
+ torch.Tensor.sign,
278
+ torch.sign,
279
+ torch.Tensor.signbit,
280
+ torch.signbit,
281
+ torch.nn.functional.silu,
282
+ torch.Tensor.sin,
283
+ torch.sin,
284
+ torch.Tensor.sinc,
285
+ torch.sinc,
286
+ torch.Tensor.sinh,
287
+ torch.sinh,
288
+ torch.nn.functional.softplus,
289
+ torch.nn.functional.softshrink,
290
+ torch.Tensor.sqrt,
291
+ torch.sqrt,
292
+ torch.Tensor.square,
293
+ torch.square,
294
+ torch.Tensor.sub,
295
+ torch.sub,
296
+ torch.Tensor.tan,
297
+ torch.tan,
298
+ torch.Tensor.tanh,
299
+ torch.tanh,
300
+ torch.nn.functional.tanh,
301
+ torch.threshold,
302
+ torch.nn.functional.threshold,
303
+ torch.trapz,
304
+ torch.Tensor.true_divide,
305
+ torch.true_divide,
306
+ torch.Tensor.trunc,
307
+ torch.trunc,
308
+ torch.Tensor.xlogy,
309
+ torch.xlogy,
310
+ torch.rand_like,
311
+ )
lib/python3.11/site-packages/functorch/dim/reference.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # reference python implementations for C ops
8
+ import torch
9
+
10
+ from functorch._C import dim as _C
11
+ from . import op_properties
12
+ from .batch_tensor import _enable_layers
13
+ from .tree_map import tree_flatten, tree_map
14
+
15
+ DimList = _C.DimList
16
+ import operator
17
+ from functools import reduce
18
+
19
+
20
+ # use dict to avoid writing C++ bindings for set
21
+ pointwise = set(op_properties.pointwise)
22
+
23
+
24
+ def prod(x):
25
+ return reduce(operator.mul, x, 1)
26
+
27
+
28
+ def _wrap_dim(d, N, keepdim):
29
+ from . import Dim
30
+
31
+ if isinstance(d, Dim):
32
+ assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
33
+ return d
34
+ elif d >= 0:
35
+ return d - N
36
+ else:
37
+ return d
38
+
39
+
40
+ def _dims(d, N, keepdim, single_dim):
41
+ from . import Dim
42
+
43
+ if isinstance(d, (Dim, int)):
44
+ return ltuple((_wrap_dim(d, N, keepdim),))
45
+ assert not single_dim, f"expected a single dimension or int but found: {d}"
46
+ return ltuple(_wrap_dim(x, N, keepdim) for x in d)
47
+
48
+
49
+ def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
50
+ from . import DimensionMismatchError
51
+
52
+ not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
53
+ if len(not_bound) == 1:
54
+ idx, d = not_bound[0]
55
+ rhs_so_far = prod(r.size for r in rhs if r.is_bound)
56
+ if lhs_size % rhs_so_far != 0:
57
+ rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
58
+ raise DimensionMismatchError(
59
+ f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
60
+ )
61
+ new_size = lhs_size // rhs_so_far
62
+ d.size = new_size
63
+ elif len(not_bound) > 1:
64
+ rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
65
+ raise DimensionMismatchError(
66
+ f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
67
+ )
68
+ else:
69
+ rhs_size = prod(r.size for r in rhs)
70
+ if lhs_size != rhs_size:
71
+ raise DimensionMismatchError(
72
+ f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
73
+ )
74
+
75
+
76
+ def _tensor_levels(inp):
77
+ from . import _Tensor
78
+
79
+ if isinstance(inp, _Tensor):
80
+ return inp._tensor, llist(inp._levels), inp._has_device
81
+ else:
82
+ return inp, llist(range(-inp.ndim, 0)), True
83
+
84
+
85
+ def _match_levels(v, from_levels, to_levels):
86
+ view = []
87
+ permute = []
88
+ requires_view = False
89
+ size = v.size()
90
+ for t in to_levels:
91
+ try:
92
+ idx = from_levels.index(t)
93
+ permute.append(idx)
94
+ view.append(size[idx])
95
+ except ValueError:
96
+ view.append(1)
97
+ requires_view = True
98
+ if permute != list(range(len(permute))):
99
+ v = v.permute(*permute)
100
+ if requires_view:
101
+ v = v.view(*view)
102
+ return v
103
+
104
+
105
+ # make a single dimension positional but do not permute it,
106
+ # used to do multi-tensor operators where the dim being acted on
107
+ # should not physically move if possible
108
+ def _positional_no_permute(self, dim, expand_dim=False):
109
+ from . import Tensor
110
+
111
+ ptensor, levels = self._tensor, llist(self._levels)
112
+ try:
113
+ idx = levels.index(dim)
114
+ except ValueError:
115
+ if not expand_dim:
116
+ raise
117
+ idx = 0
118
+ ptensor = ptensor.expand(dim.size, *ptensor.size())
119
+ levels.insert(0, 0)
120
+ idx_batched = 0
121
+ for i in range(idx):
122
+ if isinstance(levels[i], int):
123
+ levels[i] -= 1
124
+ idx_batched += 1
125
+ levels[idx] = -idx_batched - 1
126
+ return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
127
+
128
+
129
+ def seq(a, b):
130
+ from . import Dim
131
+
132
+ if isinstance(a, Dim) != isinstance(b, Dim):
133
+ return False
134
+ if isinstance(a, Dim):
135
+ return a is b
136
+ else:
137
+ return a == b
138
+
139
+
140
+ class isin:
141
+ def __contains__(self, item):
142
+ for x in self:
143
+ if seq(item, x):
144
+ return True
145
+ return False
146
+
147
+ def index(self, item):
148
+ for i, x in enumerate(self):
149
+ if seq(item, x):
150
+ return i
151
+ raise ValueError
152
+
153
+
154
+ class llist(isin, list):
155
+ pass
156
+
157
+
158
+ class ltuple(isin, tuple):
159
+ pass
160
+
161
+
162
+ empty_dict = {}
163
+
164
+
165
+ @classmethod
166
+ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
167
+ from . import _Tensor, Tensor, TensorLike
168
+ from .delayed_mul_tensor import DelayedMulTensor
169
+
170
+ if orig is torch.Tensor.__mul__:
171
+ lhs, rhs = args
172
+ if (
173
+ isinstance(lhs, _Tensor)
174
+ and isinstance(rhs, _Tensor)
175
+ and lhs.ndim == 0
176
+ and rhs.ndim == 0
177
+ ):
178
+ return DelayedMulTensor(lhs, rhs)
179
+ all_dims = llist()
180
+ flat_args, unflatten = tree_flatten((args, kwargs))
181
+ device_holding_tensor = None
182
+ for f in flat_args:
183
+ if isinstance(f, _Tensor):
184
+ if f._has_device:
185
+ device_holding_tensor = f._batchtensor
186
+ for d in f.dims:
187
+ if d not in all_dims:
188
+ all_dims.append(d)
189
+
190
+ def unwrap(t):
191
+ if isinstance(t, _Tensor):
192
+ r = t._batchtensor
193
+ if device_holding_tensor is not None and not t._has_device:
194
+ r = r.to(device=device_holding_tensor.device)
195
+ return r
196
+ return t
197
+
198
+ if orig in pointwise:
199
+ result_levels = llist()
200
+ arg_levels = llist()
201
+ to_expand = []
202
+ for i, f in enumerate(flat_args):
203
+ if isinstance(f, TensorLike):
204
+ ptensor, levels, _ = _tensor_levels(f)
205
+ if (
206
+ isinstance(f, _Tensor)
207
+ and not f._has_device
208
+ and device_holding_tensor is not None
209
+ ):
210
+ ptensor = ptensor.to(device=device_holding_tensor.device)
211
+ flat_args[i] = ptensor
212
+ for l in levels:
213
+ if l not in result_levels:
214
+ result_levels.append(l)
215
+ to_expand.append((i, levels))
216
+
217
+ for i, levels in to_expand:
218
+ flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
219
+ args, kwargs = unflatten(flat_args)
220
+ result = orig(*args, **kwargs)
221
+
222
+ def wrap(t):
223
+ if isinstance(t, TensorLike):
224
+ return Tensor.from_positional(
225
+ t, result_levels, device_holding_tensor is not None
226
+ )
227
+ return t
228
+
229
+ return tree_map(wrap, result)
230
+ else:
231
+
232
+ def wrap(t):
233
+ if isinstance(t, TensorLike):
234
+ return Tensor.from_batched(t, device_holding_tensor is not None)
235
+ return t
236
+
237
+ with _enable_layers(all_dims):
238
+ print(f"batch_tensor for {orig}")
239
+ args, kwargs = unflatten(unwrap(f) for f in flat_args)
240
+ result = orig(*args, **kwargs)
241
+ # print("END", orig)
242
+ return tree_map(wrap, result)
243
+
244
+
245
+ def positional(self, *dims):
246
+ from . import Dim, Tensor
247
+
248
+ ptensor, levels = self._tensor, llist(self._levels)
249
+ flat_dims = llist()
250
+ view = []
251
+ needs_view = False
252
+ ndim = self.ndim
253
+ for d in dims:
254
+ if isinstance(d, DimList):
255
+ flat_dims.extend(d)
256
+ view.extend(e.size for e in d)
257
+ elif isinstance(d, Dim):
258
+ flat_dims.append(d)
259
+ view.append(d.size)
260
+ elif isinstance(d, int):
261
+ d = _wrap_dim(d, ndim, False)
262
+ flat_dims.append(d)
263
+ view.append(ptensor.size(d))
264
+ else:
265
+ flat_dims.extend(d)
266
+ view.append(prod(e.size for e in d))
267
+ needs_view = True
268
+
269
+ permute = list(range(len(levels)))
270
+ nflat = len(flat_dims)
271
+ for i, d in enumerate(flat_dims):
272
+ try:
273
+ idx = levels.index(d)
274
+ except ValueError as e:
275
+ raise DimensionBindError(
276
+ f"tensor of dimensions {self.dims} does not contain dim {d}"
277
+ ) from e
278
+ p = permute[idx]
279
+ del levels[idx]
280
+ del permute[idx]
281
+ levels.insert(i, 0)
282
+ permute.insert(i, p)
283
+ ptensor = ptensor.permute(*permute)
284
+ seen = 0
285
+ for i in range(len(levels) - 1, -1, -1):
286
+ if isinstance(levels[i], int):
287
+ seen += 1
288
+ levels[i] = -seen
289
+ result = Tensor.from_positional(ptensor, levels, self._has_device)
290
+ if needs_view:
291
+ result = result.reshape(*view, *result.size()[len(flat_dims) :])
292
+ return result
293
+
294
+
295
+ def _contains_dim(input):
296
+ from . import Dim
297
+
298
+ for i in input:
299
+ if isinstance(i, Dim):
300
+ return True
301
+
302
+
303
+ def expand(self, *sizes):
304
+ if not _contains_dim(sizes):
305
+ return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
306
+ dims = sizes
307
+ sizes = [d.size for d in dims] + [-1] * self.ndim
308
+ self = self.expand(*sizes)
309
+ return self[dims]
310
+
311
+
312
+ _not_present = object()
313
+
314
+
315
+ def _getarg(name, offset, args, kwargs, default):
316
+ if len(args) > offset:
317
+ return args[offset]
318
+ return kwargs.get(name, default)
319
+
320
+
321
+ def _patcharg(name, offset, args, kwargs, value):
322
+ if len(args) > offset:
323
+ args[offset] = value
324
+ else:
325
+ kwargs[name] = value
326
+
327
+
328
+ def _wrap(
329
+ orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
330
+ ):
331
+ from . import Dim, Tensor, TensorLike
332
+
333
+ def fn(self, *args, **kwargs):
334
+ dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
335
+ if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
336
+ with _enable_layers(self.dims):
337
+ print(f"dim fallback batch_tensor for {orig}")
338
+ return Tensor.from_batched(
339
+ orig(self._batchtensor, *args, **kwargs), self._has_device
340
+ )
341
+ keepdim = (
342
+ _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
343
+ )
344
+ t, levels = self._tensor, llist(self._levels)
345
+ dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
346
+ dim_indices = tuple(levels.index(d) for d in dims)
347
+ if reduce and not keepdim:
348
+ new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
349
+ else:
350
+ new_levels = levels
351
+
352
+ if len(dim_indices) == 1:
353
+ dim_indices = dim_indices[
354
+ 0
355
+ ] # so that dims that really only take a single argument work...
356
+ args = list(args)
357
+ _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
358
+
359
+ def wrap(t):
360
+ if isinstance(t, TensorLike):
361
+ return Tensor.from_positional(t, new_levels, self._has_device)
362
+ return t
363
+
364
+ with _enable_layers(new_levels):
365
+ print(f"dim used batch_tensor for {orig}")
366
+ r = orig(t, *args, **kwargs)
367
+ return tree_map(wrap, r)
368
+
369
+ return fn
370
+
371
+
372
+ def _def(name, *args, **kwargs):
373
+ from . import _Tensor
374
+
375
+ orig = getattr(torch.Tensor, name)
376
+ setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
377
+
378
+
379
+ no_slice = slice(None)
380
+
381
+ _orig_getitem = torch.Tensor.__getitem__
382
+
383
+
384
+ class dim_tracker:
385
+ def __init__(self):
386
+ self.dims = llist()
387
+ self.count = []
388
+
389
+ def record(self, d):
390
+ if d not in self.dims:
391
+ self.dims.append(d)
392
+ self.count.append(1)
393
+
394
+ def __getitem__(self, d):
395
+ return self.count[self.dims.index(d)]
396
+
397
+
398
+ def t__getitem__(self, input):
399
+ from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
400
+
401
+ # * bail to original example if we have a single non-Dim tensor, or a non-tensor
402
+ # * locate ... or an unbound tensor list, and determine its size, bind dim list
403
+ # (remember that None does not count to the total dim count)
404
+ # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
405
+ # produce the re-view if needed
406
+ # * for each single-use dim index, replace with no_slice and mark that it will be added
407
+ # (keep track of whether we have to call super)
408
+ # * call super if needed
409
+ # * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
410
+
411
+ # this handles bool indexing handling, as well as some other simple cases.
412
+
413
+ is_simple = (
414
+ not isinstance(input, Dim)
415
+ and not isinstance(input, (tuple, list))
416
+ and
417
+ # WAR for functorch bug where zero time tensors in getitem are not handled correctly.
418
+ not (isinstance(input, TensorLike) and input.ndim == 0)
419
+ )
420
+
421
+ if is_simple:
422
+ if isinstance(self, _Tensor):
423
+ return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
424
+ else:
425
+ return _orig_getitem(self, input)
426
+
427
+ # can further optimize this case
428
+ if not isinstance(input, tuple):
429
+ input = [input]
430
+ else:
431
+ input = list(input)
432
+
433
+ dims_indexed = 0
434
+ expanding_object = None
435
+ dimlists = []
436
+ for i, s in enumerate(input):
437
+ if s is ... or isinstance(s, DimList) and not s.is_bound:
438
+ if expanding_object is not None:
439
+ msg = (
440
+ "at most one ... or unbound dimension list can exist in indexing list but"
441
+ f" found 2 at offsets {i} and {expanding_object}"
442
+ )
443
+ raise DimensionBindError(msg)
444
+ expanding_object = i
445
+
446
+ if isinstance(s, DimList):
447
+ dims_indexed += len(s) if s.is_bound else 0
448
+ dimlists.append(i)
449
+ elif s is not None and s is not ...:
450
+ dims_indexed += 1
451
+
452
+ ndim = self.ndim
453
+ if dims_indexed > ndim:
454
+ raise IndexError(
455
+ f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
456
+ )
457
+ if expanding_object is not None:
458
+ expanding_ndims = ndim - dims_indexed
459
+ obj = input[expanding_object]
460
+ if obj is ...:
461
+ input[expanding_object : expanding_object + 1] = [
462
+ no_slice
463
+ ] * expanding_ndims
464
+ else:
465
+ obj.bind_len(expanding_ndims)
466
+ # flatten the dimslists into the indexing
467
+ for i in reversed(dimlists):
468
+ input[i : i + 1] = input[i]
469
+ dims_indexed = 0
470
+ requires_view = False
471
+ size = self.size()
472
+ view_sizes = []
473
+ dims_seen = dim_tracker()
474
+
475
+ def add_dims(t):
476
+ if not isinstance(t, _Tensor):
477
+ return
478
+ for d in t.dims:
479
+ dims_seen.record(d)
480
+
481
+ add_dims(self)
482
+ dim_packs = []
483
+ for i, idx in enumerate(input):
484
+ if idx is None:
485
+ input[i] = no_slice
486
+ view_sizes.append(1)
487
+ requires_view = True
488
+ else:
489
+ sz = size[dims_indexed]
490
+ if isinstance(idx, Dim):
491
+ idx.size = sz
492
+ dims_seen.record(idx)
493
+ view_sizes.append(sz)
494
+ elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
495
+ for d in idx:
496
+ dims_seen.record(idx)
497
+ _bind_dims_to_size(sz, idx, f"offset {i}")
498
+ view_sizes.extend(d.size for d in idx)
499
+ requires_view = True
500
+ dim_packs.append(i)
501
+ else:
502
+ add_dims(idx)
503
+ view_sizes.append(sz)
504
+ dims_indexed += 1
505
+ if requires_view:
506
+ self = self.view(*view_sizes)
507
+ for i in reversed(dim_packs):
508
+ input[i : i + 1] = input[i]
509
+
510
+ # currenty:
511
+ # input is flat, containing either Dim, or Tensor, or something valid for standard indexing
512
+ # self may have first-class dims as well.
513
+
514
+ # to index:
515
+ # drop the first class dims from self, they just become direct indices of their positions
516
+
517
+ # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
518
+ # these dimensions will appear and need to be bound at the first place tensor occures
519
+
520
+ if isinstance(self, _Tensor):
521
+ ptensor_self, levels = self._tensor, list(self._levels)
522
+ # indices to ptensor rather than self which has first-class dimensions
523
+ input_it = iter(input)
524
+ flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
525
+ has_device = self._has_device
526
+ to_pad = 0
527
+ else:
528
+ ptensor_self, flat_inputs = self, input
529
+ to_pad = ptensor_self.ndim - len(flat_inputs)
530
+ has_device = True
531
+
532
+ result_levels = []
533
+ index_levels = []
534
+ tensor_insert_point = None
535
+ to_expand = {}
536
+ requires_getindex = False
537
+ for i, inp in enumerate(flat_inputs):
538
+ if isinstance(inp, Dim) and dims_seen[inp] == 1:
539
+ flat_inputs[i] = no_slice
540
+ result_levels.append(inp)
541
+ elif isinstance(inp, TensorLike):
542
+ requires_getindex = True
543
+ if tensor_insert_point is None:
544
+ tensor_insert_point = len(result_levels)
545
+ ptensor, levels, _ = _tensor_levels(inp)
546
+ to_expand[i] = levels
547
+ flat_inputs[i] = ptensor
548
+ for l in levels:
549
+ if l not in index_levels:
550
+ index_levels.append(l)
551
+ else:
552
+ requires_getindex = True
553
+ result_levels.append(0)
554
+
555
+ if tensor_insert_point is not None:
556
+ result_levels[tensor_insert_point:tensor_insert_point] = index_levels
557
+
558
+ for i, levels in to_expand.items():
559
+ flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
560
+
561
+ if requires_getindex:
562
+ result = _orig_getitem(ptensor_self, flat_inputs)
563
+ else:
564
+ result = ptensor_self
565
+
566
+ next_positional = -1
567
+ if to_pad > 0:
568
+ result_levels.extend([0] * to_pad)
569
+ for i, r in enumerate(reversed(result_levels)):
570
+ if isinstance(r, int):
571
+ result_levels[-1 - i] = next_positional
572
+ next_positional -= 1
573
+
574
+ return Tensor.from_positional(result, result_levels, has_device)
575
+
576
+
577
+ # XXX - dim is optional and can be the outer-most dimension...
578
+ def stack(tensors, new_dim, dim=0, out=None):
579
+ if isinstance(dim, int):
580
+ return torch.stack(tensors, dim, out).index(dim, new_dim)
581
+ index = None
582
+ if out is not None:
583
+ out, index = _positional_no_permute(out, dim, expand_dim=True)
584
+ ptensors = []
585
+ for t in tensors:
586
+ pt, pi = _positional_no_permute(t, dim, expand_dim=True)
587
+ if index is not None and pi != index:
588
+ pt = pt.move_dim(pi, index)
589
+ else:
590
+ index = pi
591
+ ptensors.append(pt)
592
+ pr = torch.stack(ptensors, index, out=out)
593
+ return pr.index((index, index + 1), (new_dim, dim))
594
+
595
+
596
+ _orig_split = torch.Tensor.split
597
+
598
+
599
+ def split(self, split_size_or_sections, dim=0):
600
+ from . import _Tensor, Dim
601
+
602
+ if isinstance(split_size_or_sections, int) or any(
603
+ isinstance(t, int) for t in split_size_or_sections
604
+ ):
605
+ if isinstance(dim, Dim):
606
+ raise ValueError(
607
+ "when dim is specified as a Dim object, split sizes must also be dimensions."
608
+ )
609
+ return _orig_split(self, split_size_or_sections, dim=dim)
610
+
611
+ if isinstance(dim, Dim):
612
+ assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
613
+ self, dim = _positional_no_permute(self, dim)
614
+
615
+ size = self.size(dim)
616
+ total_bound_size = 0
617
+ unbound = []
618
+ sizes = []
619
+ for i, d in enumerate(split_size_or_sections):
620
+ if d.is_bound:
621
+ sizes.append(d.size)
622
+ total_bound_size += d.size
623
+ else:
624
+ sizes.append(0)
625
+ unbound.append(i)
626
+
627
+ if unbound:
628
+ assert (
629
+ total_bound_size <= size
630
+ ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
631
+ remaining_size = size - total_bound_size
632
+ chunk_size = -(-remaining_size // len(unbound))
633
+ for u in unbound:
634
+ sz = min(chunk_size, remaining_size)
635
+ split_size_or_sections[u].size = sz
636
+ sizes[u] = sz
637
+ remaining_size -= sz
638
+ else:
639
+ assert (
640
+ total_bound_size == size
641
+ ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
642
+ return tuple(
643
+ t.index(dim, d)
644
+ for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
645
+ )
lib/python3.11/site-packages/functorch/dim/tree_map.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functorch._C import dim
8
+
9
+ tree_flatten = dim.tree_flatten
10
+
11
+
12
+ def tree_map(fn, tree):
13
+ vs, unflatten = tree_flatten(tree)
14
+ return unflatten(fn(v) for v in vs)
lib/python3.11/site-packages/functorch/dim/wrap_type.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from types import (
8
+ BuiltinMethodType,
9
+ FunctionType,
10
+ GetSetDescriptorType,
11
+ MethodDescriptorType,
12
+ WrapperDescriptorType,
13
+ )
14
+
15
+ from functorch._C import dim as _C
16
+
17
+ _wrap_method = _C._wrap_method
18
+
19
+ FUNC_TYPES = (
20
+ FunctionType,
21
+ MethodDescriptorType,
22
+ BuiltinMethodType,
23
+ WrapperDescriptorType,
24
+ )
25
+ PROPERTY_TYPES = (GetSetDescriptorType, property)
26
+
27
+
28
+ def _py_wrap_method(orig, __torch_function__):
29
+ def impl(*args, **kwargs):
30
+ return __torch_function__(orig, None, args, kwargs)
31
+
32
+ return impl
33
+
34
+
35
+ def wrap_type(use_c, to_patch, pattern, __torch_function__):
36
+ if use_c:
37
+ wrap_method = _wrap_method
38
+ else:
39
+ wrap_method = _py_wrap_method
40
+
41
+ all = {}
42
+ for t in reversed(pattern.mro()[:-1]): # skip object
43
+ all.update(t.__dict__)
44
+
45
+ def wrap_attr(orig):
46
+ return property(wrap_method(orig.__get__, __torch_function__))
47
+
48
+ for name, obj in all.items():
49
+ if name in (
50
+ "__dict__",
51
+ "__new__",
52
+ "__init__",
53
+ "__repr__",
54
+ "__weakref__",
55
+ "__doc__",
56
+ "__module__",
57
+ "__dir__",
58
+ ):
59
+ continue
60
+
61
+ # skip things that have been overloaded
62
+ # things that come from object like `__eq__` still need to be patched, however.
63
+ if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
64
+ object, name, None
65
+ ):
66
+ continue
67
+
68
+ if isinstance(obj, FUNC_TYPES):
69
+ setattr(to_patch, name, wrap_method(obj, __torch_function__))
70
+ elif isinstance(obj, PROPERTY_TYPES):
71
+ setattr(to_patch, name, wrap_attr(obj))
lib/python3.11/site-packages/functorch/einops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .rearrange import rearrange
2
+
3
+ __all__ = ["rearrange"]
lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (302 Bytes). View file
 
lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
lib/python3.11/site-packages/functorch/einops/_parsing.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py.
2
+
3
+ MIT License
4
+
5
+ Copyright (c) 2018 Alex Rogozhnikov
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import keyword
28
+ import warnings
29
+ from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
30
+
31
+ _ellipsis: str = "…" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
32
+
33
+
34
+ class AnonymousAxis:
35
+ """Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier.
36
+
37
+ Note: Different instances of this class are not equal to each other, even if they have the same value.
38
+ """
39
+
40
+ def __init__(self, value: str) -> None:
41
+ self.value = int(value)
42
+ if self.value < 1:
43
+ raise ValueError(
44
+ f"Anonymous axis should have positive length, not {self.value}"
45
+ )
46
+
47
+ def __repr__(self) -> str:
48
+ return f"{self.value}-axis"
49
+
50
+
51
+ class ParsedExpression:
52
+ """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
53
+
54
+ def __init__(
55
+ self,
56
+ expression: str,
57
+ *,
58
+ allow_underscore: bool = False,
59
+ allow_duplicates: bool = False,
60
+ ) -> None:
61
+ """Parse the expression and store relevant metadata.
62
+
63
+ Args:
64
+ expression (str): the `einops`-pattern to parse
65
+ allow_underscore (bool): whether to allow axis identifier names to begin with an underscore
66
+ allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression
67
+ """
68
+ self.has_ellipsis: bool = False
69
+ self.has_ellipsis_parenthesized: Optional[bool] = None
70
+ self.identifiers: Set[Union[str, AnonymousAxis]] = set()
71
+ # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
72
+ self.has_non_unitary_anonymous_axes: bool = False
73
+ # composition keeps structure of composite axes, see how different corner cases are handled in tests
74
+ self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
75
+ if "." in expression:
76
+ if "..." not in expression:
77
+ raise ValueError(
78
+ "Expression may contain dots only inside ellipsis (...)"
79
+ )
80
+ if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
81
+ raise ValueError(
82
+ "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
83
+ )
84
+ expression = expression.replace("...", _ellipsis)
85
+ self.has_ellipsis = True
86
+
87
+ bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
88
+
89
+ def add_axis_name(x: str) -> None:
90
+ if x in self.identifiers:
91
+ if not (allow_underscore and x == "_") and not allow_duplicates:
92
+ raise ValueError(
93
+ f"Indexing expression contains duplicate dimension '{x}'"
94
+ )
95
+ if x == _ellipsis:
96
+ self.identifiers.add(_ellipsis)
97
+ if bracket_group is None:
98
+ self.composition.append(_ellipsis)
99
+ self.has_ellipsis_parenthesized = False
100
+ else:
101
+ bracket_group.append(_ellipsis)
102
+ self.has_ellipsis_parenthesized = True
103
+ else:
104
+ is_number = str.isdecimal(x)
105
+ if is_number and int(x) == 1:
106
+ # handling the case of anonymous axis of length 1
107
+ if bracket_group is None:
108
+ self.composition.append([])
109
+ else:
110
+ pass # no need to think about 1s inside parenthesis
111
+ return
112
+ is_axis_name, reason = self.check_axis_name_return_reason(
113
+ x, allow_underscore=allow_underscore
114
+ )
115
+ if not (is_number or is_axis_name):
116
+ raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
117
+ axis_name: Union[str, AnonymousAxis] = (
118
+ AnonymousAxis(x) if is_number else x
119
+ )
120
+ self.identifiers.add(axis_name)
121
+ if is_number:
122
+ self.has_non_unitary_anonymous_axes = True
123
+ if bracket_group is None:
124
+ self.composition.append([axis_name])
125
+ else:
126
+ bracket_group.append(axis_name)
127
+
128
+ current_identifier = None
129
+ for char in expression:
130
+ if char in "() ":
131
+ if current_identifier is not None:
132
+ add_axis_name(current_identifier)
133
+ current_identifier = None
134
+ if char == "(":
135
+ if bracket_group is not None:
136
+ raise ValueError(
137
+ "Axis composition is one-level (brackets inside brackets not allowed)"
138
+ )
139
+ bracket_group = []
140
+ elif char == ")":
141
+ if bracket_group is None:
142
+ raise ValueError("Brackets are not balanced")
143
+ self.composition.append(bracket_group)
144
+ bracket_group = None
145
+ elif str.isalnum(char) or char in ["_", _ellipsis]:
146
+ if current_identifier is None:
147
+ current_identifier = char
148
+ else:
149
+ current_identifier += char
150
+ else:
151
+ raise ValueError(f"Unknown character '{char}'")
152
+
153
+ if bracket_group is not None:
154
+ raise ValueError(f"Imbalanced parentheses in expression: '{expression}'")
155
+ if current_identifier is not None:
156
+ add_axis_name(current_identifier)
157
+
158
+ @staticmethod
159
+ def check_axis_name_return_reason(
160
+ name: str, allow_underscore: bool = False
161
+ ) -> Tuple[bool, str]:
162
+ """Check if the given axis name is valid, and a message explaining why if not.
163
+
164
+ Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
165
+
166
+ Args:
167
+ name (str): the axis name to check
168
+ allow_underscore (bool): whether axis names are allowed to start with an underscore
169
+
170
+ Returns:
171
+ Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
172
+ """
173
+ if not str.isidentifier(name):
174
+ return False, "not a valid python identifier"
175
+ elif name[0] == "_" or name[-1] == "_":
176
+ if name == "_" and allow_underscore:
177
+ return True, ""
178
+ return False, "axis name should should not start or end with underscore"
179
+ else:
180
+ if keyword.iskeyword(name):
181
+ warnings.warn(
182
+ f"It is discouraged to use axes names that are keywords: {name}",
183
+ RuntimeWarning,
184
+ )
185
+ if name in ["axis"]:
186
+ warnings.warn(
187
+ "It is discouraged to use 'axis' as an axis name and will raise an error in future",
188
+ FutureWarning,
189
+ )
190
+ return True, ""
191
+
192
+ @staticmethod
193
+ def check_axis_name(name: str) -> bool:
194
+ """Check if the name is a valid axis name.
195
+
196
+ Args:
197
+ name (str): the axis name to check
198
+
199
+ Returns:
200
+ bool: whether the axis name is valid
201
+ """
202
+ is_valid, _ = ParsedExpression.check_axis_name_return_reason(name)
203
+ return is_valid
204
+
205
+
206
+ def parse_pattern(
207
+ pattern: str, axes_lengths: Mapping[str, int]
208
+ ) -> Tuple[ParsedExpression, ParsedExpression]:
209
+ """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
210
+
211
+ Args:
212
+ pattern (str): the `einops`-style rearrangement pattern
213
+ axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
214
+
215
+ Returns:
216
+ Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
217
+ """
218
+ # adapted from einops.einops._prepare_transformation_recipe
219
+ # https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
220
+ try:
221
+ left_str, right_str = pattern.split("->")
222
+ except ValueError:
223
+ raise ValueError("Pattern must contain a single '->' separator") from None
224
+
225
+ if _ellipsis in axes_lengths:
226
+ raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier")
227
+
228
+ left = ParsedExpression(left_str)
229
+ right = ParsedExpression(right_str)
230
+
231
+ if not left.has_ellipsis and right.has_ellipsis:
232
+ raise ValueError(
233
+ f"Ellipsis found in right side, but not left side of a pattern {pattern}"
234
+ )
235
+ if left.has_ellipsis and left.has_ellipsis_parenthesized:
236
+ raise ValueError(
237
+ f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
238
+ )
239
+
240
+ return left, right
241
+
242
+
243
+ def validate_rearrange_expressions(
244
+ left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int]
245
+ ) -> None:
246
+ """Perform expression validations that are specific to the `rearrange` operation.
247
+
248
+ Args:
249
+ left (ParsedExpression): left-hand side expression
250
+ right (ParsedExpression): right-hand side expression
251
+ axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
252
+ """
253
+ for length in axes_lengths.values():
254
+ if (length_type := type(length)) is not int:
255
+ raise TypeError(
256
+ f"rearrange axis lengths must be integers, got: {length_type}"
257
+ )
258
+
259
+ if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
260
+ raise ValueError("rearrange only supports unnamed axes of size 1")
261
+
262
+ difference = set.symmetric_difference(left.identifiers, right.identifiers)
263
+ if len(difference) > 0:
264
+ raise ValueError(
265
+ f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
266
+ )
267
+
268
+ unmatched_axes = axes_lengths.keys() - left.identifiers
269
+ if len(unmatched_axes) > 0:
270
+ raise ValueError(
271
+ f"Identifiers not found in rearrange expression: {unmatched_axes}"
272
+ )
273
+
274
+
275
+ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
276
+ """Convert a collection of strings representing first class dims into a comma-separated string.
277
+
278
+ Args:
279
+ collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert
280
+
281
+ Returns:
282
+ str: the comma-separated string
283
+
284
+ Examples:
285
+ >>> comma_separate(('d0',))
286
+ 'd0'
287
+
288
+ >>> comma_separate(('d0', 'd1', 'd2', 'd3'))
289
+ 'd0, d1, d2, d3'
290
+
291
+ >>> comma_separate([('d1', 'd4')])
292
+ '(d1, d4)'
293
+
294
+ >>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
295
+ '(d0,), (), (d1,), (d2,), (d3, d4)'
296
+ """
297
+ return ", ".join(
298
+ item
299
+ if isinstance(item, str)
300
+ else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
301
+ for item in collection
302
+ )
lib/python3.11/site-packages/functorch/einops/rearrange.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from typing import Callable, Dict, List, Sequence, Tuple, Union
5
+
6
+ import torch
7
+
8
+ from functorch._C import dim as _C
9
+ from ._parsing import (
10
+ _ellipsis,
11
+ AnonymousAxis,
12
+ comma_separate,
13
+ parse_pattern,
14
+ validate_rearrange_expressions,
15
+ )
16
+
17
+ __all__ = ["rearrange"]
18
+
19
+ dims = _C.dims
20
+
21
+
22
+ @functools.lru_cache(256)
23
+ def _create_rearrange_callable(
24
+ tensor_ndim: int, pattern: str, **axes_lengths: int
25
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
26
+ r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
27
+
28
+ Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
29
+ specified axes lengths, this function can be memoized.
30
+
31
+ Args:
32
+ tensor_ndim (int): the number of dimensions in the tensor to rearrange
33
+ pattern (str): the `einops`-style rearrangement pattern
34
+ axes_lengths (int): any additional length specifications for dimensions
35
+
36
+ Returns:
37
+ Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
38
+ """
39
+ left, right = parse_pattern(pattern, axes_lengths)
40
+ validate_rearrange_expressions(left, right, axes_lengths)
41
+
42
+ n_anon_dims = sum(not dim for dim in left.composition)
43
+ if left.has_ellipsis:
44
+ n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
45
+ n_named_dims = len(left.identifiers) - 1
46
+
47
+ if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
48
+ raise ValueError(
49
+ f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
50
+ f"dimensions in the tensor ({tensor_ndim})"
51
+ )
52
+ else:
53
+ n_ellipsis_dims = 0
54
+ n_named_dims = len(left.identifiers)
55
+
56
+ if (pattern_ndim := len(left.composition)) != tensor_ndim:
57
+ raise ValueError(
58
+ f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
59
+ f"the tensor ({tensor_ndim})"
60
+ )
61
+ n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
62
+
63
+ if n_dims == 0:
64
+ # an identity rearrangement on a 0-dimension tensor
65
+ return lambda tensor: tensor
66
+
67
+ first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
68
+ identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
69
+ anon_axes: List[AnonymousAxis] = []
70
+
71
+ # map the left-hand side identifiers to strings representing first class dims
72
+ dims_i = 0
73
+ for dimension in left.composition:
74
+ if isinstance(dimension, list):
75
+ for identifier in dimension:
76
+ # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
77
+ assert isinstance(identifier, str)
78
+ identifier_dim_map[identifier] = (first_class_dims[dims_i],)
79
+ dims_i += 1
80
+ if not dimension:
81
+ # unitary anonymous axis
82
+ anon_axis = AnonymousAxis("1")
83
+ identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
84
+ anon_axes.append(anon_axis)
85
+ dimension.append(anon_axis)
86
+ dims_i += 1
87
+ elif dimension == _ellipsis:
88
+ identifier = _ellipsis
89
+ identifier_dim_map[identifier] = tuple(
90
+ first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
91
+ )
92
+ dims_i += n_ellipsis_dims
93
+ else:
94
+ raise ValueError(f"Unexpected dimension: {dimension}")
95
+
96
+ def composition_to_dims(
97
+ composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
98
+ ) -> List[Union[str, Tuple[str, ...]]]:
99
+ """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
100
+ class dims."""
101
+ dim_composition: List[Union[str, Tuple[str, ...]]] = []
102
+ for dimension in composition:
103
+ if isinstance(dimension, list):
104
+ dim_composition.append(
105
+ tuple(
106
+ dim
107
+ for identifier in dimension
108
+ for dim in identifier_dim_map[identifier]
109
+ )
110
+ )
111
+ elif dimension == _ellipsis:
112
+ dim_composition.extend(identifier_dim_map[_ellipsis])
113
+ else:
114
+ raise ValueError(f"Unexpected dimension: {dimension}")
115
+ return dim_composition
116
+
117
+ left_dims = composition_to_dims(left.composition)
118
+ right_dims = composition_to_dims(right.composition)
119
+ anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
120
+ specified_lengths = tuple(
121
+ (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
122
+ )
123
+
124
+ custom_rearrange_callable_name = "do_rearrange"
125
+ custom_rearrange_callable_code = (
126
+ (
127
+ f"def {custom_rearrange_callable_name}(tensor):\n"
128
+ f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
129
+ )
130
+ + (
131
+ "".join(
132
+ f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
133
+ )
134
+ if specified_lengths
135
+ else ""
136
+ )
137
+ + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
138
+ + (
139
+ f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
140
+ if anon_dims
141
+ else " return tensor\n"
142
+ )
143
+ )
144
+
145
+ exec(custom_rearrange_callable_code)
146
+ return locals()[custom_rearrange_callable_name]
147
+
148
+
149
+ def rearrange(
150
+ tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
151
+ pattern: str,
152
+ **axes_lengths: int,
153
+ ) -> torch.Tensor:
154
+ r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
155
+ tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
156
+ stack, concatenate and other operations.
157
+
158
+ See: https://einops.rocks/api/rearrange/
159
+
160
+ Args:
161
+ tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
162
+ pattern (str): the rearrangement pattern
163
+ axes_lengths (int): any additional length specifications for dimensions
164
+
165
+ Returns:
166
+ Tensor: the rearranged tensor
167
+
168
+ Examples:
169
+ >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
170
+ >>> images = torch.randn((32, 30, 40, 3))
171
+
172
+ >>> # stack along first (batch) axis, output is a single array
173
+ >>> rearrange(images, 'b h w c -> b h w c').shape
174
+ torch.Size([32, 30, 40, 3])
175
+
176
+ >>> # concatenate images along height (vertical axis), 960 = 32 * 30
177
+ >>> rearrange(images, 'b h w c -> (b h) w c').shape
178
+ torch.Size([960, 40, 3])
179
+
180
+ >>> # concatenated images along horizontal axis, 1280 = 32 * 40
181
+ >>> rearrange(images, 'b h w c -> h (b w) c').shape
182
+ torch.Size([30, 1280, 3])
183
+
184
+ >>> # reordered axes to "b c h w" format for deep learning
185
+ >>> rearrange(images, 'b h w c -> b c h w').shape
186
+ torch.Size([32, 3, 30, 40])
187
+
188
+ >>> # flattened each image into a vector, 3600 = 30 * 40 * 3
189
+ >>> rearrange(images, 'b h w c -> b (c h w)').shape
190
+ torch.Size([32, 3600])
191
+
192
+ >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
193
+ >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
194
+ torch.Size([128, 15, 20, 3])
195
+
196
+ >>> # space-to-depth operation
197
+ >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
198
+ torch.Size([32, 15, 20, 12])
199
+ """
200
+ if not isinstance(tensor, torch.Tensor):
201
+ tensor = torch.stack(tensor)
202
+
203
+ rearrange_callable = _create_rearrange_callable(
204
+ tensor.ndim, pattern, **axes_lengths
205
+ )
206
+
207
+ return rearrange_callable(tensor)
lib/python3.11/site-packages/functorch/experimental/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # PyTorch forward-mode is not mature yet
2
+ from torch._functorch.apis import chunk_vmap
3
+ from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
4
+ from torch._functorch.eager_transforms import hessian, jacfwd, jvp
5
+
6
+ from functorch import functionalize
lib/python3.11/site-packages/functorch/experimental/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (605 Bytes). View file
 
lib/python3.11/site-packages/functorch/experimental/__pycache__/_map.cpython-311.pyc ADDED
Binary file (23.8 kB). View file
 
lib/python3.11/site-packages/functorch/experimental/__pycache__/control_flow.cpython-311.pyc ADDED
Binary file (433 Bytes). View file
 
lib/python3.11/site-packages/functorch/experimental/__pycache__/ops.cpython-311.pyc ADDED
Binary file (300 Bytes). View file
 
lib/python3.11/site-packages/functorch/experimental/_map.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils._pytree as pytree
3
+ from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
4
+ from torch._dispatch.python import suspend_functionalization
5
+ from torch._functorch.aot_autograd import AOTConfig, create_joint
6
+ from torch._functorch.eager_transforms import (
7
+ _unwrap_all_tensors_from_functional,
8
+ _wrap_all_tensors_to_functional,
9
+ functionalize,
10
+ )
11
+
12
+ from torch._higher_order_ops.cond import (
13
+ _has_potential_branch_input_alias,
14
+ _has_potential_branch_input_mutation,
15
+ UnsupportedAliasMutationException,
16
+ )
17
+ from torch._ops import HigherOrderOperator
18
+ from torch._subclasses.fake_tensor import FakeTensorMode
19
+ from torch.fx.experimental.proxy_tensor import (
20
+ disable_proxy_modes_tracing,
21
+ make_fx,
22
+ ProxyTorchDispatchMode,
23
+ track_tensor_tree,
24
+ )
25
+ from torch.multiprocessing.reductions import StorageWeakRef
26
+ from torch.utils._python_dispatch import (
27
+ _get_current_dispatch_mode,
28
+ _pop_mode_temporarily,
29
+ )
30
+
31
+
32
+ # TODO: We add this to prevent dymamo from tracing into map_wrapper,
33
+ # remove the wrapper call when it's ready.
34
+ class MapWrapper(HigherOrderOperator):
35
+ def __call__(self, xs, *args):
36
+ return map_wrapper(xs, *args)
37
+
38
+
39
+ map = MapWrapper("map", _deprecated_global_ns=True)
40
+ map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
41
+
42
+ dummy_aot_config = AOTConfig(
43
+ fw_compiler=None,
44
+ bw_compiler=None,
45
+ partition_fn=None,
46
+ decompositions={},
47
+ num_params_buffers=0,
48
+ aot_id=0,
49
+ keep_inference_input_mutations=False,
50
+ )
51
+
52
+
53
+ def create_fw_bw_graph(f, num_mapped_args, *args):
54
+ mapped_xs = args[:num_mapped_args]
55
+ pos_args = args[num_mapped_args:]
56
+
57
+ # Note: We create "clean" environments for make_fx by suspending all dispatch keys
58
+ # between Autograd and Python key. Currently, we only suspend functionalization but more can be
59
+ # added when required. Will encounter two problems if we don't suspend functionalization:
60
+ #
61
+ # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
62
+ # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
63
+ # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
64
+ # fetch the proxy for the inputs and fail to capture any operations on them.
65
+ #
66
+ # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
67
+ # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
68
+ # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
69
+ # when creating the output node, it fails to associate the wrapped tensor with its proxy.
70
+ # Instead, it will create _tensor_constant as output.
71
+
72
+ with suspend_functionalization():
73
+ with disable_proxy_modes_tracing():
74
+
75
+ def from_fun(t):
76
+ if isinstance(t, torch.Tensor):
77
+ if t.dtype != torch.bool:
78
+ return torch.empty_strided(
79
+ t.size(),
80
+ t.stride(),
81
+ dtype=t.dtype,
82
+ requires_grad=t.requires_grad,
83
+ )
84
+ else:
85
+ return t.clone()
86
+ return t
87
+
88
+ example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
89
+ example_pos_args = [
90
+ from_fun(arg) if isinstance(arg, torch.Tensor) else arg
91
+ for arg in pos_args
92
+ ]
93
+ example_flat_out = pytree.tree_map(
94
+ from_fun, f(*example_xs, *example_pos_args)
95
+ )
96
+ if any(
97
+ not isinstance(out, torch.Tensor)
98
+ for out in example_flat_out
99
+ if out is not None
100
+ ):
101
+ raise RuntimeError(
102
+ "Expect outputs of map only contains tensors or None. "
103
+ f"Got types {[type(out) for out in example_flat_out]}."
104
+ )
105
+ example_grad = [from_fun(out) for out in example_flat_out]
106
+
107
+ fw_graph = make_fx(f)(*example_xs, *example_pos_args)
108
+
109
+ def joint_f(*example_args):
110
+ joint_mapped_args = example_args[:joint_num_mapped]
111
+ args = example_args[joint_num_mapped:]
112
+
113
+ mapped_input = joint_mapped_args[:num_mapped_args]
114
+ mapped_grads = joint_mapped_args[num_mapped_args:]
115
+
116
+ def fw_with_masks(*args):
117
+ fw_out = f(*args)
118
+ return fw_out, [
119
+ True
120
+ if isinstance(ret, torch.Tensor) and ret.requires_grad
121
+ else False
122
+ for ret in fw_out
123
+ ]
124
+
125
+ joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
126
+ _, grads = joint(
127
+ list(mapped_input) + list(args),
128
+ [
129
+ grad
130
+ for grad in mapped_grads
131
+ if grad is not None and grad.requires_grad
132
+ ],
133
+ )
134
+
135
+ # In order to keep map functional for backward graph,
136
+ # we clone outputs that are aliasing inputs
137
+ input_storage = {
138
+ StorageWeakRef(arg._typed_storage())
139
+ for arg in example_args
140
+ if isinstance(arg, torch.Tensor)
141
+ }
142
+
143
+ def maybe_clone(t):
144
+ if (
145
+ isinstance(t, torch.Tensor)
146
+ and StorageWeakRef(t._typed_storage()) in input_storage
147
+ ):
148
+ return t.clone()
149
+ return t
150
+
151
+ return pytree.tree_map(maybe_clone, grads)
152
+
153
+ joint_num_mapped = len(example_grad) + len(example_xs)
154
+ joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
155
+ return fw_graph, joint_graph
156
+
157
+
158
+ def map_wrapper(f, xs, *args):
159
+ flat_xs, xs_spec = pytree.tree_flatten(xs)
160
+ if not all(isinstance(t, torch.Tensor) for t in flat_xs):
161
+ raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
162
+
163
+ num_mapped_args = len(flat_xs)
164
+ shapes = [xs.shape for xs in flat_xs]
165
+ leading_dim_size = shapes[0][0]
166
+ if leading_dim_size == 0:
167
+ raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
168
+
169
+ if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
170
+ raise RuntimeError(
171
+ f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
172
+ )
173
+
174
+ out_spec = None
175
+
176
+ def flat_fn(*flat_args):
177
+ xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec)
178
+ unflattened_out = f(xs, *flat_args[num_mapped_args:])
179
+ flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
180
+
181
+ nonlocal out_spec
182
+ out_spec = tmp_out_spec
183
+ return flat_out
184
+
185
+ return pytree.tree_unflatten(
186
+ map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec
187
+ )
188
+
189
+
190
+ class MapAutogradOp(torch.autograd.Function):
191
+ @staticmethod
192
+ def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
193
+ ctx.save_for_backward(*flat_args)
194
+ ctx._joint_graph = joint_graph
195
+ ctx._num_mapped_args = num_mapped_args
196
+ with torch._C._AutoDispatchBelowAutograd():
197
+ return (*map_impl(fw_graph, num_mapped_args, *flat_args),)
198
+
199
+ @staticmethod
200
+ def backward(ctx, *flat_grads):
201
+ fw_args = ctx.saved_tensors
202
+ fw_mapped_args = fw_args[: ctx._num_mapped_args]
203
+ pos_args = fw_args[ctx._num_mapped_args :]
204
+
205
+ grads = map_impl(
206
+ ctx._joint_graph,
207
+ ctx._num_mapped_args + len(flat_grads),
208
+ *fw_mapped_args,
209
+ *flat_grads,
210
+ *pos_args,
211
+ )
212
+ return None, None, None, *grads
213
+
214
+
215
+ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
216
+ xs = list(args[:num_mapped])
217
+ pos_args = list(args[num_mapped:])
218
+ leading_dim_size = xs[0].shape[0]
219
+
220
+ example_input = _unstack_pytree(xs)[0]
221
+ body_graph = f
222
+ if not isinstance(body_graph, torch.fx.GraphModule):
223
+ body_graph = make_fx(body_graph)(*example_input, *pos_args)
224
+
225
+ with disable_proxy_modes_tracing():
226
+ example_outs = body_graph(*example_input, *pos_args)
227
+
228
+ def expand_tensor(t):
229
+ if isinstance(t, torch.Tensor):
230
+ return t.expand(leading_dim_size, *t.shape)
231
+ return t
232
+
233
+ expanded_outs = pytree.tree_map(expand_tensor, example_outs)
234
+
235
+ next_name = None
236
+ i = 0
237
+ while not next_name:
238
+ candidate = f"body_graph_{i}"
239
+ if hasattr(proxy_mode.tracer.root, candidate):
240
+ i += 1
241
+ else:
242
+ next_name = candidate
243
+
244
+ proxy_mode.tracer.root.register_module(next_name, body_graph)
245
+ node_args = (body_graph, num_mapped, *args)
246
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
247
+ out_proxy = proxy_mode.tracer.create_proxy(
248
+ "call_function", func_overload, proxy_args, {}, name="map_impl"
249
+ )
250
+ return track_tensor_tree(
251
+ expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
252
+ )
253
+
254
+
255
+ def _unstack_pytree(xs):
256
+ flat_xs, inspec = pytree.tree_flatten(xs)
257
+ if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
258
+ raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
259
+
260
+ if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
261
+ raise RuntimeError(
262
+ f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
263
+ )
264
+
265
+ a = zip(*flat_xs)
266
+ pytrees = []
267
+ for tuple in a:
268
+ pytrees.append(pytree.tree_unflatten(tuple, inspec))
269
+ return pytrees
270
+
271
+
272
+ def _stack_pytree(pytrees):
273
+ flat_out = []
274
+ out_spec = None
275
+ for pt in pytrees:
276
+ flat_pt, out_spec = pytree.tree_flatten(pt)
277
+ flat_out.append(flat_pt)
278
+ b = zip(*flat_out)
279
+ stacked_out = []
280
+ for leaves in b:
281
+ if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
282
+ stacked_out.append(torch.stack(leaves))
283
+ elif all(leaf is None for leaf in leaves):
284
+ # Backward graph can return None output when forward inputs doesn't require grad.
285
+ # When we eagerly execute backward graph, we need to call _stack_pytree on its output,
286
+ # therefore we need to deal with None output.
287
+ stacked_out.append(None)
288
+ else:
289
+ raise RuntimeError(f"Cannot stack {leaves}.")
290
+ return pytree.tree_unflatten(stacked_out, out_spec)
291
+
292
+
293
+ @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
294
+ def map_dense(f, num_mapped_args, *args):
295
+ xs = args[:num_mapped_args]
296
+ pos_args = args[num_mapped_args:]
297
+ pytrees = []
298
+ for inp in _unstack_pytree(xs):
299
+ pytrees.append(f(*inp, *pos_args))
300
+ return _stack_pytree(pytrees)
301
+
302
+
303
+ @map_impl.py_impl(DispatchKey.Autograd)
304
+ def map_autograd(f, num_mapped_args, *args):
305
+ fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
306
+ flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
307
+ return flat_out
308
+
309
+
310
+ @map_impl.py_impl(ProxyTorchDispatchMode)
311
+ def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
312
+ mode = _get_current_dispatch_mode()
313
+ assert mode is not None, "Mode should always be enabled for python fallback key"
314
+ with _pop_mode_temporarily() as mode:
315
+ if mode.enable_tracing:
316
+ return trace_map(mode, map_impl, f, num_mapped, *args)
317
+ else:
318
+ return map_impl(f, num_mapped, *args)
319
+
320
+
321
+ @map_impl.py_impl(FakeTensorMode)
322
+ def map_fake_tensor_mode(f, num_mapped, *args):
323
+ return map_dense(f, num_mapped, *args)
324
+
325
+
326
+ @map_impl.py_impl(DispatchKey.Functionalize)
327
+ def map_func(f, num_mapped, *args):
328
+ reapply_views = torch._C._functionalization_reapply_views_tls()
329
+ xs = args[:num_mapped]
330
+ pos_args = args[num_mapped:]
331
+ unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
332
+ unwrapped_args = _unwrap_all_tensors_from_functional(
333
+ pos_args, reapply_views=reapply_views
334
+ )
335
+ mode = "mutations_and_views" if reapply_views else "mutations"
336
+
337
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
338
+ functional_map_fn = functionalize(f, remove=mode)
339
+ with disable_proxy_modes_tracing():
340
+ example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
341
+
342
+ if _has_potential_branch_input_mutation(f, example_inputs):
343
+ raise UnsupportedAliasMutationException("torch.map is mutating the input!")
344
+
345
+ if _has_potential_branch_input_alias(f, example_inputs):
346
+ raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
347
+
348
+ map_return = map_impl(
349
+ functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
350
+ )
351
+ return _wrap_all_tensors_to_functional(map_return, level=0)
352
+
353
+
354
+ @map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
355
+ def map_functionalize(interpreter, f, num_mapped, *args):
356
+ """
357
+ Functionalization implementation for torch.map. Currently:
358
+ 1. We don't allow any input mutation inside the map function
359
+ 2. Our check for above condition is not exhaustive
360
+ """
361
+ xs = args[:num_mapped]
362
+ pos_args = args[num_mapped:]
363
+ reapply_views = interpreter.functionalize_add_back_views()
364
+ mode = "mutations_and_views" if reapply_views else "mutations"
365
+ # At this point, we will see functionalized tensors, so need to unwrap them first
366
+ unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
367
+ unwrapped_args = _unwrap_all_tensors_from_functional(
368
+ pos_args, reapply_views=reapply_views
369
+ )
370
+
371
+ functional_map_fn = functionalize(f, remove=mode)
372
+
373
+ with interpreter.lower():
374
+ with disable_proxy_modes_tracing():
375
+ example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
376
+ if _has_potential_branch_input_mutation(f, example_inputs):
377
+ raise UnsupportedAliasMutationException("torch.map is mutating the input!")
378
+
379
+ if _has_potential_branch_input_alias(f, example_inputs):
380
+ raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
381
+
382
+ map_return = map_impl(
383
+ functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
384
+ )
385
+ return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
386
+
387
+
388
+ # TODO(voz) Make this automatic for keys, this is very ugly atm
389
+ map_impl.fallthrough(DispatchKey.PythonDispatcher)
390
+ map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)
391
+ map_impl.fallthrough(DispatchKey.ADInplaceOrView)
392
+ map_impl.fallthrough(DispatchKey.BackendSelect)
393
+ map_impl.fallthrough(DispatchKey.AutocastCPU)
lib/python3.11/site-packages/functorch/experimental/control_flow.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from torch._higher_order_ops.cond import ( # noqa: F401
2
+ cond,
3
+ UnsupportedAliasMutationException,
4
+ )
5
+
6
+ from ._map import map # noqa: F401
lib/python3.11/site-packages/functorch/experimental/ops.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torch._ops import HigherOrderOperator # noqa: F401
lib/python3.11/site-packages/huggingface_hub/__init__.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # ***********
16
+ # `huggingface_hub` init has 2 modes:
17
+ # - Normal usage:
18
+ # If imported to use it, all modules and functions are lazy-loaded. This means
19
+ # they exist at top level in module but are imported only the first time they are
20
+ # used. This way, `from huggingface_hub import something` will import `something`
21
+ # quickly without the hassle of importing all the features from `huggingface_hub`.
22
+ # - Static check:
23
+ # If statically analyzed, all modules and functions are loaded normally. This way
24
+ # static typing check works properly as well as autocomplete in text editors and
25
+ # IDEs.
26
+ #
27
+ # The static model imports are done inside the `if TYPE_CHECKING:` statement at
28
+ # the bottom of this file. Since module/functions imports are duplicated, it is
29
+ # mandatory to make sure to add them twice when adding one. This is checked in the
30
+ # `make quality` command.
31
+ #
32
+ # To update the static imports, please run the following command and commit the changes.
33
+ # ```
34
+ # # Use script
35
+ # python utils/check_static_imports.py --update-file
36
+ #
37
+ # # Or run style on codebase
38
+ # make style
39
+ # ```
40
+ #
41
+ # ***********
42
+ # Lazy loader vendored from https://github.com/scientific-python/lazy_loader
43
+ import importlib
44
+ import os
45
+ import sys
46
+ from typing import TYPE_CHECKING
47
+
48
+
49
+ __version__ = "0.20.2"
50
+
51
+ # Alphabetical order of definitions is ensured in tests
52
+ # WARNING: any comment added in this dictionary definition will be lost when
53
+ # re-generating the file !
54
+ _SUBMOD_ATTRS = {
55
+ "_commit_scheduler": [
56
+ "CommitScheduler",
57
+ ],
58
+ "_inference_endpoints": [
59
+ "InferenceEndpoint",
60
+ "InferenceEndpointError",
61
+ "InferenceEndpointStatus",
62
+ "InferenceEndpointTimeoutError",
63
+ "InferenceEndpointType",
64
+ ],
65
+ "_login": [
66
+ "interpreter_login",
67
+ "login",
68
+ "logout",
69
+ "notebook_login",
70
+ ],
71
+ "_multi_commits": [
72
+ "MultiCommitException",
73
+ "plan_multi_commits",
74
+ ],
75
+ "_snapshot_download": [
76
+ "snapshot_download",
77
+ ],
78
+ "_space_api": [
79
+ "SpaceHardware",
80
+ "SpaceRuntime",
81
+ "SpaceStage",
82
+ "SpaceStorage",
83
+ "SpaceVariable",
84
+ ],
85
+ "_tensorboard_logger": [
86
+ "HFSummaryWriter",
87
+ ],
88
+ "_webhooks_payload": [
89
+ "WebhookPayload",
90
+ "WebhookPayloadComment",
91
+ "WebhookPayloadDiscussion",
92
+ "WebhookPayloadDiscussionChanges",
93
+ "WebhookPayloadEvent",
94
+ "WebhookPayloadMovedTo",
95
+ "WebhookPayloadRepo",
96
+ "WebhookPayloadUrl",
97
+ "WebhookPayloadWebhook",
98
+ ],
99
+ "_webhooks_server": [
100
+ "WebhooksServer",
101
+ "webhook_endpoint",
102
+ ],
103
+ "community": [
104
+ "Discussion",
105
+ "DiscussionComment",
106
+ "DiscussionCommit",
107
+ "DiscussionEvent",
108
+ "DiscussionStatusChange",
109
+ "DiscussionTitleChange",
110
+ "DiscussionWithDetails",
111
+ ],
112
+ "constants": [
113
+ "CONFIG_NAME",
114
+ "FLAX_WEIGHTS_NAME",
115
+ "HUGGINGFACE_CO_URL_HOME",
116
+ "HUGGINGFACE_CO_URL_TEMPLATE",
117
+ "PYTORCH_WEIGHTS_NAME",
118
+ "REPO_TYPE_DATASET",
119
+ "REPO_TYPE_MODEL",
120
+ "REPO_TYPE_SPACE",
121
+ "TF2_WEIGHTS_NAME",
122
+ "TF_WEIGHTS_NAME",
123
+ ],
124
+ "fastai_utils": [
125
+ "_save_pretrained_fastai",
126
+ "from_pretrained_fastai",
127
+ "push_to_hub_fastai",
128
+ ],
129
+ "file_download": [
130
+ "HfFileMetadata",
131
+ "_CACHED_NO_EXIST",
132
+ "cached_download",
133
+ "get_hf_file_metadata",
134
+ "hf_hub_download",
135
+ "hf_hub_url",
136
+ "try_to_load_from_cache",
137
+ ],
138
+ "hf_api": [
139
+ "Collection",
140
+ "CollectionItem",
141
+ "CommitInfo",
142
+ "CommitOperation",
143
+ "CommitOperationAdd",
144
+ "CommitOperationCopy",
145
+ "CommitOperationDelete",
146
+ "GitCommitInfo",
147
+ "GitRefInfo",
148
+ "GitRefs",
149
+ "HfApi",
150
+ "RepoUrl",
151
+ "User",
152
+ "UserLikes",
153
+ "accept_access_request",
154
+ "add_collection_item",
155
+ "add_space_secret",
156
+ "add_space_variable",
157
+ "cancel_access_request",
158
+ "change_discussion_status",
159
+ "comment_discussion",
160
+ "create_branch",
161
+ "create_collection",
162
+ "create_commit",
163
+ "create_commits_on_pr",
164
+ "create_discussion",
165
+ "create_inference_endpoint",
166
+ "create_pull_request",
167
+ "create_repo",
168
+ "create_tag",
169
+ "dataset_info",
170
+ "delete_branch",
171
+ "delete_collection",
172
+ "delete_collection_item",
173
+ "delete_file",
174
+ "delete_folder",
175
+ "delete_inference_endpoint",
176
+ "delete_repo",
177
+ "delete_space_secret",
178
+ "delete_space_storage",
179
+ "delete_space_variable",
180
+ "delete_tag",
181
+ "duplicate_space",
182
+ "edit_discussion_comment",
183
+ "file_exists",
184
+ "get_collection",
185
+ "get_dataset_tags",
186
+ "get_discussion_details",
187
+ "get_full_repo_name",
188
+ "get_inference_endpoint",
189
+ "get_model_tags",
190
+ "get_paths_info",
191
+ "get_repo_discussions",
192
+ "get_safetensors_metadata",
193
+ "get_space_runtime",
194
+ "get_space_variables",
195
+ "get_token_permission",
196
+ "grant_access",
197
+ "like",
198
+ "list_accepted_access_requests",
199
+ "list_collections",
200
+ "list_datasets",
201
+ "list_files_info",
202
+ "list_inference_endpoints",
203
+ "list_liked_repos",
204
+ "list_metrics",
205
+ "list_models",
206
+ "list_pending_access_requests",
207
+ "list_rejected_access_requests",
208
+ "list_repo_commits",
209
+ "list_repo_files",
210
+ "list_repo_likers",
211
+ "list_repo_refs",
212
+ "list_repo_tree",
213
+ "list_spaces",
214
+ "merge_pull_request",
215
+ "model_info",
216
+ "move_repo",
217
+ "parse_safetensors_file_metadata",
218
+ "pause_inference_endpoint",
219
+ "pause_space",
220
+ "preupload_lfs_files",
221
+ "reject_access_request",
222
+ "rename_discussion",
223
+ "repo_exists",
224
+ "repo_info",
225
+ "repo_type_and_id_from_hf_id",
226
+ "request_space_hardware",
227
+ "request_space_storage",
228
+ "restart_space",
229
+ "resume_inference_endpoint",
230
+ "run_as_future",
231
+ "scale_to_zero_inference_endpoint",
232
+ "set_space_sleep_time",
233
+ "space_info",
234
+ "super_squash_history",
235
+ "unlike",
236
+ "update_collection_item",
237
+ "update_collection_metadata",
238
+ "update_inference_endpoint",
239
+ "update_repo_visibility",
240
+ "upload_file",
241
+ "upload_folder",
242
+ "whoami",
243
+ ],
244
+ "hf_file_system": [
245
+ "HfFileSystem",
246
+ "HfFileSystemFile",
247
+ "HfFileSystemResolvedPath",
248
+ ],
249
+ "hub_mixin": [
250
+ "ModelHubMixin",
251
+ "PyTorchModelHubMixin",
252
+ ],
253
+ "inference._client": [
254
+ "InferenceClient",
255
+ "InferenceTimeoutError",
256
+ ],
257
+ "inference._generated._async_client": [
258
+ "AsyncInferenceClient",
259
+ ],
260
+ "inference_api": [
261
+ "InferenceApi",
262
+ ],
263
+ "keras_mixin": [
264
+ "KerasModelHubMixin",
265
+ "from_pretrained_keras",
266
+ "push_to_hub_keras",
267
+ "save_pretrained_keras",
268
+ ],
269
+ "repocard": [
270
+ "DatasetCard",
271
+ "ModelCard",
272
+ "RepoCard",
273
+ "SpaceCard",
274
+ "metadata_eval_result",
275
+ "metadata_load",
276
+ "metadata_save",
277
+ "metadata_update",
278
+ ],
279
+ "repocard_data": [
280
+ "CardData",
281
+ "DatasetCardData",
282
+ "EvalResult",
283
+ "ModelCardData",
284
+ "SpaceCardData",
285
+ ],
286
+ "repository": [
287
+ "Repository",
288
+ ],
289
+ "utils": [
290
+ "CacheNotFound",
291
+ "CachedFileInfo",
292
+ "CachedRepoInfo",
293
+ "CachedRevisionInfo",
294
+ "CorruptedCacheException",
295
+ "DeleteCacheStrategy",
296
+ "HFCacheInfo",
297
+ "HfFolder",
298
+ "cached_assets_path",
299
+ "configure_http_backend",
300
+ "dump_environment_info",
301
+ "get_session",
302
+ "get_token",
303
+ "logging",
304
+ "scan_cache_dir",
305
+ ],
306
+ "utils.endpoint_helpers": [
307
+ "DatasetFilter",
308
+ "ModelFilter",
309
+ ],
310
+ }
311
+
312
+
313
+ def _attach(package_name, submodules=None, submod_attrs=None):
314
+ """Attach lazily loaded submodules, functions, or other attributes.
315
+
316
+ Typically, modules import submodules and attributes as follows:
317
+
318
+ ```py
319
+ import mysubmodule
320
+ import anothersubmodule
321
+
322
+ from .foo import someattr
323
+ ```
324
+
325
+ The idea is to replace a package's `__getattr__`, `__dir__`, and
326
+ `__all__`, such that all imports work exactly the way they would
327
+ with normal imports, except that the import occurs upon first use.
328
+
329
+ The typical way to call this function, replacing the above imports, is:
330
+
331
+ ```python
332
+ __getattr__, __dir__, __all__ = lazy.attach(
333
+ __name__,
334
+ ['mysubmodule', 'anothersubmodule'],
335
+ {'foo': ['someattr']}
336
+ )
337
+ ```
338
+ This functionality requires Python 3.7 or higher.
339
+
340
+ Args:
341
+ package_name (`str`):
342
+ Typically use `__name__`.
343
+ submodules (`set`):
344
+ List of submodules to attach.
345
+ submod_attrs (`dict`):
346
+ Dictionary of submodule -> list of attributes / functions.
347
+ These attributes are imported as they are used.
348
+
349
+ Returns:
350
+ __getattr__, __dir__, __all__
351
+
352
+ """
353
+ if submod_attrs is None:
354
+ submod_attrs = {}
355
+
356
+ if submodules is None:
357
+ submodules = set()
358
+ else:
359
+ submodules = set(submodules)
360
+
361
+ attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs}
362
+
363
+ __all__ = list(submodules | attr_to_modules.keys())
364
+
365
+ def __getattr__(name):
366
+ if name in submodules:
367
+ return importlib.import_module(f"{package_name}.{name}")
368
+ elif name in attr_to_modules:
369
+ submod_path = f"{package_name}.{attr_to_modules[name]}"
370
+ submod = importlib.import_module(submod_path)
371
+ attr = getattr(submod, name)
372
+
373
+ # If the attribute lives in a file (module) with the same
374
+ # name as the attribute, ensure that the attribute and *not*
375
+ # the module is accessible on the package.
376
+ if name == attr_to_modules[name]:
377
+ pkg = sys.modules[package_name]
378
+ pkg.__dict__[name] = attr
379
+
380
+ return attr
381
+ else:
382
+ raise AttributeError(f"No {package_name} attribute {name}")
383
+
384
+ def __dir__():
385
+ return __all__
386
+
387
+ if os.environ.get("EAGER_IMPORT", ""):
388
+ for attr in set(attr_to_modules.keys()) | submodules:
389
+ __getattr__(attr)
390
+
391
+ return __getattr__, __dir__, list(__all__)
392
+
393
+
394
+ __getattr__, __dir__, __all__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS)
395
+
396
+ # WARNING: any content below this statement is generated automatically. Any manual edit
397
+ # will be lost when re-generating this file !
398
+ #
399
+ # To update the static imports, please run the following command and commit the changes.
400
+ # ```
401
+ # # Use script
402
+ # python utils/check_static_imports.py --update-file
403
+ #
404
+ # # Or run style on codebase
405
+ # make style
406
+ # ```
407
+ if TYPE_CHECKING: # pragma: no cover
408
+ from ._commit_scheduler import CommitScheduler # noqa: F401
409
+ from ._inference_endpoints import (
410
+ InferenceEndpoint, # noqa: F401
411
+ InferenceEndpointError, # noqa: F401
412
+ InferenceEndpointStatus, # noqa: F401
413
+ InferenceEndpointTimeoutError, # noqa: F401
414
+ InferenceEndpointType, # noqa: F401
415
+ )
416
+ from ._login import (
417
+ interpreter_login, # noqa: F401
418
+ login, # noqa: F401
419
+ logout, # noqa: F401
420
+ notebook_login, # noqa: F401
421
+ )
422
+ from ._multi_commits import (
423
+ MultiCommitException, # noqa: F401
424
+ plan_multi_commits, # noqa: F401
425
+ )
426
+ from ._snapshot_download import snapshot_download # noqa: F401
427
+ from ._space_api import (
428
+ SpaceHardware, # noqa: F401
429
+ SpaceRuntime, # noqa: F401
430
+ SpaceStage, # noqa: F401
431
+ SpaceStorage, # noqa: F401
432
+ SpaceVariable, # noqa: F401
433
+ )
434
+ from ._tensorboard_logger import HFSummaryWriter # noqa: F401
435
+ from ._webhooks_payload import (
436
+ WebhookPayload, # noqa: F401
437
+ WebhookPayloadComment, # noqa: F401
438
+ WebhookPayloadDiscussion, # noqa: F401
439
+ WebhookPayloadDiscussionChanges, # noqa: F401
440
+ WebhookPayloadEvent, # noqa: F401
441
+ WebhookPayloadMovedTo, # noqa: F401
442
+ WebhookPayloadRepo, # noqa: F401
443
+ WebhookPayloadUrl, # noqa: F401
444
+ WebhookPayloadWebhook, # noqa: F401
445
+ )
446
+ from ._webhooks_server import (
447
+ WebhooksServer, # noqa: F401
448
+ webhook_endpoint, # noqa: F401
449
+ )
450
+ from .community import (
451
+ Discussion, # noqa: F401
452
+ DiscussionComment, # noqa: F401
453
+ DiscussionCommit, # noqa: F401
454
+ DiscussionEvent, # noqa: F401
455
+ DiscussionStatusChange, # noqa: F401
456
+ DiscussionTitleChange, # noqa: F401
457
+ DiscussionWithDetails, # noqa: F401
458
+ )
459
+ from .constants import (
460
+ CONFIG_NAME, # noqa: F401
461
+ FLAX_WEIGHTS_NAME, # noqa: F401
462
+ HUGGINGFACE_CO_URL_HOME, # noqa: F401
463
+ HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401
464
+ PYTORCH_WEIGHTS_NAME, # noqa: F401
465
+ REPO_TYPE_DATASET, # noqa: F401
466
+ REPO_TYPE_MODEL, # noqa: F401
467
+ REPO_TYPE_SPACE, # noqa: F401
468
+ TF2_WEIGHTS_NAME, # noqa: F401
469
+ TF_WEIGHTS_NAME, # noqa: F401
470
+ )
471
+ from .fastai_utils import (
472
+ _save_pretrained_fastai, # noqa: F401
473
+ from_pretrained_fastai, # noqa: F401
474
+ push_to_hub_fastai, # noqa: F401
475
+ )
476
+ from .file_download import (
477
+ _CACHED_NO_EXIST, # noqa: F401
478
+ HfFileMetadata, # noqa: F401
479
+ cached_download, # noqa: F401
480
+ get_hf_file_metadata, # noqa: F401
481
+ hf_hub_download, # noqa: F401
482
+ hf_hub_url, # noqa: F401
483
+ try_to_load_from_cache, # noqa: F401
484
+ )
485
+ from .hf_api import (
486
+ Collection, # noqa: F401
487
+ CollectionItem, # noqa: F401
488
+ CommitInfo, # noqa: F401
489
+ CommitOperation, # noqa: F401
490
+ CommitOperationAdd, # noqa: F401
491
+ CommitOperationCopy, # noqa: F401
492
+ CommitOperationDelete, # noqa: F401
493
+ GitCommitInfo, # noqa: F401
494
+ GitRefInfo, # noqa: F401
495
+ GitRefs, # noqa: F401
496
+ HfApi, # noqa: F401
497
+ RepoUrl, # noqa: F401
498
+ User, # noqa: F401
499
+ UserLikes, # noqa: F401
500
+ accept_access_request, # noqa: F401
501
+ add_collection_item, # noqa: F401
502
+ add_space_secret, # noqa: F401
503
+ add_space_variable, # noqa: F401
504
+ cancel_access_request, # noqa: F401
505
+ change_discussion_status, # noqa: F401
506
+ comment_discussion, # noqa: F401
507
+ create_branch, # noqa: F401
508
+ create_collection, # noqa: F401
509
+ create_commit, # noqa: F401
510
+ create_commits_on_pr, # noqa: F401
511
+ create_discussion, # noqa: F401
512
+ create_inference_endpoint, # noqa: F401
513
+ create_pull_request, # noqa: F401
514
+ create_repo, # noqa: F401
515
+ create_tag, # noqa: F401
516
+ dataset_info, # noqa: F401
517
+ delete_branch, # noqa: F401
518
+ delete_collection, # noqa: F401
519
+ delete_collection_item, # noqa: F401
520
+ delete_file, # noqa: F401
521
+ delete_folder, # noqa: F401
522
+ delete_inference_endpoint, # noqa: F401
523
+ delete_repo, # noqa: F401
524
+ delete_space_secret, # noqa: F401
525
+ delete_space_storage, # noqa: F401
526
+ delete_space_variable, # noqa: F401
527
+ delete_tag, # noqa: F401
528
+ duplicate_space, # noqa: F401
529
+ edit_discussion_comment, # noqa: F401
530
+ file_exists, # noqa: F401
531
+ get_collection, # noqa: F401
532
+ get_dataset_tags, # noqa: F401
533
+ get_discussion_details, # noqa: F401
534
+ get_full_repo_name, # noqa: F401
535
+ get_inference_endpoint, # noqa: F401
536
+ get_model_tags, # noqa: F401
537
+ get_paths_info, # noqa: F401
538
+ get_repo_discussions, # noqa: F401
539
+ get_safetensors_metadata, # noqa: F401
540
+ get_space_runtime, # noqa: F401
541
+ get_space_variables, # noqa: F401
542
+ get_token_permission, # noqa: F401
543
+ grant_access, # noqa: F401
544
+ like, # noqa: F401
545
+ list_accepted_access_requests, # noqa: F401
546
+ list_collections, # noqa: F401
547
+ list_datasets, # noqa: F401
548
+ list_files_info, # noqa: F401
549
+ list_inference_endpoints, # noqa: F401
550
+ list_liked_repos, # noqa: F401
551
+ list_metrics, # noqa: F401
552
+ list_models, # noqa: F401
553
+ list_pending_access_requests, # noqa: F401
554
+ list_rejected_access_requests, # noqa: F401
555
+ list_repo_commits, # noqa: F401
556
+ list_repo_files, # noqa: F401
557
+ list_repo_likers, # noqa: F401
558
+ list_repo_refs, # noqa: F401
559
+ list_repo_tree, # noqa: F401
560
+ list_spaces, # noqa: F401
561
+ merge_pull_request, # noqa: F401
562
+ model_info, # noqa: F401
563
+ move_repo, # noqa: F401
564
+ parse_safetensors_file_metadata, # noqa: F401
565
+ pause_inference_endpoint, # noqa: F401
566
+ pause_space, # noqa: F401
567
+ preupload_lfs_files, # noqa: F401
568
+ reject_access_request, # noqa: F401
569
+ rename_discussion, # noqa: F401
570
+ repo_exists, # noqa: F401
571
+ repo_info, # noqa: F401
572
+ repo_type_and_id_from_hf_id, # noqa: F401
573
+ request_space_hardware, # noqa: F401
574
+ request_space_storage, # noqa: F401
575
+ restart_space, # noqa: F401
576
+ resume_inference_endpoint, # noqa: F401
577
+ run_as_future, # noqa: F401
578
+ scale_to_zero_inference_endpoint, # noqa: F401
579
+ set_space_sleep_time, # noqa: F401
580
+ space_info, # noqa: F401
581
+ super_squash_history, # noqa: F401
582
+ unlike, # noqa: F401
583
+ update_collection_item, # noqa: F401
584
+ update_collection_metadata, # noqa: F401
585
+ update_inference_endpoint, # noqa: F401
586
+ update_repo_visibility, # noqa: F401
587
+ upload_file, # noqa: F401
588
+ upload_folder, # noqa: F401
589
+ whoami, # noqa: F401
590
+ )
591
+ from .hf_file_system import (
592
+ HfFileSystem, # noqa: F401
593
+ HfFileSystemFile, # noqa: F401
594
+ HfFileSystemResolvedPath, # noqa: F401
595
+ )
596
+ from .hub_mixin import (
597
+ ModelHubMixin, # noqa: F401
598
+ PyTorchModelHubMixin, # noqa: F401
599
+ )
600
+ from .inference._client import (
601
+ InferenceClient, # noqa: F401
602
+ InferenceTimeoutError, # noqa: F401
603
+ )
604
+ from .inference._generated._async_client import AsyncInferenceClient # noqa: F401
605
+ from .inference_api import InferenceApi # noqa: F401
606
+ from .keras_mixin import (
607
+ KerasModelHubMixin, # noqa: F401
608
+ from_pretrained_keras, # noqa: F401
609
+ push_to_hub_keras, # noqa: F401
610
+ save_pretrained_keras, # noqa: F401
611
+ )
612
+ from .repocard import (
613
+ DatasetCard, # noqa: F401
614
+ ModelCard, # noqa: F401
615
+ RepoCard, # noqa: F401
616
+ SpaceCard, # noqa: F401
617
+ metadata_eval_result, # noqa: F401
618
+ metadata_load, # noqa: F401
619
+ metadata_save, # noqa: F401
620
+ metadata_update, # noqa: F401
621
+ )
622
+ from .repocard_data import (
623
+ CardData, # noqa: F401
624
+ DatasetCardData, # noqa: F401
625
+ EvalResult, # noqa: F401
626
+ ModelCardData, # noqa: F401
627
+ SpaceCardData, # noqa: F401
628
+ )
629
+ from .repository import Repository # noqa: F401
630
+ from .utils import (
631
+ CachedFileInfo, # noqa: F401
632
+ CachedRepoInfo, # noqa: F401
633
+ CachedRevisionInfo, # noqa: F401
634
+ CacheNotFound, # noqa: F401
635
+ CorruptedCacheException, # noqa: F401
636
+ DeleteCacheStrategy, # noqa: F401
637
+ HFCacheInfo, # noqa: F401
638
+ HfFolder, # noqa: F401
639
+ cached_assets_path, # noqa: F401
640
+ configure_http_backend, # noqa: F401
641
+ dump_environment_info, # noqa: F401
642
+ get_session, # noqa: F401
643
+ get_token, # noqa: F401
644
+ logging, # noqa: F401
645
+ scan_cache_dir, # noqa: F401
646
+ )
647
+ from .utils.endpoint_helpers import (
648
+ DatasetFilter, # noqa: F401
649
+ ModelFilter, # noqa: F401
650
+ )
lib/python3.11/site-packages/huggingface_hub/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_api.cpython-311.pyc ADDED
Binary file (33.8 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_scheduler.cpython-311.pyc ADDED
Binary file (18.6 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_inference_endpoints.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_login.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_multi_commits.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_snapshot_download.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_space_api.cpython-311.pyc ADDED
Binary file (6.64 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_tensorboard_logger.cpython-311.pyc ADDED
Binary file (7.74 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_payload.cpython-311.pyc ADDED
Binary file (4.71 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_server.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/community.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/constants.cpython-311.pyc ADDED
Binary file (7.69 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/fastai_utils.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/file_download.cpython-311.pyc ADDED
Binary file (75.4 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_api.cpython-311.pyc ADDED
Binary file (375 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_file_system.cpython-311.pyc ADDED
Binary file (35.4 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/hub_mixin.cpython-311.pyc ADDED
Binary file (18.6 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/inference_api.cpython-311.pyc ADDED
Binary file (9.45 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/keras_mixin.cpython-311.pyc ADDED
Binary file (21.6 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/lfs.cpython-311.pyc ADDED
Binary file (27.1 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard.cpython-311.pyc ADDED
Binary file (37.7 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard_data.cpython-311.pyc ADDED
Binary file (34.5 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/__pycache__/repository.cpython-311.pyc ADDED
Binary file (72.1 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/_commit_api.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Type definitions and utilities for the `create_commit` API
3
+ """
4
+ import base64
5
+ import io
6
+ import os
7
+ import warnings
8
+ from collections import defaultdict
9
+ from contextlib import contextmanager
10
+ from dataclasses import dataclass, field
11
+ from itertools import groupby
12
+ from pathlib import Path, PurePosixPath
13
+ from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union
14
+
15
+ from tqdm.contrib.concurrent import thread_map
16
+
17
+ from huggingface_hub import get_session
18
+
19
+ from .constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER
20
+ from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info
21
+ from .utils import (
22
+ EntryNotFoundError,
23
+ build_hf_headers,
24
+ chunk_iterable,
25
+ hf_raise_for_status,
26
+ logging,
27
+ tqdm_stream_file,
28
+ validate_hf_hub_args,
29
+ )
30
+ from .utils import tqdm as hf_tqdm
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from .hf_api import RepoFile
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ UploadMode = Literal["lfs", "regular"]
41
+
42
+ # Max is 1,000 per request on the Hub for HfApi.get_paths_info
43
+ # Otherwise we get:
44
+ # HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters
45
+ # See https://github.com/huggingface/huggingface_hub/issues/1503
46
+ FETCH_LFS_BATCH_SIZE = 500
47
+
48
+
49
+ @dataclass
50
+ class CommitOperationDelete:
51
+ """
52
+ Data structure holding necessary info to delete a file or a folder from a repository
53
+ on the Hub.
54
+
55
+ Args:
56
+ path_in_repo (`str`):
57
+ Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
58
+ for a file or `"checkpoints/1fec34a/"` for a folder.
59
+ is_folder (`bool` or `Literal["auto"]`, *optional*)
60
+ Whether the Delete Operation applies to a folder or not. If "auto", the path
61
+ type (file or folder) is guessed automatically by looking if path ends with
62
+ a "/" (folder) or not (file). To explicitly set the path type, you can set
63
+ `is_folder=True` or `is_folder=False`.
64
+ """
65
+
66
+ path_in_repo: str
67
+ is_folder: Union[bool, Literal["auto"]] = "auto"
68
+
69
+ def __post_init__(self):
70
+ self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
71
+
72
+ if self.is_folder == "auto":
73
+ self.is_folder = self.path_in_repo.endswith("/")
74
+ if not isinstance(self.is_folder, bool):
75
+ raise ValueError(
76
+ f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'."
77
+ )
78
+
79
+
80
+ @dataclass
81
+ class CommitOperationCopy:
82
+ """
83
+ Data structure holding necessary info to copy a file in a repository on the Hub.
84
+
85
+ Limitations:
86
+ - Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it
87
+ - Cross-repository copies are not supported.
88
+
89
+ Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub.
90
+
91
+ Args:
92
+ src_path_in_repo (`str`):
93
+ Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`.
94
+ path_in_repo (`str`):
95
+ Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`.
96
+ src_revision (`str`, *optional*):
97
+ The git revision of the file to be copied. Can be any valid git revision.
98
+ Default to the target commit revision.
99
+ """
100
+
101
+ src_path_in_repo: str
102
+ path_in_repo: str
103
+ src_revision: Optional[str] = None
104
+
105
+ def __post_init__(self):
106
+ self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo)
107
+ self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
108
+
109
+
110
+ @dataclass
111
+ class CommitOperationAdd:
112
+ """
113
+ Data structure holding necessary info to upload a file to a repository on the Hub.
114
+
115
+ Args:
116
+ path_in_repo (`str`):
117
+ Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
118
+ path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`):
119
+ Either:
120
+ - a path to a local file (as `str` or `pathlib.Path`) to upload
121
+ - a buffer of bytes (`bytes`) holding the content of the file to upload
122
+ - a "file object" (subclass of `io.BufferedIOBase`), typically obtained
123
+ with `open(path, "rb")`. It must support `seek()` and `tell()` methods.
124
+
125
+ Raises:
126
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
127
+ If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`.
128
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
129
+ If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file.
130
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
131
+ If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both
132
+ `seek()` and `tell()`.
133
+ """
134
+
135
+ path_in_repo: str
136
+ path_or_fileobj: Union[str, Path, bytes, BinaryIO]
137
+ upload_info: UploadInfo = field(init=False, repr=False)
138
+
139
+ # Internal attributes
140
+
141
+ # set to "lfs" or "regular" once known
142
+ _upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None)
143
+
144
+ # set to True if .gitignore rules prevent the file from being uploaded as LFS
145
+ # (server-side check)
146
+ _should_ignore: Optional[bool] = field(init=False, repr=False, default=None)
147
+
148
+ # set to True once the file has been uploaded as LFS
149
+ _is_uploaded: bool = field(init=False, repr=False, default=False)
150
+
151
+ # set to True once the file has been committed
152
+ _is_committed: bool = field(init=False, repr=False, default=False)
153
+
154
+ def __post_init__(self) -> None:
155
+ """Validates `path_or_fileobj` and compute `upload_info`."""
156
+ self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
157
+
158
+ # Validate `path_or_fileobj` value
159
+ if isinstance(self.path_or_fileobj, Path):
160
+ self.path_or_fileobj = str(self.path_or_fileobj)
161
+ if isinstance(self.path_or_fileobj, str):
162
+ path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj))
163
+ if not os.path.isfile(path_or_fileobj):
164
+ raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system")
165
+ elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
166
+ # ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode
167
+ raise ValueError(
168
+ "path_or_fileobj must be either an instance of str, bytes or"
169
+ " io.BufferedIOBase. If you passed a file-like object, make sure it is"
170
+ " in binary mode."
171
+ )
172
+ if isinstance(self.path_or_fileobj, io.BufferedIOBase):
173
+ try:
174
+ self.path_or_fileobj.tell()
175
+ self.path_or_fileobj.seek(0, os.SEEK_CUR)
176
+ except (OSError, AttributeError) as exc:
177
+ raise ValueError(
178
+ "path_or_fileobj is a file-like object but does not implement seek() and tell()"
179
+ ) from exc
180
+
181
+ # Compute "upload_info" attribute
182
+ if isinstance(self.path_or_fileobj, str):
183
+ self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
184
+ elif isinstance(self.path_or_fileobj, bytes):
185
+ self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
186
+ else:
187
+ self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
188
+
189
+ @contextmanager
190
+ def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]:
191
+ """
192
+ A context manager that yields a file-like object allowing to read the underlying
193
+ data behind `path_or_fileobj`.
194
+
195
+ Args:
196
+ with_tqdm (`bool`, *optional*, defaults to `False`):
197
+ If True, iterating over the file object will display a progress bar. Only
198
+ works if the file-like object is a path to a file. Pure bytes and buffers
199
+ are not supported.
200
+
201
+ Example:
202
+
203
+ ```python
204
+ >>> operation = CommitOperationAdd(
205
+ ... path_in_repo="remote/dir/weights.h5",
206
+ ... path_or_fileobj="./local/weights.h5",
207
+ ... )
208
+ CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5')
209
+
210
+ >>> with operation.as_file() as file:
211
+ ... content = file.read()
212
+
213
+ >>> with operation.as_file(with_tqdm=True) as file:
214
+ ... while True:
215
+ ... data = file.read(1024)
216
+ ... if not data:
217
+ ... break
218
+ config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
219
+
220
+ >>> with operation.as_file(with_tqdm=True) as file:
221
+ ... requests.put(..., data=file)
222
+ config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
223
+ ```
224
+ """
225
+ if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path):
226
+ if with_tqdm:
227
+ with tqdm_stream_file(self.path_or_fileobj) as file:
228
+ yield file
229
+ else:
230
+ with open(self.path_or_fileobj, "rb") as file:
231
+ yield file
232
+ elif isinstance(self.path_or_fileobj, bytes):
233
+ yield io.BytesIO(self.path_or_fileobj)
234
+ elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
235
+ prev_pos = self.path_or_fileobj.tell()
236
+ yield self.path_or_fileobj
237
+ self.path_or_fileobj.seek(prev_pos, io.SEEK_SET)
238
+
239
+ def b64content(self) -> bytes:
240
+ """
241
+ The base64-encoded content of `path_or_fileobj`
242
+
243
+ Returns: `bytes`
244
+ """
245
+ with self.as_file() as file:
246
+ return base64.b64encode(file.read())
247
+
248
+
249
+ def _validate_path_in_repo(path_in_repo: str) -> str:
250
+ # Validate `path_in_repo` value to prevent a server-side issue
251
+ if path_in_repo.startswith("/"):
252
+ path_in_repo = path_in_repo[1:]
253
+ if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"):
254
+ raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'")
255
+ if path_in_repo.startswith("./"):
256
+ path_in_repo = path_in_repo[2:]
257
+ if any(part == ".git" for part in path_in_repo.split("/")):
258
+ raise ValueError(
259
+ "Invalid `path_in_repo` in CommitOperation: cannot update files under a '.git/' folder (path:"
260
+ f" '{path_in_repo}')."
261
+ )
262
+ return path_in_repo
263
+
264
+
265
+ CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete]
266
+
267
+
268
+ def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None:
269
+ """
270
+ Warn user when a list of operations is expected to overwrite itself in a single
271
+ commit.
272
+
273
+ Rules:
274
+ - If a filepath is updated by multiple `CommitOperationAdd` operations, a warning
275
+ message is triggered.
276
+ - If a filepath is updated at least once by a `CommitOperationAdd` and then deleted
277
+ by a `CommitOperationDelete`, a warning is triggered.
278
+ - If a `CommitOperationDelete` deletes a filepath that is then updated by a
279
+ `CommitOperationAdd`, no warning is triggered. This is usually useless (no need to
280
+ delete before upload) but can happen if a user deletes an entire folder and then
281
+ add new files to it.
282
+ """
283
+ nb_additions_per_path: Dict[str, int] = defaultdict(int)
284
+ for operation in operations:
285
+ path_in_repo = operation.path_in_repo
286
+ if isinstance(operation, CommitOperationAdd):
287
+ if nb_additions_per_path[path_in_repo] > 0:
288
+ warnings.warn(
289
+ "About to update multiple times the same file in the same commit:"
290
+ f" '{path_in_repo}'. This can cause undesired inconsistencies in"
291
+ " your repo."
292
+ )
293
+ nb_additions_per_path[path_in_repo] += 1
294
+ for parent in PurePosixPath(path_in_repo).parents:
295
+ # Also keep track of number of updated files per folder
296
+ # => warns if deleting a folder overwrite some contained files
297
+ nb_additions_per_path[str(parent)] += 1
298
+ if isinstance(operation, CommitOperationDelete):
299
+ if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0:
300
+ if operation.is_folder:
301
+ warnings.warn(
302
+ "About to delete a folder containing files that have just been"
303
+ f" updated within the same commit: '{path_in_repo}'. This can"
304
+ " cause undesired inconsistencies in your repo."
305
+ )
306
+ else:
307
+ warnings.warn(
308
+ "About to delete a file that have just been updated within the"
309
+ f" same commit: '{path_in_repo}'. This can cause undesired"
310
+ " inconsistencies in your repo."
311
+ )
312
+
313
+
314
+ @validate_hf_hub_args
315
+ def _upload_lfs_files(
316
+ *,
317
+ additions: List[CommitOperationAdd],
318
+ repo_type: str,
319
+ repo_id: str,
320
+ token: Optional[str],
321
+ endpoint: Optional[str] = None,
322
+ num_threads: int = 5,
323
+ revision: Optional[str] = None,
324
+ ):
325
+ """
326
+ Uploads the content of `additions` to the Hub using the large file storage protocol.
327
+
328
+ Relevant external documentation:
329
+ - LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
330
+
331
+ Args:
332
+ additions (`List` of `CommitOperationAdd`):
333
+ The files to be uploaded
334
+ repo_type (`str`):
335
+ Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
336
+ repo_id (`str`):
337
+ A namespace (user or an organization) and a repo name separated
338
+ by a `/`.
339
+ token (`str`, *optional*):
340
+ An authentication token ( See https://huggingface.co/settings/tokens )
341
+ num_threads (`int`, *optional*):
342
+ The number of concurrent threads to use when uploading. Defaults to 5.
343
+ revision (`str`, *optional*):
344
+ The git revision to upload to.
345
+
346
+ Raises: `RuntimeError` if an upload failed for any reason
347
+
348
+ Raises: `ValueError` if the server returns malformed responses
349
+
350
+ Raises: `requests.HTTPError` if the LFS batch endpoint returned an HTTP
351
+ error
352
+
353
+ """
354
+ # Step 1: retrieve upload instructions from the LFS batch endpoint.
355
+ # Upload instructions are retrieved by chunk of 256 files to avoid reaching
356
+ # the payload limit.
357
+ batch_actions: List[Dict] = []
358
+ for chunk in chunk_iterable(additions, chunk_size=256):
359
+ batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(
360
+ upload_infos=[op.upload_info for op in chunk],
361
+ token=token,
362
+ repo_id=repo_id,
363
+ repo_type=repo_type,
364
+ revision=revision,
365
+ endpoint=endpoint,
366
+ )
367
+
368
+ # If at least 1 error, we do not retrieve information for other chunks
369
+ if batch_errors_chunk:
370
+ message = "\n".join(
371
+ [
372
+ f'Encountered error for file with OID {err.get("oid")}: `{err.get("error", {}).get("message")}'
373
+ for err in batch_errors_chunk
374
+ ]
375
+ )
376
+ raise ValueError(f"LFS batch endpoint returned errors:\n{message}")
377
+
378
+ batch_actions += batch_actions_chunk
379
+ oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions}
380
+
381
+ # Step 2: ignore files that have already been uploaded
382
+ filtered_actions = []
383
+ for action in batch_actions:
384
+ if action.get("actions") is None:
385
+ logger.debug(
386
+ f"Content of file {oid2addop[action['oid']].path_in_repo} is already"
387
+ " present upstream - skipping upload."
388
+ )
389
+ else:
390
+ filtered_actions.append(action)
391
+
392
+ if len(filtered_actions) == 0:
393
+ logger.debug("No LFS files to upload.")
394
+ return
395
+
396
+ # Step 3: upload files concurrently according to these instructions
397
+ def _wrapped_lfs_upload(batch_action) -> None:
398
+ try:
399
+ operation = oid2addop[batch_action["oid"]]
400
+ lfs_upload(operation=operation, lfs_batch_action=batch_action, token=token)
401
+ except Exception as exc:
402
+ raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc
403
+
404
+ if HF_HUB_ENABLE_HF_TRANSFER:
405
+ logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.")
406
+ for action in hf_tqdm(filtered_actions):
407
+ _wrapped_lfs_upload(action)
408
+ elif len(filtered_actions) == 1:
409
+ logger.debug("Uploading 1 LFS file to the Hub")
410
+ _wrapped_lfs_upload(filtered_actions[0])
411
+ else:
412
+ logger.debug(
413
+ f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently"
414
+ )
415
+ thread_map(
416
+ _wrapped_lfs_upload,
417
+ filtered_actions,
418
+ desc=f"Upload {len(filtered_actions)} LFS files",
419
+ max_workers=num_threads,
420
+ tqdm_class=hf_tqdm,
421
+ )
422
+
423
+
424
+ def _validate_preupload_info(preupload_info: dict):
425
+ files = preupload_info.get("files")
426
+ if not isinstance(files, list):
427
+ raise ValueError("preupload_info is improperly formatted")
428
+ for file_info in files:
429
+ if not (
430
+ isinstance(file_info, dict)
431
+ and isinstance(file_info.get("path"), str)
432
+ and isinstance(file_info.get("uploadMode"), str)
433
+ and (file_info["uploadMode"] in ("lfs", "regular"))
434
+ ):
435
+ raise ValueError("preupload_info is improperly formatted:")
436
+ return preupload_info
437
+
438
+
439
+ @validate_hf_hub_args
440
+ def _fetch_upload_modes(
441
+ additions: Iterable[CommitOperationAdd],
442
+ repo_type: str,
443
+ repo_id: str,
444
+ token: Optional[str],
445
+ revision: str,
446
+ endpoint: Optional[str] = None,
447
+ create_pr: bool = False,
448
+ gitignore_content: Optional[str] = None,
449
+ ) -> None:
450
+ """
451
+ Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob
452
+ or as git LFS blob. Input `additions` are mutated in-place with the upload mode.
453
+
454
+ Args:
455
+ additions (`Iterable` of :class:`CommitOperationAdd`):
456
+ Iterable of :class:`CommitOperationAdd` describing the files to
457
+ upload to the Hub.
458
+ repo_type (`str`):
459
+ Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
460
+ repo_id (`str`):
461
+ A namespace (user or an organization) and a repo name separated
462
+ by a `/`.
463
+ token (`str`, *optional*):
464
+ An authentication token ( See https://huggingface.co/settings/tokens )
465
+ revision (`str`):
466
+ The git revision to upload the files to. Can be any valid git revision.
467
+ gitignore_content (`str`, *optional*):
468
+ The content of the `.gitignore` file to know which files should be ignored. The order of priority
469
+ is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present
470
+ in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub
471
+ (if any).
472
+ Raises:
473
+ [`~utils.HfHubHTTPError`]
474
+ If the Hub API returned an error.
475
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
476
+ If the Hub API response is improperly formatted.
477
+ """
478
+ endpoint = endpoint if endpoint is not None else ENDPOINT
479
+ headers = build_hf_headers(token=token)
480
+
481
+ # Fetch upload mode (LFS or regular) chunk by chunk.
482
+ upload_modes: Dict[str, UploadMode] = {}
483
+ should_ignore_info: Dict[str, bool] = {}
484
+
485
+ for chunk in chunk_iterable(additions, 256):
486
+ payload: Dict = {
487
+ "files": [
488
+ {
489
+ "path": op.path_in_repo,
490
+ "sample": base64.b64encode(op.upload_info.sample).decode("ascii"),
491
+ "size": op.upload_info.size,
492
+ "sha": op.upload_info.sha256.hex(),
493
+ }
494
+ for op in chunk
495
+ ]
496
+ }
497
+ if gitignore_content is not None:
498
+ payload["gitIgnore"] = gitignore_content
499
+
500
+ resp = get_session().post(
501
+ f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
502
+ json=payload,
503
+ headers=headers,
504
+ params={"create_pr": "1"} if create_pr else None,
505
+ )
506
+ hf_raise_for_status(resp)
507
+ preupload_info = _validate_preupload_info(resp.json())
508
+ upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]})
509
+ should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]})
510
+
511
+ # Set upload mode for each addition operation
512
+ for addition in additions:
513
+ addition._upload_mode = upload_modes[addition.path_in_repo]
514
+ addition._should_ignore = should_ignore_info[addition.path_in_repo]
515
+
516
+ # Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented)
517
+ # => empty files are uploaded as "regular" to still allow users to commit them.
518
+ for addition in additions:
519
+ if addition.upload_info.size == 0:
520
+ addition._upload_mode = "regular"
521
+
522
+
523
+ @validate_hf_hub_args
524
+ def _fetch_lfs_files_to_copy(
525
+ copies: Iterable[CommitOperationCopy],
526
+ repo_type: str,
527
+ repo_id: str,
528
+ token: Optional[str],
529
+ revision: str,
530
+ endpoint: Optional[str] = None,
531
+ ) -> Dict[Tuple[str, Optional[str]], "RepoFile"]:
532
+ """
533
+ Requests the Hub files information of the LFS files to be copied, including their sha256.
534
+
535
+ Args:
536
+ copies (`Iterable` of :class:`CommitOperationCopy`):
537
+ Iterable of :class:`CommitOperationCopy` describing the files to
538
+ copy on the Hub.
539
+ repo_type (`str`):
540
+ Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
541
+ repo_id (`str`):
542
+ A namespace (user or an organization) and a repo name separated
543
+ by a `/`.
544
+ token (`str`, *optional*):
545
+ An authentication token ( See https://huggingface.co/settings/tokens )
546
+ revision (`str`):
547
+ The git revision to upload the files to. Can be any valid git revision.
548
+
549
+ Returns: `Dict[Tuple[str, Optional[str]], RepoFile]]`
550
+ Key is the file path and revision of the file to copy, value is the repo file.
551
+
552
+ Raises:
553
+ [`~utils.HfHubHTTPError`]
554
+ If the Hub API returned an error.
555
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
556
+ If the Hub API response is improperly formatted.
557
+ """
558
+ from .hf_api import HfApi, RepoFolder
559
+
560
+ hf_api = HfApi(endpoint=endpoint, token=token)
561
+ files_to_copy = {}
562
+ for src_revision, operations in groupby(copies, key=lambda op: op.src_revision):
563
+ operations = list(operations) # type: ignore
564
+ paths = [op.src_path_in_repo for op in operations]
565
+ for offset in range(0, len(paths), FETCH_LFS_BATCH_SIZE):
566
+ src_repo_files = hf_api.get_paths_info(
567
+ repo_id=repo_id,
568
+ paths=paths[offset : offset + FETCH_LFS_BATCH_SIZE],
569
+ revision=src_revision or revision,
570
+ repo_type=repo_type,
571
+ )
572
+ for src_repo_file in src_repo_files:
573
+ if isinstance(src_repo_file, RepoFolder):
574
+ raise NotImplementedError("Copying a folder is not implemented.")
575
+ if not src_repo_file.lfs:
576
+ raise NotImplementedError("Copying a non-LFS file is not implemented")
577
+ files_to_copy[(src_repo_file.rfilename, src_revision)] = src_repo_file
578
+ for operation in operations:
579
+ if (operation.src_path_in_repo, src_revision) not in files_to_copy:
580
+ raise EntryNotFoundError(
581
+ f"Cannot copy {operation.src_path_in_repo} at revision "
582
+ f"{src_revision or revision}: file is missing on repo."
583
+ )
584
+ return files_to_copy
585
+
586
+
587
+ def _prepare_commit_payload(
588
+ operations: Iterable[CommitOperation],
589
+ files_to_copy: Dict[Tuple[str, Optional[str]], "RepoFile"],
590
+ commit_message: str,
591
+ commit_description: Optional[str] = None,
592
+ parent_commit: Optional[str] = None,
593
+ ) -> Iterable[Dict[str, Any]]:
594
+ """
595
+ Builds the payload to POST to the `/commit` API of the Hub.
596
+
597
+ Payload is returned as an iterator so that it can be streamed as a ndjson in the
598
+ POST request.
599
+
600
+ For more information, see:
601
+ - https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073
602
+ - http://ndjson.org/
603
+ """
604
+ commit_description = commit_description if commit_description is not None else ""
605
+
606
+ # 1. Send a header item with the commit metadata
607
+ header_value = {"summary": commit_message, "description": commit_description}
608
+ if parent_commit is not None:
609
+ header_value["parentCommit"] = parent_commit
610
+ yield {"key": "header", "value": header_value}
611
+
612
+ nb_ignored_files = 0
613
+
614
+ # 2. Send operations, one per line
615
+ for operation in operations:
616
+ # Skip ignored files
617
+ if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
618
+ logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
619
+ nb_ignored_files += 1
620
+ continue
621
+
622
+ # 2.a. Case adding a regular file
623
+ if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular":
624
+ yield {
625
+ "key": "file",
626
+ "value": {
627
+ "content": operation.b64content().decode(),
628
+ "path": operation.path_in_repo,
629
+ "encoding": "base64",
630
+ },
631
+ }
632
+ # 2.b. Case adding an LFS file
633
+ elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs":
634
+ yield {
635
+ "key": "lfsFile",
636
+ "value": {
637
+ "path": operation.path_in_repo,
638
+ "algo": "sha256",
639
+ "oid": operation.upload_info.sha256.hex(),
640
+ "size": operation.upload_info.size,
641
+ },
642
+ }
643
+ # 2.c. Case deleting a file or folder
644
+ elif isinstance(operation, CommitOperationDelete):
645
+ yield {
646
+ "key": "deletedFolder" if operation.is_folder else "deletedFile",
647
+ "value": {"path": operation.path_in_repo},
648
+ }
649
+ # 2.d. Case copying a file or folder
650
+ elif isinstance(operation, CommitOperationCopy):
651
+ file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)]
652
+ if not file_to_copy.lfs:
653
+ raise NotImplementedError("Copying a non-LFS file is not implemented")
654
+ yield {
655
+ "key": "lfsFile",
656
+ "value": {
657
+ "path": operation.path_in_repo,
658
+ "algo": "sha256",
659
+ "oid": file_to_copy.lfs["sha256"],
660
+ },
661
+ }
662
+ # 2.e. Never expected to happen
663
+ else:
664
+ raise ValueError(
665
+ f"Unknown operation to commit. Operation: {operation}. Upload mode:"
666
+ f" {getattr(operation, '_upload_mode', None)}"
667
+ )
668
+
669
+ if nb_ignored_files > 0:
670
+ logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).")
lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import logging
3
+ import os
4
+ import time
5
+ from concurrent.futures import Future
6
+ from dataclasses import dataclass
7
+ from io import SEEK_END, SEEK_SET, BytesIO
8
+ from pathlib import Path
9
+ from threading import Lock, Thread
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ from .hf_api import IGNORE_GIT_FOLDER_PATTERNS, CommitInfo, CommitOperationAdd, HfApi
13
+ from .utils import filter_repo_objects
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class _FileToUpload:
21
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
22
+
23
+ local_path: Path
24
+ path_in_repo: str
25
+ size_limit: int
26
+ last_modified: float
27
+
28
+
29
+ class CommitScheduler:
30
+ """
31
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
32
+
33
+ The scheduler is started when instantiated and run indefinitely. At the end of your script, a last commit is
34
+ triggered. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
35
+ to learn more about how to use it.
36
+
37
+ Args:
38
+ repo_id (`str`):
39
+ The id of the repo to commit to.
40
+ folder_path (`str` or `Path`):
41
+ Path to the local folder to upload regularly.
42
+ every (`int` or `float`, *optional*):
43
+ The number of minutes between each commit. Defaults to 5 minutes.
44
+ path_in_repo (`str`, *optional*):
45
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
46
+ of the repository.
47
+ repo_type (`str`, *optional*):
48
+ The type of the repo to commit to. Defaults to `model`.
49
+ revision (`str`, *optional*):
50
+ The revision of the repo to commit to. Defaults to `main`.
51
+ private (`bool`, *optional*):
52
+ Whether to make the repo private. Defaults to `False`. This value is ignored if the repo already exist.
53
+ token (`str`, *optional*):
54
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
55
+ allow_patterns (`List[str]` or `str`, *optional*):
56
+ If provided, only files matching at least one pattern are uploaded.
57
+ ignore_patterns (`List[str]` or `str`, *optional*):
58
+ If provided, files matching any of the patterns are not uploaded.
59
+ squash_history (`bool`, *optional*):
60
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
61
+ useful to avoid degraded performances on the repo when it grows too large.
62
+ hf_api (`HfApi`, *optional*):
63
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
64
+
65
+ Example:
66
+ ```py
67
+ >>> from pathlib import Path
68
+ >>> from huggingface_hub import CommitScheduler
69
+
70
+ # Scheduler uploads every 10 minutes
71
+ >>> csv_path = Path("watched_folder/data.csv")
72
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
73
+
74
+ >>> with csv_path.open("a") as f:
75
+ ... f.write("first line")
76
+
77
+ # Some time later (...)
78
+ >>> with csv_path.open("a") as f:
79
+ ... f.write("second line")
80
+ ```
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ *,
86
+ repo_id: str,
87
+ folder_path: Union[str, Path],
88
+ every: Union[int, float] = 5,
89
+ path_in_repo: Optional[str] = None,
90
+ repo_type: Optional[str] = None,
91
+ revision: Optional[str] = None,
92
+ private: bool = False,
93
+ token: Optional[str] = None,
94
+ allow_patterns: Optional[Union[List[str], str]] = None,
95
+ ignore_patterns: Optional[Union[List[str], str]] = None,
96
+ squash_history: bool = False,
97
+ hf_api: Optional["HfApi"] = None,
98
+ ) -> None:
99
+ self.api = hf_api or HfApi(token=token)
100
+
101
+ # Folder
102
+ self.folder_path = Path(folder_path).expanduser().resolve()
103
+ self.path_in_repo = path_in_repo or ""
104
+ self.allow_patterns = allow_patterns
105
+
106
+ if ignore_patterns is None:
107
+ ignore_patterns = []
108
+ elif isinstance(ignore_patterns, str):
109
+ ignore_patterns = [ignore_patterns]
110
+ self.ignore_patterns = ignore_patterns + IGNORE_GIT_FOLDER_PATTERNS
111
+
112
+ if self.folder_path.is_file():
113
+ raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.")
114
+ self.folder_path.mkdir(parents=True, exist_ok=True)
115
+
116
+ # Repository
117
+ repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True)
118
+ self.repo_id = repo_url.repo_id
119
+ self.repo_type = repo_type
120
+ self.revision = revision
121
+ self.token = token
122
+
123
+ # Keep track of already uploaded files
124
+ self.last_uploaded: Dict[Path, float] = {} # key is local path, value is timestamp
125
+
126
+ # Scheduler
127
+ if not every > 0:
128
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
129
+ self.lock = Lock()
130
+ self.every = every
131
+ self.squash_history = squash_history
132
+
133
+ logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.")
134
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
135
+ self._scheduler_thread.start()
136
+ atexit.register(self._push_to_hub)
137
+
138
+ self.__stopped = False
139
+
140
+ def stop(self) -> None:
141
+ """Stop the scheduler.
142
+
143
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
144
+ """
145
+ self.__stopped = True
146
+
147
+ def _run_scheduler(self) -> None:
148
+ """Dumb thread waiting between each scheduled push to Hub."""
149
+ while True:
150
+ self.last_future = self.trigger()
151
+ time.sleep(self.every * 60)
152
+ if self.__stopped:
153
+ break
154
+
155
+ def trigger(self) -> Future:
156
+ """Trigger a `push_to_hub` and return a future.
157
+
158
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
159
+ immediately, without waiting for the next scheduled commit.
160
+ """
161
+ return self.api.run_as_future(self._push_to_hub)
162
+
163
+ def _push_to_hub(self) -> Optional[CommitInfo]:
164
+ if self.__stopped: # If stopped, already scheduled commits are ignored
165
+ return None
166
+
167
+ logger.info("(Background) scheduled commit triggered.")
168
+ try:
169
+ value = self.push_to_hub()
170
+ if self.squash_history:
171
+ logger.info("(Background) squashing repo history.")
172
+ self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision)
173
+ return value
174
+ except Exception as e:
175
+ logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced
176
+ raise
177
+
178
+ def push_to_hub(self) -> Optional[CommitInfo]:
179
+ """
180
+ Push folder to the Hub and return the commit info.
181
+
182
+ <Tip warning={true}>
183
+
184
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
185
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
186
+ issues.
187
+
188
+ </Tip>
189
+
190
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
191
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
192
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
193
+ for example to compress data together in a single file before committing. For more details and examples, check
194
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
195
+ """
196
+ # Check files to upload (with lock)
197
+ with self.lock:
198
+ logger.debug("Listing files to upload for scheduled commit.")
199
+
200
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
201
+ relpath_to_abspath = {
202
+ path.relative_to(self.folder_path).as_posix(): path
203
+ for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic
204
+ if path.is_file()
205
+ }
206
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
207
+
208
+ # Filter with pattern + filter out unchanged files + retrieve current file size
209
+ files_to_upload: List[_FileToUpload] = []
210
+ for relpath in filter_repo_objects(
211
+ relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns
212
+ ):
213
+ local_path = relpath_to_abspath[relpath]
214
+ stat = local_path.stat()
215
+ if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime:
216
+ files_to_upload.append(
217
+ _FileToUpload(
218
+ local_path=local_path,
219
+ path_in_repo=prefix + relpath,
220
+ size_limit=stat.st_size,
221
+ last_modified=stat.st_mtime,
222
+ )
223
+ )
224
+
225
+ # Return if nothing to upload
226
+ if len(files_to_upload) == 0:
227
+ logger.debug("Dropping schedule commit: no changed file to upload.")
228
+ return None
229
+
230
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
231
+ logger.debug("Removing unchanged files since previous scheduled commit.")
232
+ add_operations = [
233
+ CommitOperationAdd(
234
+ # Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
235
+ path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit),
236
+ path_in_repo=file_to_upload.path_in_repo,
237
+ )
238
+ for file_to_upload in files_to_upload
239
+ ]
240
+
241
+ # Upload files (append mode expected - no need for lock)
242
+ logger.debug("Uploading files for scheduled commit.")
243
+ commit_info = self.api.create_commit(
244
+ repo_id=self.repo_id,
245
+ repo_type=self.repo_type,
246
+ operations=add_operations,
247
+ commit_message="Scheduled Commit",
248
+ revision=self.revision,
249
+ )
250
+
251
+ # Successful commit: keep track of the latest "last_modified" for each file
252
+ for file in files_to_upload:
253
+ self.last_uploaded[file.local_path] = file.last_modified
254
+ return commit_info
255
+
256
+
257
+ class PartialFileIO(BytesIO):
258
+ """A file-like object that reads only the first part of a file.
259
+
260
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
261
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
262
+
263
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
264
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
265
+
266
+ Only supports `read`, `tell` and `seek` methods.
267
+
268
+ Args:
269
+ file_path (`str` or `Path`):
270
+ Path to the file to read.
271
+ size_limit (`int`):
272
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
273
+ will be read (and uploaded).
274
+ """
275
+
276
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
277
+ self._file_path = Path(file_path)
278
+ self._file = self._file_path.open("rb")
279
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
280
+
281
+ def __del__(self) -> None:
282
+ self._file.close()
283
+ return super().__del__()
284
+
285
+ def __repr__(self) -> str:
286
+ return f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
287
+
288
+ def __len__(self) -> int:
289
+ return self._size_limit
290
+
291
+ def __getattribute__(self, name: str):
292
+ if name.startswith("_") or name in ("read", "tell", "seek"): # only 3 public methods supported
293
+ return super().__getattribute__(name)
294
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
295
+
296
+ def tell(self) -> int:
297
+ """Return the current file position."""
298
+ return self._file.tell()
299
+
300
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
301
+ """Change the stream position to the given offset.
302
+
303
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
304
+ """
305
+ if __whence == SEEK_END:
306
+ # SEEK_END => set from the truncated end
307
+ __offset = len(self) + __offset
308
+ __whence = SEEK_SET
309
+
310
+ pos = self._file.seek(__offset, __whence)
311
+ if pos > self._size_limit:
312
+ return self._file.seek(self._size_limit)
313
+ return pos
314
+
315
+ def read(self, __size: Optional[int] = -1) -> bytes:
316
+ """Read at most `__size` bytes from the file.
317
+
318
+ Behavior is the same as a regular file, except that it is capped to the size limit.
319
+ """
320
+ current = self._file.tell()
321
+ if __size is None or __size < 0:
322
+ # Read until file limit
323
+ truncated_size = self._size_limit - current
324
+ else:
325
+ # Read until file limit or __size
326
+ truncated_size = min(__size, self._size_limit - current)
327
+ return self._file.read(truncated_size)
lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from dataclasses import dataclass, field
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from typing import TYPE_CHECKING, Dict, Optional
6
+
7
+ from .inference._client import InferenceClient
8
+ from .inference._generated._async_client import AsyncInferenceClient
9
+ from .utils import logging, parse_datetime
10
+
11
+
12
+ if TYPE_CHECKING:
13
+ from .hf_api import HfApi
14
+
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+
19
+ class InferenceEndpointError(Exception):
20
+ """Generic exception when dealing with Inference Endpoints."""
21
+
22
+
23
+ class InferenceEndpointTimeoutError(InferenceEndpointError, TimeoutError):
24
+ """Exception for timeouts while waiting for Inference Endpoint."""
25
+
26
+
27
+ class InferenceEndpointStatus(str, Enum):
28
+ PENDING = "pending"
29
+ INITIALIZING = "initializing"
30
+ UPDATING = "updating"
31
+ UPDATE_FAILED = "updateFailed"
32
+ RUNNING = "running"
33
+ PAUSED = "paused"
34
+ FAILED = "failed"
35
+ SCALED_TO_ZERO = "scaledToZero"
36
+
37
+
38
+ class InferenceEndpointType(str, Enum):
39
+ PUBlIC = "public"
40
+ PROTECTED = "protected"
41
+ PRIVATE = "private"
42
+
43
+
44
+ @dataclass
45
+ class InferenceEndpoint:
46
+ """
47
+ Contains information about a deployed Inference Endpoint.
48
+
49
+ Args:
50
+ name (`str`):
51
+ The unique name of the Inference Endpoint.
52
+ namespace (`str`):
53
+ The namespace where the Inference Endpoint is located.
54
+ repository (`str`):
55
+ The name of the model repository deployed on this Inference Endpoint.
56
+ status ([`InferenceEndpointStatus`]):
57
+ The current status of the Inference Endpoint.
58
+ url (`str`, *optional*):
59
+ The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL.
60
+ framework (`str`):
61
+ The machine learning framework used for the model.
62
+ revision (`str`):
63
+ The specific model revision deployed on the Inference Endpoint.
64
+ task (`str`):
65
+ The task associated with the deployed model.
66
+ created_at (`datetime.datetime`):
67
+ The timestamp when the Inference Endpoint was created.
68
+ updated_at (`datetime.datetime`):
69
+ The timestamp of the last update of the Inference Endpoint.
70
+ type ([`InferenceEndpointType`]):
71
+ The type of the Inference Endpoint (public, protected, private).
72
+ raw (`Dict`):
73
+ The raw dictionary data returned from the API.
74
+ token (`str`, *optional*):
75
+ Authentication token for the Inference Endpoint, if set when requesting the API.
76
+
77
+ Example:
78
+ ```python
79
+ >>> from huggingface_hub import get_inference_endpoint
80
+ >>> endpoint = get_inference_endpoint("my-text-to-image")
81
+ >>> endpoint
82
+ InferenceEndpoint(name='my-text-to-image', ...)
83
+
84
+ # Get status
85
+ >>> endpoint.status
86
+ 'running'
87
+ >>> endpoint.url
88
+ 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
89
+
90
+ # Run inference
91
+ >>> endpoint.client.text_to_image(...)
92
+
93
+ # Pause endpoint to save $$$
94
+ >>> endpoint.pause()
95
+
96
+ # ...
97
+ # Resume and wait for deployment
98
+ >>> endpoint.resume()
99
+ >>> endpoint.wait()
100
+ >>> endpoint.client.text_to_image(...)
101
+ ```
102
+ """
103
+
104
+ # Field in __repr__
105
+ name: str = field(init=False)
106
+ namespace: str
107
+ repository: str = field(init=False)
108
+ status: InferenceEndpointStatus = field(init=False)
109
+ url: Optional[str] = field(init=False)
110
+
111
+ # Other fields
112
+ framework: str = field(repr=False, init=False)
113
+ revision: str = field(repr=False, init=False)
114
+ task: str = field(repr=False, init=False)
115
+ created_at: datetime = field(repr=False, init=False)
116
+ updated_at: datetime = field(repr=False, init=False)
117
+ type: InferenceEndpointType = field(repr=False, init=False)
118
+
119
+ # Raw dict from the API
120
+ raw: Dict = field(repr=False)
121
+
122
+ # Internal fields
123
+ _token: Optional[str] = field(repr=False, compare=False)
124
+ _api: "HfApi" = field(repr=False, compare=False)
125
+
126
+ @classmethod
127
+ def from_raw(
128
+ cls, raw: Dict, namespace: str, token: Optional[str] = None, api: Optional["HfApi"] = None
129
+ ) -> "InferenceEndpoint":
130
+ """Initialize object from raw dictionary."""
131
+ if api is None:
132
+ from .hf_api import HfApi
133
+
134
+ api = HfApi()
135
+ if token is None:
136
+ token = api.token
137
+
138
+ # All other fields are populated in __post_init__
139
+ return cls(raw=raw, namespace=namespace, _token=token, _api=api)
140
+
141
+ def __post_init__(self) -> None:
142
+ """Populate fields from raw dictionary."""
143
+ self._populate_from_raw()
144
+
145
+ @property
146
+ def client(self) -> InferenceClient:
147
+ """Returns a client to make predictions on this Inference Endpoint.
148
+
149
+ Returns:
150
+ [`InferenceClient`]: an inference client pointing to the deployed endpoint.
151
+
152
+ Raises:
153
+ [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
154
+ """
155
+ if self.url is None:
156
+ raise InferenceEndpointError(
157
+ "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
158
+ "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
159
+ )
160
+ return InferenceClient(model=self.url, token=self._token)
161
+
162
+ @property
163
+ def async_client(self) -> AsyncInferenceClient:
164
+ """Returns a client to make predictions on this Inference Endpoint.
165
+
166
+ Returns:
167
+ [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint.
168
+
169
+ Raises:
170
+ [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
171
+ """
172
+ if self.url is None:
173
+ raise InferenceEndpointError(
174
+ "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
175
+ "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
176
+ )
177
+ return AsyncInferenceClient(model=self.url, token=self._token)
178
+
179
+ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint":
180
+ """Wait for the Inference Endpoint to be deployed.
181
+
182
+ Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout`
183
+ seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest
184
+ data.
185
+
186
+ Args:
187
+ timeout (`int`, *optional*):
188
+ The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait
189
+ indefinitely.
190
+ refresh_every (`int`, *optional*):
191
+ The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s.
192
+
193
+ Returns:
194
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
195
+ """
196
+ if self.url is not None: # Means the endpoint is deployed
197
+ logger.info("Inference Endpoint is ready to be used.")
198
+ return self
199
+
200
+ if timeout is not None and timeout < 0:
201
+ raise ValueError("`timeout` cannot be negative.")
202
+ if refresh_every <= 0:
203
+ raise ValueError("`refresh_every` must be positive.")
204
+
205
+ start = time.time()
206
+ while True:
207
+ self.fetch()
208
+ if self.url is not None: # Means the endpoint is deployed
209
+ logger.info("Inference Endpoint is ready to be used.")
210
+ return self
211
+ if timeout is not None:
212
+ if time.time() - start > timeout:
213
+ raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")
214
+ logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...")
215
+ time.sleep(refresh_every)
216
+
217
+ def fetch(self) -> "InferenceEndpoint":
218
+ """Fetch latest information about the Inference Endpoint.
219
+
220
+ Returns:
221
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
222
+ """
223
+ obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
224
+ self.raw = obj.raw
225
+ self._populate_from_raw()
226
+ return self
227
+
228
+ def update(
229
+ self,
230
+ *,
231
+ # Compute update
232
+ accelerator: Optional[str] = None,
233
+ instance_size: Optional[str] = None,
234
+ instance_type: Optional[str] = None,
235
+ min_replica: Optional[int] = None,
236
+ max_replica: Optional[int] = None,
237
+ # Model update
238
+ repository: Optional[str] = None,
239
+ framework: Optional[str] = None,
240
+ revision: Optional[str] = None,
241
+ task: Optional[str] = None,
242
+ ) -> "InferenceEndpoint":
243
+ """Update the Inference Endpoint.
244
+
245
+ This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
246
+ optional but at least one must be provided.
247
+
248
+ This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the
249
+ latest data from the server.
250
+
251
+ Args:
252
+ accelerator (`str`, *optional*):
253
+ The hardware accelerator to be used for inference (e.g. `"cpu"`).
254
+ instance_size (`str`, *optional*):
255
+ The size or type of the instance to be used for hosting the model (e.g. `"large"`).
256
+ instance_type (`str`, *optional*):
257
+ The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`).
258
+ min_replica (`int`, *optional*):
259
+ The minimum number of replicas (instances) to keep running for the Inference Endpoint.
260
+ max_replica (`int`, *optional*):
261
+ The maximum number of replicas (instances) to scale to for the Inference Endpoint.
262
+
263
+ repository (`str`, *optional*):
264
+ The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
265
+ framework (`str`, *optional*):
266
+ The machine learning framework used for the model (e.g. `"custom"`).
267
+ revision (`str`, *optional*):
268
+ The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
269
+ task (`str`, *optional*):
270
+ The task on which to deploy the model (e.g. `"text-classification"`).
271
+
272
+ Returns:
273
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
274
+ """
275
+ # Make API call
276
+ obj = self._api.update_inference_endpoint(
277
+ name=self.name,
278
+ namespace=self.namespace,
279
+ accelerator=accelerator,
280
+ instance_size=instance_size,
281
+ instance_type=instance_type,
282
+ min_replica=min_replica,
283
+ max_replica=max_replica,
284
+ repository=repository,
285
+ framework=framework,
286
+ revision=revision,
287
+ task=task,
288
+ token=self._token,
289
+ )
290
+
291
+ # Mutate current object
292
+ self.raw = obj.raw
293
+ self._populate_from_raw()
294
+ return self
295
+
296
+ def pause(self) -> "InferenceEndpoint":
297
+ """Pause the Inference Endpoint.
298
+
299
+ A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`].
300
+ This is different than scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which
301
+ would be automatically restarted when a request is made to it.
302
+
303
+ This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the
304
+ latest data from the server.
305
+
306
+ Returns:
307
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
308
+ """
309
+ obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
310
+ self.raw = obj.raw
311
+ self._populate_from_raw()
312
+ return self
313
+
314
+ def resume(self) -> "InferenceEndpoint":
315
+ """Resume the Inference Endpoint.
316
+
317
+ This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the
318
+ latest data from the server.
319
+
320
+ Returns:
321
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
322
+ """
323
+ obj = self._api.resume_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
324
+ self.raw = obj.raw
325
+ self._populate_from_raw()
326
+ return self
327
+
328
+ def scale_to_zero(self) -> "InferenceEndpoint":
329
+ """Scale Inference Endpoint to zero.
330
+
331
+ An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a
332
+ cold start delay. This is different than pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which
333
+ would require a manual resume with [`InferenceEndpoint.resume`].
334
+
335
+ This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the
336
+ latest data from the server.
337
+
338
+ Returns:
339
+ [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
340
+ """
341
+ obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
342
+ self.raw = obj.raw
343
+ self._populate_from_raw()
344
+ return self
345
+
346
+ def delete(self) -> None:
347
+ """Delete the Inference Endpoint.
348
+
349
+ This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
350
+ to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`].
351
+
352
+ This is an alias for [`HfApi.delete_inference_endpoint`].
353
+ """
354
+ self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
355
+
356
+ def _populate_from_raw(self) -> None:
357
+ """Populate fields from raw dictionary.
358
+
359
+ Called in __post_init__ + each time the Inference Endpoint is updated.
360
+ """
361
+ # Repr fields
362
+ self.name = self.raw["name"]
363
+ self.repository = self.raw["model"]["repository"]
364
+ self.status = self.raw["status"]["state"]
365
+ self.url = self.raw["status"].get("url")
366
+
367
+ # Other fields
368
+ self.framework = self.raw["model"]["framework"]
369
+ self.revision = self.raw["model"]["revision"]
370
+ self.task = self.raw["model"]["task"]
371
+ self.created_at = parse_datetime(self.raw["status"]["createdAt"])
372
+ self.updated_at = parse_datetime(self.raw["status"]["updatedAt"])
373
+ self.type = self.raw["type"]
lib/python3.11/site-packages/huggingface_hub/_login.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains methods to login to the Hub."""
15
+ import os
16
+ import subprocess
17
+ from functools import partial
18
+ from getpass import getpass
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ from . import constants
23
+ from .commands._cli_utils import ANSI
24
+ from .utils import (
25
+ capture_output,
26
+ get_token,
27
+ is_google_colab,
28
+ is_notebook,
29
+ list_credential_helpers,
30
+ logging,
31
+ run_subprocess,
32
+ set_git_credential,
33
+ unset_git_credential,
34
+ )
35
+ from .utils._token import _get_token_from_environment, _get_token_from_google_colab
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _HF_LOGO_ASCII = """
41
+ _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
42
+ _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
43
+ _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
44
+ _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
45
+ _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
46
+ """
47
+
48
+
49
+ def login(
50
+ token: Optional[str] = None,
51
+ add_to_git_credential: bool = False,
52
+ new_session: bool = True,
53
+ write_permission: bool = False,
54
+ ) -> None:
55
+ """Login the machine to access the Hub.
56
+
57
+ The `token` is persisted in cache and set as a git credential. Once done, the machine
58
+ is logged in and the access token will be available across all `huggingface_hub`
59
+ components. If `token` is not provided, it will be prompted to the user either with
60
+ a widget (in a notebook) or via the terminal.
61
+
62
+ To login from outside of a script, one can also use `huggingface-cli login` which is
63
+ a cli command that wraps [`login`].
64
+
65
+ <Tip>
66
+
67
+ [`login`] is a drop-in replacement method for [`notebook_login`] as it wraps and
68
+ extends its capabilities.
69
+
70
+ </Tip>
71
+
72
+ <Tip>
73
+
74
+ When the token is not passed, [`login`] will automatically detect if the script runs
75
+ in a notebook or not. However, this detection might not be accurate due to the
76
+ variety of notebooks that exists nowadays. If that is the case, you can always force
77
+ the UI by using [`notebook_login`] or [`interpreter_login`].
78
+
79
+ </Tip>
80
+
81
+ Args:
82
+ token (`str`, *optional*):
83
+ User access token to generate from https://huggingface.co/settings/token.
84
+ add_to_git_credential (`bool`, defaults to `False`):
85
+ If `True`, token will be set as git credential. If no git credential helper
86
+ is configured, a warning will be displayed to the user. If `token` is `None`,
87
+ the value of `add_to_git_credential` is ignored and will be prompted again
88
+ to the end user.
89
+ new_session (`bool`, defaults to `True`):
90
+ If `True`, will request a token even if one is already saved on the machine.
91
+ write_permission (`bool`, defaults to `False`):
92
+ If `True`, requires a token with write permission.
93
+ Raises:
94
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
95
+ If an organization token is passed. Only personal account tokens are valid
96
+ to login.
97
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
98
+ If token is invalid.
99
+ [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
100
+ If running in a notebook but `ipywidgets` is not installed.
101
+ """
102
+ if token is not None:
103
+ if not add_to_git_credential:
104
+ print(
105
+ "Token will not been saved to git credential helper. Pass"
106
+ " `add_to_git_credential=True` if you want to set the git"
107
+ " credential as well."
108
+ )
109
+ _login(token, add_to_git_credential=add_to_git_credential, write_permission=write_permission)
110
+ elif is_notebook():
111
+ notebook_login(new_session=new_session, write_permission=write_permission)
112
+ else:
113
+ interpreter_login(new_session=new_session, write_permission=write_permission)
114
+
115
+
116
+ def logout() -> None:
117
+ """Logout the machine from the Hub.
118
+
119
+ Token is deleted from the machine and removed from git credential.
120
+ """
121
+ if get_token() is None:
122
+ print("Not logged in!")
123
+ return
124
+
125
+ # Delete token from git credentials
126
+ unset_git_credential()
127
+
128
+ # Delete token file
129
+ try:
130
+ Path(constants.HF_TOKEN_PATH).unlink()
131
+ except FileNotFoundError:
132
+ pass
133
+
134
+ # Check if still logged in
135
+ if _get_token_from_google_colab() is not None:
136
+ raise EnvironmentError(
137
+ "You are automatically logged in using a Google Colab secret.\n"
138
+ "To log out, you must unset the `HF_TOKEN` secret in your Colab settings."
139
+ )
140
+ if _get_token_from_environment() is not None:
141
+ raise EnvironmentError(
142
+ "Token has been deleted from your machine but you are still logged in.\n"
143
+ "To log out, you must clear out both `HF_TOKEN` and `HUGGING_FACE_HUB_TOKEN` environment variables."
144
+ )
145
+
146
+ print("Successfully logged out.")
147
+
148
+
149
+ ###
150
+ # Interpreter-based login (text)
151
+ ###
152
+
153
+
154
+ def interpreter_login(new_session: bool = True, write_permission: bool = False) -> None:
155
+ """
156
+ Displays a prompt to login to the HF website and store the token.
157
+
158
+ This is equivalent to [`login`] without passing a token when not run in a notebook.
159
+ [`interpreter_login`] is useful if you want to force the use of the terminal prompt
160
+ instead of a notebook widget.
161
+
162
+ For more details, see [`login`].
163
+
164
+ Args:
165
+ new_session (`bool`, defaults to `True`):
166
+ If `True`, will request a token even if one is already saved on the machine.
167
+ write_permission (`bool`, defaults to `False`):
168
+ If `True`, requires a token with write permission.
169
+
170
+ """
171
+ if not new_session and _current_token_okay(write_permission=write_permission):
172
+ print("User is already logged in.")
173
+ return
174
+
175
+ from .commands.delete_cache import _ask_for_confirmation_no_tui
176
+
177
+ print(_HF_LOGO_ASCII)
178
+ if get_token() is not None:
179
+ print(
180
+ " A token is already saved on your machine. Run `huggingface-cli"
181
+ " whoami` to get more information or `huggingface-cli logout` if you want"
182
+ " to log out."
183
+ )
184
+ print(" Setting a new token will erase the existing one.")
185
+
186
+ print(" To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .")
187
+ if os.name == "nt":
188
+ print("Token can be pasted using 'Right-Click'.")
189
+ token = getpass("Token: ")
190
+ add_to_git_credential = _ask_for_confirmation_no_tui("Add token as git credential?")
191
+
192
+ _login(token=token, add_to_git_credential=add_to_git_credential, write_permission=write_permission)
193
+
194
+
195
+ ###
196
+ # Notebook-based login (widget)
197
+ ###
198
+
199
+ NOTEBOOK_LOGIN_PASSWORD_HTML = """<center> <img
200
+ src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
201
+ alt='Hugging Face'> <br> Immediately click login after typing your password or
202
+ it might be stored in plain text in this notebook file. </center>"""
203
+
204
+
205
+ NOTEBOOK_LOGIN_TOKEN_HTML_START = """<center> <img
206
+ src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
207
+ alt='Hugging Face'> <br> Copy a token from <a
208
+ href="https://huggingface.co/settings/tokens" target="_blank">your Hugging Face
209
+ tokens page</a> and paste it below. <br> Immediately click login after copying
210
+ your token or it might be stored in plain text in this notebook file. </center>"""
211
+
212
+
213
+ NOTEBOOK_LOGIN_TOKEN_HTML_END = """
214
+ <b>Pro Tip:</b> If you don't already have one, you can create a dedicated
215
+ 'notebooks' token with 'write' access, that you can then easily reuse for all
216
+ notebooks. </center>"""
217
+
218
+
219
+ def notebook_login(new_session: bool = True, write_permission: bool = False) -> None:
220
+ """
221
+ Displays a widget to login to the HF website and store the token.
222
+
223
+ This is equivalent to [`login`] without passing a token when run in a notebook.
224
+ [`notebook_login`] is useful if you want to force the use of the notebook widget
225
+ instead of a prompt in the terminal.
226
+
227
+ For more details, see [`login`].
228
+
229
+ Args:
230
+ new_session (`bool`, defaults to `True`):
231
+ If `True`, will request a token even if one is already saved on the machine.
232
+ write_permission (`bool`, defaults to `False`):
233
+ If `True`, requires a token with write permission.
234
+ """
235
+ try:
236
+ import ipywidgets.widgets as widgets # type: ignore
237
+ from IPython.display import display # type: ignore
238
+ except ImportError:
239
+ raise ImportError(
240
+ "The `notebook_login` function can only be used in a notebook (Jupyter or"
241
+ " Colab) and you need the `ipywidgets` module: `pip install ipywidgets`."
242
+ )
243
+ if not new_session and _current_token_okay(write_permission=write_permission):
244
+ print("User is already logged in.")
245
+ return
246
+
247
+ box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%")
248
+
249
+ token_widget = widgets.Password(description="Token:")
250
+ git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?")
251
+ token_finish_button = widgets.Button(description="Login")
252
+
253
+ login_token_widget = widgets.VBox(
254
+ [
255
+ widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_START),
256
+ token_widget,
257
+ git_checkbox_widget,
258
+ token_finish_button,
259
+ widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_END),
260
+ ],
261
+ layout=box_layout,
262
+ )
263
+ display(login_token_widget)
264
+
265
+ # On click events
266
+ def login_token_event(t, write_permission: bool = False):
267
+ """
268
+ Event handler for the login button.
269
+
270
+ Args:
271
+ write_permission (`bool`, defaults to `False`):
272
+ If `True`, requires a token with write permission.
273
+ """
274
+ token = token_widget.value
275
+ add_to_git_credential = git_checkbox_widget.value
276
+ # Erase token and clear value to make sure it's not saved in the notebook.
277
+ token_widget.value = ""
278
+ # Hide inputs
279
+ login_token_widget.children = [widgets.Label("Connecting...")]
280
+ try:
281
+ with capture_output() as captured:
282
+ _login(token, add_to_git_credential=add_to_git_credential, write_permission=write_permission)
283
+ message = captured.getvalue()
284
+ except Exception as error:
285
+ message = str(error)
286
+ # Print result (success message or error)
287
+ login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()]
288
+
289
+ token_finish_button.on_click(partial(login_token_event, write_permission=write_permission))
290
+
291
+
292
+ ###
293
+ # Login private helpers
294
+ ###
295
+
296
+
297
+ def _login(token: str, add_to_git_credential: bool, write_permission: bool = False) -> None:
298
+ from .hf_api import get_token_permission # avoid circular import
299
+
300
+ if token.startswith("api_org"):
301
+ raise ValueError("You must use your personal account token, not an organization token.")
302
+
303
+ permission = get_token_permission(token)
304
+ if permission is None:
305
+ raise ValueError("Invalid token passed!")
306
+ elif write_permission and permission != "write":
307
+ raise ValueError(
308
+ "Token is valid but is 'read-only' and a 'write' token is required.\nPlease provide a new token with"
309
+ " correct permission."
310
+ )
311
+ print(f"Token is valid (permission: {permission}).")
312
+
313
+ if add_to_git_credential:
314
+ if _is_git_credential_helper_configured():
315
+ set_git_credential(token)
316
+ print(
317
+ "Your token has been saved in your configured git credential helpers"
318
+ + f" ({','.join(list_credential_helpers())})."
319
+ )
320
+ else:
321
+ print("Token has not been saved to git credential helper.")
322
+
323
+ # Save token
324
+ path = Path(constants.HF_TOKEN_PATH)
325
+ path.parent.mkdir(parents=True, exist_ok=True)
326
+ path.write_text(token)
327
+ print(f"Your token has been saved to {constants.HF_TOKEN_PATH}")
328
+ print("Login successful")
329
+
330
+
331
+ def _current_token_okay(write_permission: bool = False):
332
+ """Check if the current token is valid.
333
+
334
+ Args:
335
+ write_permission (`bool`, defaults to `False`):
336
+ If `True`, requires a token with write permission.
337
+
338
+ Returns:
339
+ `bool`: `True` if the current token is valid, `False` otherwise.
340
+ """
341
+ from .hf_api import get_token_permission # avoid circular import
342
+
343
+ permission = get_token_permission()
344
+ if permission is None or (write_permission and permission != "write"):
345
+ return False
346
+ return True
347
+
348
+
349
+ def _is_git_credential_helper_configured() -> bool:
350
+ """Check if a git credential helper is configured.
351
+
352
+ Warns user if not the case (except for Google Colab where "store" is set by default
353
+ by `huggingface_hub`).
354
+ """
355
+ helpers = list_credential_helpers()
356
+ if len(helpers) > 0:
357
+ return True # Do not warn: at least 1 helper is set
358
+
359
+ # Only in Google Colab to avoid the warning message
360
+ # See https://github.com/huggingface/huggingface_hub/issues/1043#issuecomment-1247010710
361
+ if is_google_colab():
362
+ _set_store_as_git_credential_helper_globally()
363
+ return True # Do not warn: "store" is used by default in Google Colab
364
+
365
+ # Otherwise, warn user
366
+ print(
367
+ ANSI.red(
368
+ "Cannot authenticate through git-credential as no helper is defined on your"
369
+ " machine.\nYou might have to re-authenticate when pushing to the Hugging"
370
+ " Face Hub.\nRun the following command in your terminal in case you want to"
371
+ " set the 'store' credential helper as default.\n\ngit config --global"
372
+ " credential.helper store\n\nRead"
373
+ " https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more"
374
+ " details."
375
+ )
376
+ )
377
+ return False
378
+
379
+
380
+ def _set_store_as_git_credential_helper_globally() -> None:
381
+ """Set globally the credential.helper to `store`.
382
+
383
+ To be used only in Google Colab as we assume the user doesn't care about the git
384
+ credential config. It is the only particular case where we don't want to display the
385
+ warning message in [`notebook_login()`].
386
+
387
+ Related:
388
+ - https://github.com/huggingface/huggingface_hub/issues/1043
389
+ - https://github.com/huggingface/huggingface_hub/issues/1051
390
+ - https://git-scm.com/docs/git-credential-store
391
+ """
392
+ try:
393
+ run_subprocess("git config --global credential.helper store")
394
+ except subprocess.CalledProcessError as exc:
395
+ raise EnvironmentError(exc.stderr)
lib/python3.11/site-packages/huggingface_hub/_multi_commits.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains utilities to multi-commits (i.e. push changes iteratively on a PR)."""
16
+ import re
17
+ from dataclasses import dataclass, field
18
+ from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple, Union
19
+
20
+ from ._commit_api import CommitOperationAdd, CommitOperationDelete
21
+ from .community import DiscussionWithDetails
22
+ from .utils import experimental
23
+ from .utils._cache_manager import _format_size
24
+ from .utils.insecure_hashlib import sha256
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from .hf_api import HfApi
29
+
30
+
31
+ class MultiCommitException(Exception):
32
+ """Base exception for any exception happening while doing a multi-commit."""
33
+
34
+
35
+ MULTI_COMMIT_PR_DESCRIPTION_TEMPLATE = """
36
+ ## {commit_message}
37
+
38
+ {commit_description}
39
+
40
+ **Multi commit ID:** {multi_commit_id}
41
+
42
+ Scheduled commits:
43
+
44
+ {multi_commit_strategy}
45
+
46
+ _This is a PR opened using the `huggingface_hub` library in the context of a multi-commit. PR can be commented as a usual PR. However, please be aware that manually updating the PR description, changing the PR status, or pushing new commits, is not recommended as it might corrupt the commit process. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
47
+ """
48
+
49
+ MULTI_COMMIT_PR_COMPLETION_COMMENT_TEMPLATE = """
50
+ Multi-commit is now completed! You can ping the repo owner to review the changes. This PR can now be commented or modified without risking to corrupt it.
51
+
52
+ _This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
53
+ """
54
+
55
+ MULTI_COMMIT_PR_CLOSING_COMMENT_TEMPLATE = """
56
+ `create_pr=False` has been passed so PR is automatically merged.
57
+
58
+ _This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
59
+ """
60
+
61
+ MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_NO_CHANGES_TEMPLATE = """
62
+ Cannot merge Pull Requests as no changes are associated. This PR will be closed automatically.
63
+
64
+ _This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
65
+ """
66
+
67
+ MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_BAD_REQUEST_TEMPLATE = """
68
+ An error occurred while trying to merge the Pull Request: `{error_message}`.
69
+
70
+ _This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
71
+ """
72
+
73
+
74
+ STEP_ID_REGEX = re.compile(r"- \[(?P<completed>[ |x])\].*(?P<step_id>[a-fA-F0-9]{64})", flags=re.MULTILINE)
75
+
76
+
77
+ @experimental
78
+ def plan_multi_commits(
79
+ operations: Iterable[Union[CommitOperationAdd, CommitOperationDelete]],
80
+ max_operations_per_commit: int = 50,
81
+ max_upload_size_per_commit: int = 2 * 1024 * 1024 * 1024,
82
+ ) -> Tuple[List[List[CommitOperationAdd]], List[List[CommitOperationDelete]]]:
83
+ """Split a list of operations in a list of commits to perform.
84
+
85
+ Implementation follows a sub-optimal (yet simple) algorithm:
86
+ 1. Delete operations are grouped together by commits of maximum `max_operations_per_commits` operations.
87
+ 2. All additions exceeding `max_upload_size_per_commit` are committed 1 by 1.
88
+ 3. All remaining additions are grouped together and split each time the `max_operations_per_commit` or the
89
+ `max_upload_size_per_commit` limit is reached.
90
+
91
+ We do not try to optimize the splitting to get the lowest number of commits as this is a NP-hard problem (see
92
+ [bin packing problem](https://en.wikipedia.org/wiki/Bin_packing_problem)). For our use case, it is not problematic
93
+ to use a sub-optimal solution so we favored an easy-to-explain implementation.
94
+
95
+ Args:
96
+ operations (`List` of [`~hf_api.CommitOperation`]):
97
+ The list of operations to split into commits.
98
+ max_operations_per_commit (`int`):
99
+ Maximum number of operations in a single commit. Defaults to 50.
100
+ max_upload_size_per_commit (`int`):
101
+ Maximum size to upload (in bytes) in a single commit. Defaults to 2GB. Files bigger than this limit are
102
+ uploaded, 1 per commit.
103
+
104
+ Returns:
105
+ `Tuple[List[List[CommitOperationAdd]], List[List[CommitOperationDelete]]]`: a tuple. First item is a list of
106
+ lists of [`CommitOperationAdd`] representing the addition commits to push. The second item is a list of lists
107
+ of [`CommitOperationDelete`] representing the deletion commits.
108
+
109
+ <Tip warning={true}>
110
+
111
+ `plan_multi_commits` is experimental. Its API and behavior is subject to change in the future without prior notice.
112
+
113
+ </Tip>
114
+
115
+ Example:
116
+ ```python
117
+ >>> from huggingface_hub import HfApi, plan_multi_commits
118
+ >>> addition_commits, deletion_commits = plan_multi_commits(
119
+ ... operations=[
120
+ ... CommitOperationAdd(...),
121
+ ... CommitOperationAdd(...),
122
+ ... CommitOperationDelete(...),
123
+ ... CommitOperationDelete(...),
124
+ ... CommitOperationAdd(...),
125
+ ... ],
126
+ ... )
127
+ >>> HfApi().create_commits_on_pr(
128
+ ... repo_id="my-cool-model",
129
+ ... addition_commits=addition_commits,
130
+ ... deletion_commits=deletion_commits,
131
+ ... (...)
132
+ ... verbose=True,
133
+ ... )
134
+ ```
135
+
136
+ <Tip warning={true}>
137
+
138
+ The initial order of the operations is not guaranteed! All deletions will be performed before additions. If you are
139
+ not updating multiple times the same file, you are fine.
140
+
141
+ </Tip>
142
+ """
143
+ addition_commits: List[List[CommitOperationAdd]] = []
144
+ deletion_commits: List[List[CommitOperationDelete]] = []
145
+
146
+ additions: List[CommitOperationAdd] = []
147
+ additions_size = 0
148
+ deletions: List[CommitOperationDelete] = []
149
+ for op in operations:
150
+ if isinstance(op, CommitOperationDelete):
151
+ # Group delete operations together
152
+ deletions.append(op)
153
+ if len(deletions) >= max_operations_per_commit:
154
+ deletion_commits.append(deletions)
155
+ deletions = []
156
+
157
+ elif op.upload_info.size >= max_upload_size_per_commit:
158
+ # Upload huge files 1 by 1
159
+ addition_commits.append([op])
160
+
161
+ elif additions_size + op.upload_info.size < max_upload_size_per_commit:
162
+ # Group other additions and split if size limit is reached (either max_nb_files or max_upload_size)
163
+ additions.append(op)
164
+ additions_size += op.upload_info.size
165
+
166
+ else:
167
+ addition_commits.append(additions)
168
+ additions = [op]
169
+ additions_size = op.upload_info.size
170
+
171
+ if len(additions) >= max_operations_per_commit:
172
+ addition_commits.append(additions)
173
+ additions = []
174
+ additions_size = 0
175
+
176
+ if len(additions) > 0:
177
+ addition_commits.append(additions)
178
+ if len(deletions) > 0:
179
+ deletion_commits.append(deletions)
180
+
181
+ return addition_commits, deletion_commits
182
+
183
+
184
+ @dataclass
185
+ class MultiCommitStep:
186
+ """Dataclass containing a list of CommitOperation to commit at once.
187
+
188
+ A [`MultiCommitStep`] is one atomic part of a [`MultiCommitStrategy`]. Each step is identified by its own
189
+ deterministic ID based on the list of commit operations (hexadecimal sha256). ID is persistent between re-runs if
190
+ the list of commits is kept the same.
191
+ """
192
+
193
+ operations: List[Union[CommitOperationAdd, CommitOperationDelete]]
194
+
195
+ id: str = field(init=False)
196
+ completed: bool = False
197
+
198
+ def __post_init__(self) -> None:
199
+ if len(self.operations) == 0:
200
+ raise ValueError("A MultiCommitStep must have at least 1 commit operation, got 0.")
201
+
202
+ # Generate commit id
203
+ sha = sha256()
204
+ for op in self.operations:
205
+ if isinstance(op, CommitOperationAdd):
206
+ sha.update(b"ADD")
207
+ sha.update(op.path_in_repo.encode())
208
+ sha.update(op.upload_info.sha256)
209
+ elif isinstance(op, CommitOperationDelete):
210
+ sha.update(b"DELETE")
211
+ sha.update(op.path_in_repo.encode())
212
+ sha.update(str(op.is_folder).encode())
213
+ else:
214
+ NotImplementedError()
215
+ self.id = sha.hexdigest()
216
+
217
+ def __str__(self) -> str:
218
+ """Format a step for PR description.
219
+
220
+ Formatting can be changed in the future as long as it is single line, starts with `- [ ]`/`- [x]` and contains
221
+ `self.id`. Must be able to match `STEP_ID_REGEX`.
222
+ """
223
+ additions = [op for op in self.operations if isinstance(op, CommitOperationAdd)]
224
+ file_deletions = [op for op in self.operations if isinstance(op, CommitOperationDelete) and not op.is_folder]
225
+ folder_deletions = [op for op in self.operations if isinstance(op, CommitOperationDelete) and op.is_folder]
226
+ if len(additions) > 0:
227
+ return (
228
+ f"- [{'x' if self.completed else ' '}] Upload {len(additions)} file(s) "
229
+ f"totalling {_format_size(sum(add.upload_info.size for add in additions))}"
230
+ f" ({self.id})"
231
+ )
232
+ else:
233
+ return (
234
+ f"- [{'x' if self.completed else ' '}] Delete {len(file_deletions)} file(s) and"
235
+ f" {len(folder_deletions)} folder(s) ({self.id})"
236
+ )
237
+
238
+
239
+ @dataclass
240
+ class MultiCommitStrategy:
241
+ """Dataclass containing a list of [`MultiCommitStep`] to commit iteratively.
242
+
243
+ A strategy is identified by its own deterministic ID based on the list of its steps (hexadecimal sha256). ID is
244
+ persistent between re-runs if the list of commits is kept the same.
245
+ """
246
+
247
+ addition_commits: List[MultiCommitStep]
248
+ deletion_commits: List[MultiCommitStep]
249
+
250
+ id: str = field(init=False)
251
+ all_steps: Set[str] = field(init=False)
252
+
253
+ def __post_init__(self) -> None:
254
+ self.all_steps = {step.id for step in self.addition_commits + self.deletion_commits}
255
+ if len(self.all_steps) < len(self.addition_commits) + len(self.deletion_commits):
256
+ raise ValueError("Got duplicate commits in MultiCommitStrategy. All commits must be unique.")
257
+
258
+ if len(self.all_steps) == 0:
259
+ raise ValueError("A MultiCommitStrategy must have at least 1 commit, got 0.")
260
+
261
+ # Generate strategy id
262
+ sha = sha256()
263
+ for step in self.addition_commits + self.deletion_commits:
264
+ sha.update("new step".encode())
265
+ sha.update(step.id.encode())
266
+ self.id = sha.hexdigest()
267
+
268
+
269
+ def multi_commit_create_pull_request(
270
+ api: "HfApi",
271
+ repo_id: str,
272
+ commit_message: str,
273
+ commit_description: Optional[str],
274
+ strategy: MultiCommitStrategy,
275
+ token: Optional[str],
276
+ repo_type: Optional[str],
277
+ ) -> DiscussionWithDetails:
278
+ return api.create_pull_request(
279
+ repo_id=repo_id,
280
+ title=f"[WIP] {commit_message} (multi-commit {strategy.id})",
281
+ description=multi_commit_generate_comment(
282
+ commit_message=commit_message, commit_description=commit_description, strategy=strategy
283
+ ),
284
+ token=token,
285
+ repo_type=repo_type,
286
+ )
287
+
288
+
289
+ def multi_commit_generate_comment(
290
+ commit_message: str,
291
+ commit_description: Optional[str],
292
+ strategy: MultiCommitStrategy,
293
+ ) -> str:
294
+ return MULTI_COMMIT_PR_DESCRIPTION_TEMPLATE.format(
295
+ commit_message=commit_message,
296
+ commit_description=commit_description or "",
297
+ multi_commit_id=strategy.id,
298
+ multi_commit_strategy="\n".join(
299
+ str(commit) for commit in strategy.deletion_commits + strategy.addition_commits
300
+ ),
301
+ )
302
+
303
+
304
+ def multi_commit_parse_pr_description(description: str) -> Set[str]:
305
+ return {match[1] for match in STEP_ID_REGEX.findall(description)}
lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Dict, List, Literal, Optional, Union
4
+
5
+ import requests
6
+ from tqdm.auto import tqdm as base_tqdm
7
+ from tqdm.contrib.concurrent import thread_map
8
+
9
+ from .constants import (
10
+ DEFAULT_ETAG_TIMEOUT,
11
+ DEFAULT_REVISION,
12
+ HF_HUB_CACHE,
13
+ HF_HUB_ENABLE_HF_TRANSFER,
14
+ REPO_TYPES,
15
+ )
16
+ from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
17
+ from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
18
+ from .utils import (
19
+ GatedRepoError,
20
+ LocalEntryNotFoundError,
21
+ OfflineModeIsEnabled,
22
+ RepositoryNotFoundError,
23
+ RevisionNotFoundError,
24
+ filter_repo_objects,
25
+ logging,
26
+ validate_hf_hub_args,
27
+ )
28
+ from .utils import tqdm as hf_tqdm
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ @validate_hf_hub_args
35
+ def snapshot_download(
36
+ repo_id: str,
37
+ *,
38
+ repo_type: Optional[str] = None,
39
+ revision: Optional[str] = None,
40
+ cache_dir: Union[str, Path, None] = None,
41
+ local_dir: Union[str, Path, None] = None,
42
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
43
+ library_name: Optional[str] = None,
44
+ library_version: Optional[str] = None,
45
+ user_agent: Optional[Union[Dict, str]] = None,
46
+ proxies: Optional[Dict] = None,
47
+ etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
48
+ resume_download: bool = False,
49
+ force_download: bool = False,
50
+ token: Optional[Union[bool, str]] = None,
51
+ local_files_only: bool = False,
52
+ allow_patterns: Optional[Union[List[str], str]] = None,
53
+ ignore_patterns: Optional[Union[List[str], str]] = None,
54
+ max_workers: int = 8,
55
+ tqdm_class: Optional[base_tqdm] = None,
56
+ endpoint: Optional[str] = None,
57
+ ) -> str:
58
+ """Download repo files.
59
+
60
+ Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
61
+ a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
62
+ to keep their actual filename relative to that folder. You can also filter which files to download using
63
+ `allow_patterns` and `ignore_patterns`.
64
+
65
+ If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure
66
+ how you want to move those files:
67
+ - If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob
68
+ files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal
69
+ is to be able to manually edit and save small files without corrupting the cache while saving disk space for
70
+ binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`
71
+ environment variable.
72
+ - If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`.
73
+ This is optimal in term of disk usage but files must not be manually edited.
74
+ - If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the
75
+ local dir. This means disk usage is not optimized.
76
+ - Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the
77
+ files are downloaded and directly placed under `local_dir`. This means if you need to download them again later,
78
+ they will be re-downloaded entirely.
79
+
80
+ An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
81
+ configured. It is also not possible to filter which files to download when cloning a repository using git.
82
+
83
+ Args:
84
+ repo_id (`str`):
85
+ A user or an organization name and a repo name separated by a `/`.
86
+ repo_type (`str`, *optional*):
87
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
88
+ `None` or `"model"` if downloading from a model. Default is `None`.
89
+ revision (`str`, *optional*):
90
+ An optional Git revision id which can be a branch name, a tag, or a
91
+ commit hash.
92
+ cache_dir (`str`, `Path`, *optional*):
93
+ Path to the folder where cached files are stored.
94
+ local_dir (`str` or `Path`, *optional*):
95
+ If provided, the downloaded files will be placed under this directory, either as symlinks (default) or
96
+ regular files (see description for more details).
97
+ local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`):
98
+ To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either
99
+ duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be
100
+ created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if
101
+ already exists) or downloaded from the Hub and not cached. See description for more details.
102
+ library_name (`str`, *optional*):
103
+ The name of the library to which the object corresponds.
104
+ library_version (`str`, *optional*):
105
+ The version of the library.
106
+ user_agent (`str`, `dict`, *optional*):
107
+ The user-agent info in the form of a dictionary or a string.
108
+ proxies (`dict`, *optional*):
109
+ Dictionary mapping protocol to the URL of the proxy passed to
110
+ `requests.request`.
111
+ etag_timeout (`float`, *optional*, defaults to `10`):
112
+ When fetching ETag, how many seconds to wait for the server to send
113
+ data before giving up which is passed to `requests.request`.
114
+ resume_download (`bool`, *optional*, defaults to `False):
115
+ If `True`, resume a previously interrupted download.
116
+ force_download (`bool`, *optional*, defaults to `False`):
117
+ Whether the file should be downloaded even if it already exists in the local cache.
118
+ token (`str`, `bool`, *optional*):
119
+ A token to be used for the download.
120
+ - If `True`, the token is read from the HuggingFace config
121
+ folder.
122
+ - If a string, it's used as the authentication token.
123
+ local_files_only (`bool`, *optional*, defaults to `False`):
124
+ If `True`, avoid downloading the file and return the path to the
125
+ local cached file if it exists.
126
+ allow_patterns (`List[str]` or `str`, *optional*):
127
+ If provided, only files matching at least one pattern are downloaded.
128
+ ignore_patterns (`List[str]` or `str`, *optional*):
129
+ If provided, files matching any of the patterns are not downloaded.
130
+ max_workers (`int`, *optional*):
131
+ Number of concurrent threads to download files (1 thread = 1 file download).
132
+ Defaults to 8.
133
+ tqdm_class (`tqdm`, *optional*):
134
+ If provided, overwrites the default behavior for the progress bar. Passed
135
+ argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
136
+ Note that the `tqdm_class` is not passed to each individual download.
137
+ Defaults to the custom HF progress bar that can be disabled by setting
138
+ `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
139
+
140
+ Returns:
141
+ Local folder path (string) of repo snapshot
142
+
143
+ <Tip>
144
+
145
+ Raises the following errors:
146
+
147
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
148
+ if `token=True` and the token cannot be found.
149
+ - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
150
+ ETag cannot be determined.
151
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
152
+ if some parameter value is invalid
153
+
154
+ </Tip>
155
+ """
156
+ if cache_dir is None:
157
+ cache_dir = HF_HUB_CACHE
158
+ if revision is None:
159
+ revision = DEFAULT_REVISION
160
+ if isinstance(cache_dir, Path):
161
+ cache_dir = str(cache_dir)
162
+
163
+ if repo_type is None:
164
+ repo_type = "model"
165
+ if repo_type not in REPO_TYPES:
166
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
167
+
168
+ storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
169
+
170
+ repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
171
+ api_call_error: Optional[Exception] = None
172
+ if not local_files_only:
173
+ # try/except logic to handle different errors => taken from `hf_hub_download`
174
+ try:
175
+ # if we have internet connection we want to list files to download
176
+ api = HfApi(
177
+ library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint
178
+ )
179
+ repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
180
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
181
+ # Actually raise for those subclasses of ConnectionError
182
+ raise
183
+ except (
184
+ requests.exceptions.ConnectionError,
185
+ requests.exceptions.Timeout,
186
+ OfflineModeIsEnabled,
187
+ ) as error:
188
+ # Internet connection is down
189
+ # => will try to use local files only
190
+ api_call_error = error
191
+ pass
192
+ except RevisionNotFoundError:
193
+ # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
194
+ raise
195
+ except requests.HTTPError as error:
196
+ # Multiple reasons for an http error:
197
+ # - Repository is private and invalid/missing token sent
198
+ # - Repository is gated and invalid/missing token sent
199
+ # - Hub is down (error 500 or 504)
200
+ # => let's switch to 'local_files_only=True' to check if the files are already cached.
201
+ # (if it's not the case, the error will be re-raised)
202
+ api_call_error = error
203
+ pass
204
+
205
+ # At this stage, if `repo_info` is None it means either:
206
+ # - internet connection is down
207
+ # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
208
+ # - repo is private/gated and invalid/missing token sent
209
+ # - Hub is down
210
+ # => let's look if we can find the appropriate folder in the cache:
211
+ # - if the specified revision is a commit hash, look inside "snapshots".
212
+ # - f the specified revision is a branch or tag, look inside "refs".
213
+ if repo_info is None:
214
+ # Try to get which commit hash corresponds to the specified revision
215
+ commit_hash = None
216
+ if REGEX_COMMIT_HASH.match(revision):
217
+ commit_hash = revision
218
+ else:
219
+ ref_path = os.path.join(storage_folder, "refs", revision)
220
+ if os.path.exists(ref_path):
221
+ # retrieve commit_hash from refs file
222
+ with open(ref_path) as f:
223
+ commit_hash = f.read()
224
+
225
+ # Try to locate snapshot folder for this commit hash
226
+ if commit_hash is not None:
227
+ snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
228
+ if os.path.exists(snapshot_folder):
229
+ # Snapshot folder exists => let's return it
230
+ # (but we can't check if all the files are actually there)
231
+ return snapshot_folder
232
+
233
+ # If we couldn't find the appropriate folder on disk, raise an error.
234
+ if local_files_only:
235
+ raise LocalEntryNotFoundError(
236
+ "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
237
+ "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
238
+ "'local_files_only=False' as input."
239
+ )
240
+ elif isinstance(api_call_error, OfflineModeIsEnabled):
241
+ raise LocalEntryNotFoundError(
242
+ "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
243
+ "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
244
+ "'HF_HUB_OFFLINE=0' as environment variable."
245
+ ) from api_call_error
246
+ elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
247
+ # Repo not found => let's raise the actual error
248
+ raise api_call_error
249
+ else:
250
+ # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
251
+ raise LocalEntryNotFoundError(
252
+ "An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
253
+ " snapshot folder for the specified revision on the local disk. Please check your internet connection"
254
+ " and try again."
255
+ ) from api_call_error
256
+
257
+ # At this stage, internet connection is up and running
258
+ # => let's download the files!
259
+ assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
260
+ assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
261
+ filtered_repo_files = list(
262
+ filter_repo_objects(
263
+ items=[f.rfilename for f in repo_info.siblings],
264
+ allow_patterns=allow_patterns,
265
+ ignore_patterns=ignore_patterns,
266
+ )
267
+ )
268
+ commit_hash = repo_info.sha
269
+ snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
270
+ # if passed revision is not identical to commit_hash
271
+ # then revision has to be a branch name or tag name.
272
+ # In that case store a ref.
273
+ if revision != commit_hash:
274
+ ref_path = os.path.join(storage_folder, "refs", revision)
275
+ os.makedirs(os.path.dirname(ref_path), exist_ok=True)
276
+ with open(ref_path, "w") as f:
277
+ f.write(commit_hash)
278
+
279
+ # we pass the commit_hash to hf_hub_download
280
+ # so no network call happens if we already
281
+ # have the file locally.
282
+ def _inner_hf_hub_download(repo_file: str):
283
+ return hf_hub_download(
284
+ repo_id,
285
+ filename=repo_file,
286
+ repo_type=repo_type,
287
+ revision=commit_hash,
288
+ endpoint=endpoint,
289
+ cache_dir=cache_dir,
290
+ local_dir=local_dir,
291
+ local_dir_use_symlinks=local_dir_use_symlinks,
292
+ library_name=library_name,
293
+ library_version=library_version,
294
+ user_agent=user_agent,
295
+ proxies=proxies,
296
+ etag_timeout=etag_timeout,
297
+ resume_download=resume_download,
298
+ force_download=force_download,
299
+ token=token,
300
+ )
301
+
302
+ if HF_HUB_ENABLE_HF_TRANSFER:
303
+ # when using hf_transfer we don't want extra parallelism
304
+ # from the one hf_transfer provides
305
+ for file in filtered_repo_files:
306
+ _inner_hf_hub_download(file)
307
+ else:
308
+ thread_map(
309
+ _inner_hf_hub_download,
310
+ filtered_repo_files,
311
+ desc=f"Fetching {len(filtered_repo_files)} files",
312
+ max_workers=max_workers,
313
+ # User can use its own tqdm class or the default one from `huggingface_hub.utils`
314
+ tqdm_class=tqdm_class or hf_tqdm,
315
+ )
316
+
317
+ if local_dir is not None:
318
+ return str(os.path.realpath(local_dir))
319
+ return snapshot_folder
lib/python3.11/site-packages/huggingface_hub/_space_api.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from datetime import datetime
17
+ from enum import Enum
18
+ from typing import Dict, Optional
19
+
20
+ from huggingface_hub.utils import parse_datetime
21
+
22
+
23
+ class SpaceStage(str, Enum):
24
+ """
25
+ Enumeration of possible stage of a Space on the Hub.
26
+
27
+ Value can be compared to a string:
28
+ ```py
29
+ assert SpaceStage.BUILDING == "BUILDING"
30
+ ```
31
+
32
+ Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L61 (private url).
33
+ """
34
+
35
+ # Copied from moon-landing > server > repo_types > SpaceInfo.ts (private repo)
36
+ NO_APP_FILE = "NO_APP_FILE"
37
+ CONFIG_ERROR = "CONFIG_ERROR"
38
+ BUILDING = "BUILDING"
39
+ BUILD_ERROR = "BUILD_ERROR"
40
+ RUNNING = "RUNNING"
41
+ RUNNING_BUILDING = "RUNNING_BUILDING"
42
+ RUNTIME_ERROR = "RUNTIME_ERROR"
43
+ DELETING = "DELETING"
44
+ STOPPED = "STOPPED"
45
+ PAUSED = "PAUSED"
46
+
47
+
48
+ class SpaceHardware(str, Enum):
49
+ """
50
+ Enumeration of hardwares available to run your Space on the Hub.
51
+
52
+ Value can be compared to a string:
53
+ ```py
54
+ assert SpaceHardware.CPU_BASIC == "cpu-basic"
55
+ ```
56
+
57
+ Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L73 (private url).
58
+ """
59
+
60
+ CPU_BASIC = "cpu-basic"
61
+ CPU_UPGRADE = "cpu-upgrade"
62
+ T4_SMALL = "t4-small"
63
+ T4_MEDIUM = "t4-medium"
64
+ ZERO_A10G = "zero-a10g"
65
+ A10G_SMALL = "a10g-small"
66
+ A10G_LARGE = "a10g-large"
67
+ A10G_LARGEX2 = "a10g-largex2"
68
+ A10G_LARGEX4 = "a10g-largex4"
69
+ A100_LARGE = "a100-large"
70
+
71
+
72
+ class SpaceStorage(str, Enum):
73
+ """
74
+ Enumeration of persistent storage available for your Space on the Hub.
75
+
76
+ Value can be compared to a string:
77
+ ```py
78
+ assert SpaceStorage.SMALL == "small"
79
+ ```
80
+
81
+ Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts#L24 (private url).
82
+ """
83
+
84
+ SMALL = "small"
85
+ MEDIUM = "medium"
86
+ LARGE = "large"
87
+
88
+
89
+ @dataclass
90
+ class SpaceRuntime:
91
+ """
92
+ Contains information about the current runtime of a Space.
93
+
94
+ Args:
95
+ stage (`str`):
96
+ Current stage of the space. Example: RUNNING.
97
+ hardware (`str` or `None`):
98
+ Current hardware of the space. Example: "cpu-basic". Can be `None` if Space
99
+ is `BUILDING` for the first time.
100
+ requested_hardware (`str` or `None`):
101
+ Requested hardware. Can be different than `hardware` especially if the request
102
+ has just been made. Example: "t4-medium". Can be `None` if no hardware has
103
+ been requested yet.
104
+ sleep_time (`int` or `None`):
105
+ Number of seconds the Space will be kept alive after the last request. By default (if value is `None`), the
106
+ Space will never go to sleep if it's running on an upgraded hardware, while it will go to sleep after 48
107
+ hours on a free 'cpu-basic' hardware. For more details, see https://huggingface.co/docs/hub/spaces-gpus#sleep-time.
108
+ raw (`dict`):
109
+ Raw response from the server. Contains more information about the Space
110
+ runtime like number of replicas, number of cpu, memory size,...
111
+ """
112
+
113
+ stage: SpaceStage
114
+ hardware: Optional[SpaceHardware]
115
+ requested_hardware: Optional[SpaceHardware]
116
+ sleep_time: Optional[int]
117
+ storage: Optional[SpaceStorage]
118
+ raw: Dict
119
+
120
+ def __init__(self, data: Dict) -> None:
121
+ self.stage = data["stage"]
122
+ self.hardware = data.get("hardware", {}).get("current")
123
+ self.requested_hardware = data.get("hardware", {}).get("requested")
124
+ self.sleep_time = data.get("gcTimeout")
125
+ self.storage = data.get("storage")
126
+ self.raw = data
127
+
128
+
129
+ @dataclass
130
+ class SpaceVariable:
131
+ """
132
+ Contains information about the current variables of a Space.
133
+
134
+ Args:
135
+ key (`str`):
136
+ Variable key. Example: `"MODEL_REPO_ID"`
137
+ value (`str`):
138
+ Variable value. Example: `"the_model_repo_id"`.
139
+ description (`str` or None):
140
+ Description of the variable. Example: `"Model Repo ID of the implemented model"`.
141
+ updatedAt (`datetime`):
142
+ datetime of the last update of the variable.
143
+ """
144
+
145
+ key: str
146
+ value: str
147
+ description: Optional[str]
148
+ updated_at: datetime
149
+
150
+ def __init__(self, key: str, values: Dict) -> None:
151
+ self.key = key
152
+ self.value = values["value"]
153
+ self.description = values.get("description")
154
+ self.updated_at = parse_datetime(values["updatedAt"])