Upload dataclass.py with huggingface_hub
Browse files- dataclass.py +27 -3
dataclass.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import copy
|
2 |
import dataclasses
|
3 |
import functools
|
|
|
4 |
from abc import ABCMeta
|
5 |
from inspect import Parameter, Signature
|
6 |
-
from typing import Any, final
|
7 |
|
8 |
_FIELDS = "__fields__"
|
9 |
|
@@ -36,8 +37,10 @@ class Field:
|
|
36 |
final: bool = False
|
37 |
abstract: bool = False
|
38 |
required: bool = False
|
|
|
39 |
internal: bool = False
|
40 |
origin_cls: type = None
|
|
|
41 |
|
42 |
def get_default(self):
|
43 |
if self.default_factory is not None:
|
@@ -51,6 +54,12 @@ class FinalField(Field):
|
|
51 |
self.final = True
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
@dataclasses.dataclass
|
55 |
class RequiredField(Field):
|
56 |
def __post_init__(self):
|
@@ -111,7 +120,7 @@ class UnexpectedArgumentError(TypeError):
|
|
111 |
pass
|
112 |
|
113 |
|
114 |
-
|
115 |
|
116 |
|
117 |
def is_possible_field(field_name, field_value):
|
@@ -125,7 +134,7 @@ def is_possible_field(field_name, field_value):
|
|
125 |
bool: True if the name-value pair can represent a field, False otherwise.
|
126 |
"""
|
127 |
return (
|
128 |
-
field_name not in
|
129 |
and not field_name.startswith("__")
|
130 |
and not callable(field_value)
|
131 |
)
|
@@ -229,6 +238,10 @@ def required_fields(cls):
|
|
229 |
return [field for field in fields(cls) if field.required]
|
230 |
|
231 |
|
|
|
|
|
|
|
|
|
232 |
def abstract_fields(cls):
|
233 |
return [field for field in fields(cls) if field.abstract]
|
234 |
|
@@ -241,6 +254,10 @@ def is_final_field(field):
|
|
241 |
return field.final
|
242 |
|
243 |
|
|
|
|
|
|
|
|
|
244 |
def get_field_default(field):
|
245 |
if field.default_factory is not None:
|
246 |
return field.default_factory()
|
@@ -394,6 +411,7 @@ class Dataclass(metaclass=DataclassMeta):
|
|
394 |
"""Initialize fields based on kwargs.
|
395 |
|
396 |
Checks for abstract fields when an instance is created.
|
|
|
397 |
"""
|
398 |
_init_fields = [field for field in fields(self) if field.init]
|
399 |
_init_fields_names = [field.name for field in _init_fields]
|
@@ -401,6 +419,12 @@ class Dataclass(metaclass=DataclassMeta):
|
|
401 |
field.name for field in _init_fields if field.also_positional
|
402 |
]
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
for name in _init_positional_fields_names[: len(argv)]:
|
405 |
if name in kwargs:
|
406 |
raise TypeError(
|
|
|
1 |
import copy
|
2 |
import dataclasses
|
3 |
import functools
|
4 |
+
import warnings
|
5 |
from abc import ABCMeta
|
6 |
from inspect import Parameter, Signature
|
7 |
+
from typing import Any, Dict, final
|
8 |
|
9 |
_FIELDS = "__fields__"
|
10 |
|
|
|
37 |
final: bool = False
|
38 |
abstract: bool = False
|
39 |
required: bool = False
|
40 |
+
deprecated: bool = False
|
41 |
internal: bool = False
|
42 |
origin_cls: type = None
|
43 |
+
metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
|
44 |
|
45 |
def get_default(self):
|
46 |
if self.default_factory is not None:
|
|
|
54 |
self.final = True
|
55 |
|
56 |
|
57 |
+
@dataclasses.dataclass
|
58 |
+
class DeprecatedField(Field):
|
59 |
+
def __post_init__(self):
|
60 |
+
self.deprecated = True
|
61 |
+
|
62 |
+
|
63 |
@dataclasses.dataclass
|
64 |
class RequiredField(Field):
|
65 |
def __post_init__(self):
|
|
|
120 |
pass
|
121 |
|
122 |
|
123 |
+
standard_variables = dir(object)
|
124 |
|
125 |
|
126 |
def is_possible_field(field_name, field_value):
|
|
|
134 |
bool: True if the name-value pair can represent a field, False otherwise.
|
135 |
"""
|
136 |
return (
|
137 |
+
field_name not in standard_variables
|
138 |
and not field_name.startswith("__")
|
139 |
and not callable(field_value)
|
140 |
)
|
|
|
238 |
return [field for field in fields(cls) if field.required]
|
239 |
|
240 |
|
241 |
+
def deprecated_fields(cls):
|
242 |
+
return [field for field in fields(cls) if field.deprecated]
|
243 |
+
|
244 |
+
|
245 |
def abstract_fields(cls):
|
246 |
return [field for field in fields(cls) if field.abstract]
|
247 |
|
|
|
254 |
return field.final
|
255 |
|
256 |
|
257 |
+
def is_deprecated_field(field):
|
258 |
+
return field.deprecated
|
259 |
+
|
260 |
+
|
261 |
def get_field_default(field):
|
262 |
if field.default_factory is not None:
|
263 |
return field.default_factory()
|
|
|
411 |
"""Initialize fields based on kwargs.
|
412 |
|
413 |
Checks for abstract fields when an instance is created.
|
414 |
+
Warn when a deprecated is used
|
415 |
"""
|
416 |
_init_fields = [field for field in fields(self) if field.init]
|
417 |
_init_fields_names = [field.name for field in _init_fields]
|
|
|
419 |
field.name for field in _init_fields if field.also_positional
|
420 |
]
|
421 |
|
422 |
+
_init_deprecated_fields = [field for field in _init_fields if field.deprecated]
|
423 |
+
for dep_field in _init_deprecated_fields:
|
424 |
+
warnings.warn(
|
425 |
+
dep_field.metadata["deprecation_msg"], DeprecationWarning, stacklevel=2
|
426 |
+
)
|
427 |
+
|
428 |
for name in _init_positional_fields_names[: len(argv)]:
|
429 |
if name in kwargs:
|
430 |
raise TypeError(
|