Elron commited on
Commit
75da725
1 Parent(s): 938f93e

Upload dataclass.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataclass.py +71 -10
dataclass.py CHANGED
@@ -27,10 +27,12 @@ class Field:
27
  name: str = None
28
  type: type = None
29
  init: bool = True
 
30
  default_factory: Any = None
31
  final: bool = False
32
  abstract: bool = False
33
  required: bool = False
 
34
  origin_cls: type = None
35
 
36
  def get_default(self):
@@ -58,6 +60,20 @@ class AbstractField(Field):
58
  self.abstract = True
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  class FinalFieldError(TypeError):
62
  pass
63
 
@@ -74,6 +90,10 @@ class TypeMismatchError(TypeError):
74
  pass
75
 
76
 
 
 
 
 
77
  standart_variables = dir(object)
78
 
79
 
@@ -102,17 +122,21 @@ def get_fields(cls, attrs):
102
  Returns:
103
  dict: A dictionary mapping field names to Field instances.
104
  """
105
-
106
- fields = {**getattr(cls, _FIELDS, {})}
 
107
  annotations = {**attrs.get("__annotations__", {})}
108
 
109
  for attr_name, attr_value in attrs.items():
110
  if attr_name not in annotations and is_possible_field(attr_name, attr_value):
111
  if attr_name in fields:
112
- if not isinstance(attr_value, fields[attr_name].type):
113
- raise TypeMismatchError(
114
- f"Type mismatch for field '{attr_name}' of class '{fields[attr_name].origin_cls}'. Expected {fields[attr_name].type}, got {type(attr_value)}"
115
- )
 
 
 
116
  annotations[attr_name] = fields[attr_name].type
117
 
118
  for field_name, field_type in annotations.items():
@@ -150,6 +174,10 @@ def get_fields(cls, attrs):
150
  field_instance = Field(**args)
151
  fields[field_name] = field_instance
152
 
 
 
 
 
153
  return fields
154
 
155
 
@@ -296,15 +324,46 @@ class Dataclass(metaclass=DataclassMeta):
296
 
297
  """
298
 
 
 
299
  @final
300
- def __init__(self, *args, **kwargs):
301
  """
302
  Initialize fields based on kwargs.
303
  Checks for abstract fields when an instance is created.
304
  """
305
- init_fields = [field for field in fields(self) if field.init]
306
- for field, arg in zip(init_fields, args):
307
- kwargs[field.name] = arg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  for field in abstract_fields(self):
310
  raise AbstractFieldError(
@@ -321,6 +380,8 @@ class Dataclass(metaclass=DataclassMeta):
321
  if field.name in kwargs:
322
  setattr(self, field.name, kwargs[field.name])
323
  else:
 
 
324
  setattr(self, field.name, get_field_default(field))
325
 
326
  self.__post_init__()
 
27
  name: str = None
28
  type: type = None
29
  init: bool = True
30
+ also_positional: bool = True
31
  default_factory: Any = None
32
  final: bool = False
33
  abstract: bool = False
34
  required: bool = False
35
+ internal: bool = False
36
  origin_cls: type = None
37
 
38
  def get_default(self):
 
60
  self.abstract = True
61
 
62
 
63
+ @dataclasses.dataclass
64
+ class NonPositionalField(Field):
65
+ def __post_init__(self):
66
+ self.also_positional = False
67
+
68
+
69
+ @dataclasses.dataclass
70
+ class InternalField(Field):
71
+ def __post_init__(self):
72
+ self.internal = True
73
+ self.init = False
74
+ self.also_positional = False
75
+
76
+
77
  class FinalFieldError(TypeError):
78
  pass
79
 
 
90
  pass
91
 
92
 
93
+ class UnexpectedArgumentError(TypeError):
94
+ pass
95
+
96
+
97
  standart_variables = dir(object)
98
 
99
 
 
122
  Returns:
123
  dict: A dictionary mapping field names to Field instances.
124
  """
125
+ fields = {}
126
+ for base in cls.__bases__:
127
+ fields = {**getattr(base, _FIELDS, {}), **fields}
128
  annotations = {**attrs.get("__annotations__", {})}
129
 
130
  for attr_name, attr_value in attrs.items():
131
  if attr_name not in annotations and is_possible_field(attr_name, attr_value):
132
  if attr_name in fields:
133
+ try:
134
+ if not isinstance(attr_value, fields[attr_name].type):
135
+ raise TypeMismatchError(
136
+ f"Type mismatch for field '{attr_name}' of class '{fields[attr_name].origin_cls}'. Expected {fields[attr_name].type}, got {type(attr_value)}"
137
+ )
138
+ except TypeError:
139
+ pass
140
  annotations[attr_name] = fields[attr_name].type
141
 
142
  for field_name, field_type in annotations.items():
 
174
  field_instance = Field(**args)
175
  fields[field_name] = field_instance
176
 
177
+ if cls.__allow_unexpected_arguments__:
178
+ fields["_argv"] = InternalField(name="_argv", type=tuple, default=())
179
+ fields["_kwargs"] = InternalField(name="_kwargs", type=dict, default={})
180
+
181
  return fields
182
 
183
 
 
324
 
325
  """
326
 
327
+ __allow_unexpected_arguments__ = False
328
+
329
  @final
330
+ def __init__(self, *argv, **kwargs):
331
  """
332
  Initialize fields based on kwargs.
333
  Checks for abstract fields when an instance is created.
334
  """
335
+ _init_fields = [field for field in fields(self) if field.init]
336
+ _init_fields_names = [field.name for field in _init_fields]
337
+ _init_positional_fields_names = [field.name for field in _init_fields if field.also_positional]
338
+
339
+ for name in _init_positional_fields_names[: len(argv)]:
340
+ if name in kwargs:
341
+ raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
342
+
343
+ if len(argv) <= len(_init_positional_fields_names):
344
+ unexpected_argv = []
345
+ else:
346
+ unexpected_argv = argv[len(_init_positional_fields_names) :]
347
+
348
+ unexpected_kwargs = {k: v for k, v in kwargs.items() if k not in _init_fields_names}
349
+
350
+ if self.__allow_unexpected_arguments__:
351
+ self._argv = unexpected_argv
352
+ self._kwargs = unexpected_kwargs
353
+
354
+ else:
355
+ if len(unexpected_argv) > 0:
356
+ raise UnexpectedArgumentError(
357
+ f"Too many positional arguments {unexpected_argv} for class {self.__class__.__name__}.\nShould be only {len(_init_positional_fields_names)} positional arguments: {', '.join(_init_positional_fields_names)}"
358
+ )
359
+
360
+ if len(unexpected_kwargs) > 0:
361
+ raise UnexpectedArgumentError(
362
+ f"Unexpected keyword argument(s) {unexpected_kwargs} for class {self.__class__.__name__}.\nShould be one of: {fields_names(self)}"
363
+ )
364
+
365
+ for name, arg in zip(_init_positional_fields_names, argv):
366
+ kwargs[name] = arg
367
 
368
  for field in abstract_fields(self):
369
  raise AbstractFieldError(
 
380
  if field.name in kwargs:
381
  setattr(self, field.name, kwargs[field.name])
382
  else:
383
+ if field.name in ["_argv", "_kwargs"] and self.__allow_unexpected_arguments__:
384
+ continue
385
  setattr(self, field.name, get_field_default(field))
386
 
387
  self.__post_init__()