Elron commited on
Commit
129c8dd
·
verified ·
1 Parent(s): e8292e5

Upload dataclass.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataclass.py +55 -5
dataclass.py CHANGED
@@ -1,11 +1,17 @@
1
  import copy
2
  import dataclasses
 
3
  from abc import ABCMeta
 
4
  from typing import Any, final
5
 
6
  _FIELDS = "__fields__"
7
 
8
 
 
 
 
 
9
  @dataclasses.dataclass
10
  class Field:
11
  """An alternative to dataclasses.dataclass decorator for a more flexible field definition.
@@ -21,7 +27,7 @@ class Field:
21
  origin_cls (type, optional): The original class that defined the field. Defaults to None.
22
  """
23
 
24
- default: Any = None
25
  name: str = None
26
  type: type = None
27
  init: bool = True
@@ -51,13 +57,18 @@ class RequiredField(Field):
51
  self.required = True
52
 
53
 
 
 
 
 
54
  @dataclasses.dataclass
55
  class OptionalField(Field):
56
  def __post_init__(self):
57
  self.required = False
58
- assert (
59
- self.default is not None or self.default_factory is not None
60
- ), "OptionalField must have default or default_factory"
 
61
 
62
 
63
  @dataclasses.dataclass
@@ -269,7 +280,46 @@ class DataclassMeta(ABCMeta):
269
  @final
270
  def __init__(cls, name, bases, attrs):
271
  super().__init__(name, bases, attrs)
272
- setattr(cls, _FIELDS, get_fields(cls, attrs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
 
275
  class Dataclass(metaclass=DataclassMeta):
 
1
  import copy
2
  import dataclasses
3
+ import functools
4
  from abc import ABCMeta
5
+ from inspect import Parameter, Signature
6
  from typing import Any, final
7
 
8
  _FIELDS = "__fields__"
9
 
10
 
11
+ class Undefined:
12
+ pass
13
+
14
+
15
  @dataclasses.dataclass
16
  class Field:
17
  """An alternative to dataclasses.dataclass decorator for a more flexible field definition.
 
27
  origin_cls (type, optional): The original class that defined the field. Defaults to None.
28
  """
29
 
30
+ default: Any = Undefined
31
  name: str = None
32
  type: type = None
33
  init: bool = True
 
57
  self.required = True
58
 
59
 
60
+ class MissingDefaultError(TypeError):
61
+ pass
62
+
63
+
64
  @dataclasses.dataclass
65
  class OptionalField(Field):
66
  def __post_init__(self):
67
  self.required = False
68
+ if self.default is Undefined and self.default_factory is None:
69
+ raise MissingDefaultError(
70
+ "OptionalField must have default or default_factory"
71
+ )
72
 
73
 
74
  @dataclasses.dataclass
 
280
  @final
281
  def __init__(cls, name, bases, attrs):
282
  super().__init__(name, bases, attrs)
283
+ fields = get_fields(cls, attrs)
284
+ setattr(cls, _FIELDS, fields)
285
+ cls.update_init_signature()
286
+
287
+ def update_init_signature(cls):
288
+ parameters = []
289
+
290
+ for name, field in getattr(cls, _FIELDS).items():
291
+ if field.init and not field.internal:
292
+ if field.default is not Undefined:
293
+ default_value = field.default
294
+ elif field.default_factory is not None:
295
+ default_value = field.default_factory()
296
+ else:
297
+ default_value = Parameter.empty
298
+
299
+ if isinstance(default_value, dataclasses._MISSING_TYPE):
300
+ default_value = Parameter.empty
301
+ param = Parameter(
302
+ name,
303
+ Parameter.POSITIONAL_OR_KEYWORD,
304
+ default=default_value,
305
+ annotation=field.type,
306
+ )
307
+ parameters.append(param)
308
+
309
+ if getattr(cls, "__allow_unexpected_arguments__", False):
310
+ parameters.append(Parameter("_argv", Parameter.VAR_POSITIONAL))
311
+ parameters.append(Parameter("_kwargs", Parameter.VAR_KEYWORD))
312
+
313
+ signature = Signature(parameters, __validate_parameters__=False)
314
+
315
+ original_init = cls.__init__
316
+
317
+ @functools.wraps(original_init)
318
+ def custom_cls_init(self, *args, **kwargs):
319
+ original_init(self, *args, **kwargs)
320
+
321
+ custom_cls_init.__signature__ = signature
322
+ cls.__init__ = custom_cls_init
323
 
324
 
325
  class Dataclass(metaclass=DataclassMeta):