Elron commited on
Commit
0e3c8e6
1 Parent(s): 18480db

Upload artifact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. artifact.py +37 -2
artifact.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import inspect
2
  import json
3
  import os
@@ -44,6 +45,31 @@ def map_values_in_place(object, mapper):
44
  return mapper(object)
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class Artifact(Dataclass):
48
  type: str = Field(default=None, final=True, init=False)
49
 
@@ -53,6 +79,15 @@ class Artifact(Dataclass):
53
  def is_artifact_dict(cls, d):
54
  return isinstance(d, dict) and "type" in d and d["type"] in cls._class_register
55
 
 
 
 
 
 
 
 
 
 
56
  @classmethod
57
  def get_artifact_type(cls):
58
  return camel_to_snake_case(cls.__name__)
@@ -106,7 +141,7 @@ class Artifact(Dataclass):
106
 
107
  @classmethod
108
  def from_dict(cls, d):
109
- assert cls.is_artifact_dict(d), "Input must be a dict with type field"
110
  return cls._recursive_load(d)
111
 
112
  @classmethod
@@ -127,7 +162,7 @@ class Artifact(Dataclass):
127
 
128
  @final
129
  def __pre_init__(self, **kwargs):
130
- self._init_dict = kwargs
131
 
132
  @final
133
  def __post_init__(self):
 
1
+ import difflib
2
  import inspect
3
  import json
4
  import os
 
45
  return mapper(object)
46
 
47
 
48
+ def get_closest_artifact_type(type):
49
+ artifact_type_options = list(Artifact._class_register.keys())
50
+ matches = difflib.get_close_matches(type, artifact_type_options)
51
+ if matches:
52
+ return matches[0] # Return the closest match
53
+ return None
54
+
55
+
56
+ class UnrecognizedArtifactType(ValueError):
57
+ def __init__(self, type) -> None:
58
+ closest_artifact_type = get_closest_artifact_type(type)
59
+ message = (
60
+ f"'{type}' is not a recognized value for 'type' parameter."
61
+ "\n\n"
62
+ f"Did you mean '{closest_artifact_type}'?"
63
+ )
64
+ super().__init__(message)
65
+
66
+
67
+ class MissingArtifactType(ValueError):
68
+ def __init__(self, dic) -> None:
69
+ message = f"Missing 'type' parameter. Expected 'type' in artifact dict, got {dic}"
70
+ super().__init__(message)
71
+
72
+
73
  class Artifact(Dataclass):
74
  type: str = Field(default=None, final=True, init=False)
75
 
 
79
  def is_artifact_dict(cls, d):
80
  return isinstance(d, dict) and "type" in d and d["type"] in cls._class_register
81
 
82
+ @classmethod
83
+ def verify_is_artifact_dict(cls, d):
84
+ if not isinstance(d, dict):
85
+ raise ValueError(f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'.")
86
+ if "type" not in d:
87
+ raise MissingArtifactType(d)
88
+ if d["type"] not in cls._class_register:
89
+ raise UnrecognizedArtifactType(d["type"])
90
+
91
  @classmethod
92
  def get_artifact_type(cls):
93
  return camel_to_snake_case(cls.__name__)
 
141
 
142
  @classmethod
143
  def from_dict(cls, d):
144
+ cls.verify_is_artifact_dict(d)
145
  return cls._recursive_load(d)
146
 
147
  @classmethod
 
162
 
163
  @final
164
  def __pre_init__(self, **kwargs):
165
+ self._init_dict = deepcopy(kwargs)
166
 
167
  @final
168
  def __post_init__(self):