Elron commited on
Commit
8afaaba
·
verified ·
1 Parent(s): 7b21e0b

Upload dataclass.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataclass.py +27 -3
dataclass.py CHANGED
@@ -1,9 +1,10 @@
1
  import copy
2
  import dataclasses
3
  import functools
 
4
  from abc import ABCMeta
5
  from inspect import Parameter, Signature
6
- from typing import Any, final
7
 
8
  _FIELDS = "__fields__"
9
 
@@ -36,8 +37,10 @@ class Field:
36
  final: bool = False
37
  abstract: bool = False
38
  required: bool = False
 
39
  internal: bool = False
40
  origin_cls: type = None
 
41
 
42
  def get_default(self):
43
  if self.default_factory is not None:
@@ -51,6 +54,12 @@ class FinalField(Field):
51
  self.final = True
52
 
53
 
 
 
 
 
 
 
54
  @dataclasses.dataclass
55
  class RequiredField(Field):
56
  def __post_init__(self):
@@ -111,7 +120,7 @@ class UnexpectedArgumentError(TypeError):
111
  pass
112
 
113
 
114
- standart_variables = dir(object)
115
 
116
 
117
  def is_possible_field(field_name, field_value):
@@ -125,7 +134,7 @@ def is_possible_field(field_name, field_value):
125
  bool: True if the name-value pair can represent a field, False otherwise.
126
  """
127
  return (
128
- field_name not in standart_variables
129
  and not field_name.startswith("__")
130
  and not callable(field_value)
131
  )
@@ -229,6 +238,10 @@ def required_fields(cls):
229
  return [field for field in fields(cls) if field.required]
230
 
231
 
 
 
 
 
232
  def abstract_fields(cls):
233
  return [field for field in fields(cls) if field.abstract]
234
 
@@ -241,6 +254,10 @@ def is_final_field(field):
241
  return field.final
242
 
243
 
 
 
 
 
244
  def get_field_default(field):
245
  if field.default_factory is not None:
246
  return field.default_factory()
@@ -394,6 +411,7 @@ class Dataclass(metaclass=DataclassMeta):
394
  """Initialize fields based on kwargs.
395
 
396
  Checks for abstract fields when an instance is created.
 
397
  """
398
  _init_fields = [field for field in fields(self) if field.init]
399
  _init_fields_names = [field.name for field in _init_fields]
@@ -401,6 +419,12 @@ class Dataclass(metaclass=DataclassMeta):
401
  field.name for field in _init_fields if field.also_positional
402
  ]
403
 
 
 
 
 
 
 
404
  for name in _init_positional_fields_names[: len(argv)]:
405
  if name in kwargs:
406
  raise TypeError(
 
1
  import copy
2
  import dataclasses
3
  import functools
4
+ import warnings
5
  from abc import ABCMeta
6
  from inspect import Parameter, Signature
7
+ from typing import Any, Dict, final
8
 
9
  _FIELDS = "__fields__"
10
 
 
37
  final: bool = False
38
  abstract: bool = False
39
  required: bool = False
40
+ deprecated: bool = False
41
  internal: bool = False
42
  origin_cls: type = None
43
+ metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
44
 
45
  def get_default(self):
46
  if self.default_factory is not None:
 
54
  self.final = True
55
 
56
 
57
+ @dataclasses.dataclass
58
+ class DeprecatedField(Field):
59
+ def __post_init__(self):
60
+ self.deprecated = True
61
+
62
+
63
  @dataclasses.dataclass
64
  class RequiredField(Field):
65
  def __post_init__(self):
 
120
  pass
121
 
122
 
123
+ standard_variables = dir(object)
124
 
125
 
126
  def is_possible_field(field_name, field_value):
 
134
  bool: True if the name-value pair can represent a field, False otherwise.
135
  """
136
  return (
137
+ field_name not in standard_variables
138
  and not field_name.startswith("__")
139
  and not callable(field_value)
140
  )
 
238
  return [field for field in fields(cls) if field.required]
239
 
240
 
241
+ def deprecated_fields(cls):
242
+ return [field for field in fields(cls) if field.deprecated]
243
+
244
+
245
  def abstract_fields(cls):
246
  return [field for field in fields(cls) if field.abstract]
247
 
 
254
  return field.final
255
 
256
 
257
+ def is_deprecated_field(field):
258
+ return field.deprecated
259
+
260
+
261
  def get_field_default(field):
262
  if field.default_factory is not None:
263
  return field.default_factory()
 
411
  """Initialize fields based on kwargs.
412
 
413
  Checks for abstract fields when an instance is created.
414
+ Warn when a deprecated is used
415
  """
416
  _init_fields = [field for field in fields(self) if field.init]
417
  _init_fields_names = [field.name for field in _init_fields]
 
419
  field.name for field in _init_fields if field.also_positional
420
  ]
421
 
422
+ _init_deprecated_fields = [field for field in _init_fields if field.deprecated]
423
+ for dep_field in _init_deprecated_fields:
424
+ warnings.warn(
425
+ dep_field.metadata["deprecation_msg"], DeprecationWarning, stacklevel=2
426
+ )
427
+
428
  for name in _init_positional_fields_names[: len(argv)]:
429
  if name in kwargs:
430
  raise TypeError(