Elron commited on
Commit
c5d9b09
1 Parent(s): 4e49444

Upload dataclass.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataclass.py +51 -47
dataclass.py CHANGED
@@ -1,7 +1,6 @@
1
  import copy
2
  import dataclasses
3
  from abc import ABCMeta
4
- from copy import deepcopy
5
  from typing import Any, final
6
 
7
  _FIELDS = "__fields__"
@@ -9,8 +8,7 @@ _FIELDS = "__fields__"
9
 
10
  @dataclasses.dataclass
11
  class Field:
12
- """
13
- An alternative to dataclasses.dataclass decorator for a more flexible field definition.
14
 
15
  Attributes:
16
  default (Any, optional): Default value for the field. Defaults to None.
@@ -38,8 +36,7 @@ class Field:
38
  def get_default(self):
39
  if self.default_factory is not None:
40
  return self.default_factory()
41
- else:
42
- return self.default
43
 
44
 
45
  @dataclasses.dataclass
@@ -107,8 +104,7 @@ standart_variables = dir(object)
107
 
108
 
109
  def is_possible_field(field_name, field_value):
110
- """
111
- Check if a name-value pair can potentially represent a field.
112
 
113
  Args:
114
  field_name (str): The name of the field.
@@ -117,12 +113,15 @@ def is_possible_field(field_name, field_value):
117
  Returns:
118
  bool: True if the name-value pair can represent a field, False otherwise.
119
  """
120
- return field_name not in standart_variables and not field_name.startswith("__") and not callable(field_value)
 
 
 
 
121
 
122
 
123
  def get_fields(cls, attrs):
124
- """
125
- Get the fields for a class based on its attributes.
126
 
127
  Args:
128
  cls (type): The class to get the fields for.
@@ -191,15 +190,16 @@ def get_fields(cls, attrs):
191
 
192
 
193
  def is_dataclass(obj):
194
- """Returns True if obj is a dataclass or an instance of a
195
- dataclass."""
196
  cls = obj if isinstance(obj, type) else type(obj)
197
  return hasattr(cls, _FIELDS)
198
 
199
 
200
  def class_fields(obj):
201
  all_fields = fields(obj)
202
- return [field for field in all_fields if field.origin_cls == obj.__class__.__qualname__]
 
 
203
 
204
 
205
  def fields(cls):
@@ -233,31 +233,36 @@ def is_final_field(field):
233
  def get_field_default(field):
234
  if field.default_factory is not None:
235
  return field.default_factory()
236
- else:
237
- return field.default
238
 
239
 
240
  def asdict(obj):
241
- assert is_dataclass(obj), f"{obj} must be a dataclass, got {type(obj)} with bases {obj.__class__.__bases__}"
 
 
242
  return _asdict_inner(obj)
243
 
244
 
245
  def _asdict_inner(obj):
246
  if is_dataclass(obj):
247
  return obj.to_dict()
248
- elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
 
249
  return type(obj)(*[_asdict_inner(v) for v in obj])
250
- elif isinstance(obj, (list, tuple)):
 
251
  return type(obj)([_asdict_inner(v) for v in obj])
252
- elif isinstance(obj, dict):
 
253
  return type(obj)({_asdict_inner(k): _asdict_inner(v) for k, v in obj.items()})
254
- else:
255
- return copy.deepcopy(obj)
256
 
257
 
258
  class DataclassMeta(ABCMeta):
259
- """
260
- Metaclass for Dataclass.
261
  Checks for final fields when a subclass is created.
262
  """
263
 
@@ -268,7 +273,8 @@ class DataclassMeta(ABCMeta):
268
 
269
 
270
  class Dataclass(metaclass=DataclassMeta):
271
- """
 
272
  Base class for data-like classes that provides additional functionality and control
273
  over Python's built-in @dataclasses.dataclass decorator. Other classes can inherit from
274
  this class to get the benefits of this implementation. As a base class, it ensures that
@@ -324,7 +330,7 @@ class Dataclass(metaclass=DataclassMeta):
324
  pass
325
 
326
  grand_child = GrandChild()
327
- print(grand_child.to_dict())
328
  ```
329
 
330
  """
@@ -333,17 +339,21 @@ class Dataclass(metaclass=DataclassMeta):
333
 
334
  @final
335
  def __init__(self, *argv, **kwargs):
336
- """
337
- Initialize fields based on kwargs.
338
  Checks for abstract fields when an instance is created.
339
  """
340
  _init_fields = [field for field in fields(self) if field.init]
341
  _init_fields_names = [field.name for field in _init_fields]
342
- _init_positional_fields_names = [field.name for field in _init_fields if field.also_positional]
 
 
343
 
344
  for name in _init_positional_fields_names[: len(argv)]:
345
  if name in kwargs:
346
- raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
 
 
347
 
348
  expected_unexpected_argv = kwargs.pop("_argv", None)
349
 
@@ -360,11 +370,15 @@ class Dataclass(metaclass=DataclassMeta):
360
 
361
  expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
362
  unexpected_kwargs = {
363
- k: v for k, v in kwargs.items() if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
 
 
364
  }
365
 
366
  if expected_unexpected_kwargs is not None:
367
- intersection = set(unexpected_kwargs.keys()) & set(expected_unexpected_kwargs.keys())
 
 
368
  assert (
369
  len(intersection) == 0
370
  ), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
@@ -416,31 +430,21 @@ class Dataclass(metaclass=DataclassMeta):
416
  return True
417
 
418
  def __pre_init__(self, **kwargs):
419
- """
420
- Pre initialization hook.
421
- """
422
  pass
423
 
424
  def __post_init__(self):
425
- """
426
- Post initialization hook.
427
- """
428
  pass
429
 
430
  def _to_raw_dict(self):
431
- """
432
- Convert to raw dict
433
- """
434
  return {field.name: getattr(self, field.name) for field in fields(self)}
435
 
436
  def to_dict(self):
437
- """
438
- Convert to dict.
439
- """
440
  return _asdict_inner(self._to_raw_dict())
441
 
442
  def __repr__(self) -> str:
443
- """
444
- String representation.
445
- """
446
- return f"{self.__class__.__name__}({', '.join([f'{field.name}={repr(getattr(self, field.name))}' for field in fields(self)])})"
 
1
  import copy
2
  import dataclasses
3
  from abc import ABCMeta
 
4
  from typing import Any, final
5
 
6
  _FIELDS = "__fields__"
 
8
 
9
  @dataclasses.dataclass
10
  class Field:
11
+ """An alternative to dataclasses.dataclass decorator for a more flexible field definition.
 
12
 
13
  Attributes:
14
  default (Any, optional): Default value for the field. Defaults to None.
 
36
  def get_default(self):
37
  if self.default_factory is not None:
38
  return self.default_factory()
39
+ return self.default
 
40
 
41
 
42
  @dataclasses.dataclass
 
104
 
105
 
106
  def is_possible_field(field_name, field_value):
107
+ """Check if a name-value pair can potentially represent a field.
 
108
 
109
  Args:
110
  field_name (str): The name of the field.
 
113
  Returns:
114
  bool: True if the name-value pair can represent a field, False otherwise.
115
  """
116
+ return (
117
+ field_name not in standart_variables
118
+ and not field_name.startswith("__")
119
+ and not callable(field_value)
120
+ )
121
 
122
 
123
  def get_fields(cls, attrs):
124
+ """Get the fields for a class based on its attributes.
 
125
 
126
  Args:
127
  cls (type): The class to get the fields for.
 
190
 
191
 
192
  def is_dataclass(obj):
193
+ """Returns True if obj is a dataclass or an instance of a dataclass."""
 
194
  cls = obj if isinstance(obj, type) else type(obj)
195
  return hasattr(cls, _FIELDS)
196
 
197
 
198
  def class_fields(obj):
199
  all_fields = fields(obj)
200
+ return [
201
+ field for field in all_fields if field.origin_cls == obj.__class__.__qualname__
202
+ ]
203
 
204
 
205
  def fields(cls):
 
233
  def get_field_default(field):
234
  if field.default_factory is not None:
235
  return field.default_factory()
236
+
237
+ return field.default
238
 
239
 
240
  def asdict(obj):
241
+ assert is_dataclass(
242
+ obj
243
+ ), f"{obj} must be a dataclass, got {type(obj)} with bases {obj.__class__.__bases__}"
244
  return _asdict_inner(obj)
245
 
246
 
247
  def _asdict_inner(obj):
248
  if is_dataclass(obj):
249
  return obj.to_dict()
250
+
251
+ if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
252
  return type(obj)(*[_asdict_inner(v) for v in obj])
253
+
254
+ if isinstance(obj, (list, tuple)):
255
  return type(obj)([_asdict_inner(v) for v in obj])
256
+
257
+ if isinstance(obj, dict):
258
  return type(obj)({_asdict_inner(k): _asdict_inner(v) for k, v in obj.items()})
259
+
260
+ return copy.deepcopy(obj)
261
 
262
 
263
  class DataclassMeta(ABCMeta):
264
+ """Metaclass for Dataclass.
265
+
266
  Checks for final fields when a subclass is created.
267
  """
268
 
 
273
 
274
 
275
  class Dataclass(metaclass=DataclassMeta):
276
+ """Base class for data-like classes that provides additional functionality and control.
277
+
278
  Base class for data-like classes that provides additional functionality and control
279
  over Python's built-in @dataclasses.dataclass decorator. Other classes can inherit from
280
  this class to get the benefits of this implementation. As a base class, it ensures that
 
330
  pass
331
 
332
  grand_child = GrandChild()
333
+ logging.info(grand_child.to_dict())
334
  ```
335
 
336
  """
 
339
 
340
  @final
341
  def __init__(self, *argv, **kwargs):
342
+ """Initialize fields based on kwargs.
343
+
344
  Checks for abstract fields when an instance is created.
345
  """
346
  _init_fields = [field for field in fields(self) if field.init]
347
  _init_fields_names = [field.name for field in _init_fields]
348
+ _init_positional_fields_names = [
349
+ field.name for field in _init_fields if field.also_positional
350
+ ]
351
 
352
  for name in _init_positional_fields_names[: len(argv)]:
353
  if name in kwargs:
354
+ raise TypeError(
355
+ f"{self.__class__.__name__} got multiple values for argument '{name}'"
356
+ )
357
 
358
  expected_unexpected_argv = kwargs.pop("_argv", None)
359
 
 
370
 
371
  expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
372
  unexpected_kwargs = {
373
+ k: v
374
+ for k, v in kwargs.items()
375
+ if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
376
  }
377
 
378
  if expected_unexpected_kwargs is not None:
379
+ intersection = set(unexpected_kwargs.keys()) & set(
380
+ expected_unexpected_kwargs.keys()
381
+ )
382
  assert (
383
  len(intersection) == 0
384
  ), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
 
430
  return True
431
 
432
  def __pre_init__(self, **kwargs):
433
+ """Pre initialization hook."""
 
 
434
  pass
435
 
436
  def __post_init__(self):
437
+ """Post initialization hook."""
 
 
438
  pass
439
 
440
  def _to_raw_dict(self):
441
+ """Convert to raw dict."""
 
 
442
  return {field.name: getattr(self, field.name) for field in fields(self)}
443
 
444
  def to_dict(self):
445
+ """Convert to dict."""
 
 
446
  return _asdict_inner(self._to_raw_dict())
447
 
448
  def __repr__(self) -> str:
449
+ """String representation."""
450
+ return f"{self.__class__.__name__}({', '.join([f'{field.name}={getattr(self, field.name)!r}' for field in fields(self)])})"