Elron commited on
Commit
547609e
1 Parent(s): 31cee3d

Upload dataclass.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataclass.py +331 -72
dataclass.py CHANGED
@@ -1,89 +1,348 @@
1
- # Let's modify the code to allow the finalfield and field functions to accept the same parameters as dataclasses.field
2
- class AbstractFieldValue:
3
- def __init__(self):
4
- raise TypeError("Abstract field must be overridden in subclass")
5
-
6
  import dataclasses
 
 
 
7
 
8
- class FinalField:
9
- def __init__(self, *, default=dataclasses.MISSING, default_factory=dataclasses.MISSING,
10
- init=True, repr=True, hash=None, compare=True, metadata=None):
11
- self.field = dataclasses.field(default=default, default_factory=default_factory,
12
- init=init, repr=repr, hash=hash, compare=compare, metadata=metadata)
13
 
14
- def abstractfield():
15
- return dataclasses.field(default_factory=AbstractFieldValue)
16
 
17
- def finalfield(*, default=dataclasses.MISSING, default_factory=dataclasses.MISSING,
18
- init=True, repr=True, hash=None, compare=True, metadata=None):
19
- return FinalField(default=default, default_factory=default_factory,
20
- init=init, repr=repr, hash=hash, compare=compare, metadata=metadata)
21
 
22
- def field(*, default=dataclasses.MISSING, default_factory=dataclasses.MISSING,
23
- init=True, repr=True, hash=None, compare=True, metadata=None):
24
- return dataclasses.field(default=default, default_factory=default_factory,
25
- init=init, repr=repr, hash=hash, compare=compare, metadata=metadata)
 
 
 
 
 
 
26
 
27
- class DataclassMeta(type):
28
- def __new__(cls, name, bases, attrs):
29
- attrs['__finalfields__'] = attrs.get('__finalfields__', [])
30
- for base in bases:
31
- if issubclass(base, Dataclass) and hasattr(base, '__finalfields__'):
32
- for field in base.__finalfields__:
33
- if field in attrs:
34
- raise TypeError(f"Final field '{field}' cannot be overridden in subclass")
35
- attrs['__finalfields__'].append(field)
36
 
37
- for attr_name, attr_value in list(attrs.items()):
38
- if isinstance(attr_value, FinalField):
39
- attrs[attr_name] = attr_value.field # Replace the final field marker with the actual field
40
- attrs['__finalfields__'].append(attr_name)
 
41
 
42
- new_class = super().__new__(cls, name, bases, attrs)
43
- new_class = dataclasses.dataclass(new_class)
44
 
45
- return new_class
 
 
 
46
 
47
 
 
 
 
 
48
 
49
- class Dataclass(metaclass=DataclassMeta):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  pass
51
 
52
- if __name__ == '__main__':
53
- # Test classes
54
- class GrandparentClass(Dataclass):
55
- abstract_field: int = abstractfield()
56
- final_field: str = finalfield(default_factory=lambda: 'Hello')
57
 
58
- class ParentClass(GrandparentClass):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  pass
60
 
61
- try:
62
- class CorrectChildClass(ParentClass):
63
- abstract_field: int = 1 # This correctly overrides the abstract field
64
- correct_child_class_instance = CorrectChildClass()
65
- print(f'CorrectChildClass instance: {correct_child_class_instance} - passed')
66
- except Exception as e:
67
- print(f'CorrectChildClass: {str(e)} - failed')
68
-
69
- try:
70
- class IncorrectChildClass1(ParentClass):
71
- pass # This fails to override the abstract field
72
- print(f'IncorrectChildClass1: {IncorrectChildClass1} - passed')
73
- except Exception as e:
74
- print(f'IncorrectChildClass1: {str(e)} - failed')
75
-
76
- try:
77
- incorrect_child_class1_instance = IncorrectChildClass1()
78
- print(f'IncorrectChildClass1 instance: {incorrect_child_class1_instance} - failed')
79
- except Exception as e:
80
- print(f'IncorrectChildClass1 instantiation: {str(e)} - passed')
81
-
82
- # Testing the final field functionality
83
-
84
- try:
85
- class IncorrectChildClass2(ParentClass):
86
- final_field: str = 'Hello' # This attempts to override the final field
87
- print(f'IncorrectChildClass2: {IncorrectChildClass2} - failed')
88
- except Exception as e:
89
- print(f'IncorrectChildClass2: {str(e)} - passed')
 
1
+ import copy
 
 
 
 
2
  import dataclasses
3
+ from abc import ABCMeta
4
+ from copy import deepcopy
5
+ from typing import Any, final
6
 
7
+ _FIELDS = "__fields__"
 
 
 
 
8
 
 
 
9
 
10
+ @dataclasses.dataclass
11
+ class Field:
12
+ """
13
+ An alternative to dataclasses.dataclass decorator for a more flexible field definition.
14
 
15
+ Attributes:
16
+ default (Any, optional): Default value for the field. Defaults to None.
17
+ name (str, optional): Name of the field. Defaults to None.
18
+ type (type, optional): Type of the field. Defaults to None.
19
+ default_factory (Any, optional): A function that returns the default value. Defaults to None.
20
+ final (bool, optional): A boolean indicating if the field is final (cannot be overridden). Defaults to False.
21
+ abstract (bool, optional): A boolean indicating if the field is abstract (must be implemented by subclasses). Defaults to False.
22
+ required (bool, optional): A boolean indicating if the field is required. Defaults to False.
23
+ origin_cls (type, optional): The original class that defined the field. Defaults to None.
24
+ """
25
 
26
+ default: Any = None
27
+ name: str = None
28
+ type: type = None
29
+ init: bool = True
30
+ default_factory: Any = None
31
+ final: bool = False
32
+ abstract: bool = False
33
+ required: bool = False
34
+ origin_cls: type = None
35
 
36
+ def get_default(self):
37
+ if self.default_factory is not None:
38
+ return self.default_factory()
39
+ else:
40
+ return self.default
41
 
 
 
42
 
43
+ @dataclasses.dataclass
44
+ class FinalField(Field):
45
+ def __post_init__(self):
46
+ self.final = True
47
 
48
 
49
+ @dataclasses.dataclass
50
+ class RequiredField(Field):
51
+ def __post_init__(self):
52
+ self.required = True
53
 
54
+
55
+ @dataclasses.dataclass
56
+ class AbstractField(Field):
57
+ def __post_init__(self):
58
+ self.abstract = True
59
+
60
+
61
+ class FinalFieldError(TypeError):
62
+ pass
63
+
64
+
65
+ class RequiredFieldError(TypeError):
66
+ pass
67
+
68
+
69
+ class AbstractFieldError(TypeError):
70
  pass
71
 
 
 
 
 
 
72
 
73
+ class TypeMismatchError(TypeError):
74
+ pass
75
+
76
+
77
+ standart_variables = dir(object)
78
+
79
+
80
+ def is_possible_field(field_name, field_value):
81
+ """
82
+ Check if a name-value pair can potentially represent a field.
83
+
84
+ Args:
85
+ field_name (str): The name of the field.
86
+ field_value: The value of the field.
87
+
88
+ Returns:
89
+ bool: True if the name-value pair can represent a field, False otherwise.
90
+ """
91
+ return field_name not in standart_variables and not field_name.startswith("__") and not callable(field_value)
92
+
93
+
94
+ def get_fields(cls, attrs):
95
+ """
96
+ Get the fields for a class based on its attributes.
97
+
98
+ Args:
99
+ cls (type): The class to get the fields for.
100
+ attrs (dict): The attributes of the class.
101
+
102
+ Returns:
103
+ dict: A dictionary mapping field names to Field instances.
104
+ """
105
+
106
+ fields = {**getattr(cls, _FIELDS, {})}
107
+ annotations = {**attrs.get("__annotations__", {})}
108
+
109
+ for attr_name, attr_value in attrs.items():
110
+ if attr_name not in annotations and is_possible_field(attr_name, attr_value):
111
+ if attr_name in fields:
112
+ if not isinstance(attr_value, fields[attr_name].type):
113
+ raise TypeMismatchError(
114
+ f"Type mismatch for field '{attr_name}' of class '{fields[attr_name].origin_cls}'. Expected {fields[attr_name].type}, got {type(attr_value)}"
115
+ )
116
+ annotations[attr_name] = fields[attr_name].type
117
+
118
+ for field_name, field_type in annotations.items():
119
+ if field_name in fields and fields[field_name].final:
120
+ raise FinalFieldError(
121
+ f"Final field {field_name} defined in {fields[field_name].origin_cls} overridden in {cls}"
122
+ )
123
+
124
+ args = {
125
+ "name": field_name,
126
+ "type": field_type,
127
+ "origin_cls": attrs["__qualname__"],
128
+ }
129
+
130
+ if field_name in attrs:
131
+ field = attrs[field_name]
132
+ if isinstance(field, Field):
133
+ args = {**dataclasses.asdict(field), **args}
134
+ elif isinstance(field, dataclasses.Field):
135
+ args = {
136
+ "default": field.default,
137
+ "name": field.name,
138
+ "type": field.type,
139
+ "init": field.init,
140
+ "default_factory": field.default_factory,
141
+ **args,
142
+ }
143
+ else:
144
+ args["default"] = field
145
+ else:
146
+ args["default"] = dataclasses.MISSING
147
+ args["default_factory"] = None
148
+ args["required"] = True
149
+
150
+ field_instance = Field(**args)
151
+ fields[field_name] = field_instance
152
+
153
+ return fields
154
+
155
+
156
+ def is_dataclass(obj):
157
+ """Returns True if obj is a dataclass or an instance of a
158
+ dataclass."""
159
+ cls = obj if isinstance(obj, type) else type(obj)
160
+ return hasattr(cls, _FIELDS)
161
+
162
+
163
+ def class_fields(obj):
164
+ all_fields = fields(obj)
165
+ return [field for field in all_fields if field.origin_cls == obj.__class__.__qualname__]
166
+
167
+
168
+ def fields(cls):
169
+ return list(getattr(cls, _FIELDS).values())
170
+
171
+
172
+ def fields_names(cls):
173
+ return list(getattr(cls, _FIELDS).keys())
174
+
175
+
176
+ def final_fields(cls):
177
+ return [field for field in fields(cls) if field.final]
178
+
179
+
180
+ def required_fields(cls):
181
+ return [field for field in fields(cls) if field.required]
182
+
183
+
184
+ def abstract_fields(cls):
185
+ return [field for field in fields(cls) if field.abstract]
186
+
187
+
188
+ def is_abstract_field(field):
189
+ return field.abstract
190
+
191
+
192
+ def is_final_field(field):
193
+ return field.final
194
+
195
+
196
+ def get_field_default(field):
197
+ if field.default_factory is not None:
198
+ return field.default_factory()
199
+ else:
200
+ return field.default
201
+
202
+
203
+ def asdict(obj):
204
+ assert is_dataclass(obj), f"{obj} must be a dataclass, got {type(obj)} with bases {obj.__class__.__bases__}"
205
+ return _asdict_inner(obj)
206
+
207
+
208
+ def _asdict_inner(obj):
209
+ if is_dataclass(obj):
210
+ result = {}
211
+ for field in fields(obj):
212
+ v = getattr(obj, field.name)
213
+ result[field.name] = _asdict_inner(v)
214
+ return result
215
+ elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
216
+ return type(obj)(*[_asdict_inner(v) for v in obj])
217
+ elif isinstance(obj, (list, tuple)):
218
+ return type(obj)([_asdict_inner(v) for v in obj])
219
+ elif isinstance(obj, dict):
220
+ return type(obj)({_asdict_inner(k): _asdict_inner(v) for k, v in obj.items()})
221
+ else:
222
+ return copy.deepcopy(obj)
223
+
224
+
225
+ class DataclassMeta(ABCMeta):
226
+ """
227
+ Metaclass for Dataclass.
228
+ Checks for final fields when a subclass is created.
229
+ """
230
+
231
+ @final
232
+ def __init__(cls, name, bases, attrs):
233
+ super().__init__(name, bases, attrs)
234
+ setattr(cls, _FIELDS, get_fields(cls, attrs))
235
+
236
+
237
+ class Dataclass(metaclass=DataclassMeta):
238
+ """
239
+ Base class for data-like classes that provides additional functionality and control
240
+ over Python's built-in @dataclasses.dataclass decorator. Other classes can inherit from
241
+ this class to get the benefits of this implementation. As a base class, it ensures that
242
+ all subclasses will automatically be data classes.
243
+
244
+ The usage and field definitions are similar to Python's built-in @dataclasses.dataclass decorator.
245
+ However, this implementation provides additional classes for defining "final", "required",
246
+ and "abstract" fields.
247
+
248
+ Key enhancements of this custom implementation:
249
+
250
+ 1. Automatic Data Class Creation: All subclasses automatically become data classes,
251
+ without needing to use the @dataclasses.dataclass decorator.
252
+
253
+ 2. Field Immutability: Supports creation of "final" fields (using FinalField class) that
254
+ cannot be overridden by subclasses. This functionality is not natively supported in
255
+ Python or in the built-in dataclasses module.
256
+
257
+ 3. Required Fields: Supports creation of "required" fields (using RequiredField class) that
258
+ must be provided when creating an instance of the class, adding a level of validation
259
+ not present in the built-in dataclasses module.
260
+
261
+ 4. Abstract Fields: Supports creation of "abstract" fields (using AbstractField class) that
262
+ must be overridden by any non-abstract subclass. This is similar to abstract methods in
263
+ an abc.ABC class, but applied to fields.
264
+
265
+ 5. Type Checking: Performs type checking to ensure that if a field is redefined in a subclass,
266
+ the type of the field remains consistent, adding static type checking not natively supported
267
+ in Python.
268
+
269
+ 6. Error Definitions: Defines specific error types (FinalFieldError, RequiredFieldError,
270
+ AbstractFieldError, TypeMismatchError) for providing detailed error information during debugging.
271
+
272
+ 7. MetaClass Usage: Uses a metaclass (DataclassMeta) for customization of class creation,
273
+ allowing checks and alterations to be made at the time of class creation, providing more control.
274
+
275
+ Example:
276
+ ```
277
+ class Parent(Dataclass):
278
+ final_field: int = FinalField(1) # this field cannot be overridden
279
+ required_field: str = RequiredField()
280
+ also_required_field: float
281
+ abstract_field: int = AbstractField()
282
+
283
+ class Child(Parent):
284
+ abstract_field = 3 # now once overridden, this is no longer abstract
285
+ required_field = Field(name="required_field", default="provided", type=str)
286
+
287
+ class Mixin(Dataclass):
288
+ mixin_field = Field(name="mixin_field", default="mixin", type=str)
289
+
290
+ class GrandChild(Child, Mixin):
291
+ pass
292
+
293
+ grand_child = GrandChild()
294
+ print(grand_child.to_dict())
295
+ ```
296
+
297
+ """
298
+
299
+ @final
300
+ def __init__(self, *args, **kwargs):
301
+ """
302
+ Initialize fields based on kwargs.
303
+ Checks for abstract fields when an instance is created.
304
+ """
305
+ init_fields = [field for field in fields(self) if field.init]
306
+ for field, arg in zip(init_fields, args):
307
+ kwargs[field.name] = arg
308
+
309
+ for field in abstract_fields(self):
310
+ raise AbstractFieldError(
311
+ f"Abstract field '{field.name}' of class {field.origin_cls} not implemented in {self.__class__.__name__}"
312
+ )
313
+
314
+ for field in required_fields(self):
315
+ if field.name not in kwargs:
316
+ raise RequiredFieldError(
317
+ f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
318
+ )
319
+
320
+ for field in fields(self):
321
+ if field.name in kwargs:
322
+ setattr(self, field.name, kwargs[field.name])
323
+ else:
324
+ setattr(self, field.name, get_field_default(field))
325
+
326
+ self.__post_init__()
327
+
328
+ @property
329
+ def __is_dataclass__(self) -> bool:
330
+ return True
331
+
332
+ def __post_init__(self):
333
+ """
334
+ Post initialization hook.
335
+ """
336
  pass
337
 
338
+ def to_dict(self):
339
+ """
340
+ Convert to dict.
341
+ """
342
+ return asdict(self)
343
+
344
+ def __repr__(self) -> str:
345
+ """
346
+ String representation.
347
+ """
348
+ return f"{self.__class__.__name__}({', '.join([f'{field.name}={repr(getattr(self, field.name))}' for field in fields(self)])})"