import difflib import inspect import json import os import pkgutil from abc import abstractmethod from copy import deepcopy from typing import Dict, List, Optional, Union, final from .dataclass import ( AbstractField, Dataclass, Field, InternalField, NonPositionalField, fields, ) from .logging_utils import get_logger from .parsing_utils import ( separate_inside_and_outside_square_brackets, ) from .settings_utils import get_settings from .text_utils import camel_to_snake_case, is_camel_case from .type_utils import issubtype from .utils import artifacts_json_cache, save_json logger = get_logger() settings = get_settings() class Artifactories: def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super().__new__(cls) cls.instance.artifactories = [] return cls.instance def __iter__(self): self._index = 0 # Initialize/reset the index for iteration return self def __next__(self): while self._index < len(self.artifactories): artifactory = self.artifactories[self._index] self._index += 1 if ( settings.use_only_local_catalogs and not artifactory.is_local ): # Corrected typo from 'is_loacl' to 'is_local' continue return artifactory raise StopIteration def register(self, artifactory): assert isinstance( artifactory, Artifactory ), "Artifactory must be an instance of Artifactory" assert hasattr( artifactory, "__contains__" ), "Artifactory must have __contains__ method" assert hasattr( artifactory, "__getitem__" ), "Artifactory must have __getitem__ method" self.artifactories = [artifactory, *self.artifactories] def unregister(self, artifactory): assert isinstance( artifactory, Artifactory ), "Artifactory must be an instance of Artifactory" assert hasattr( artifactory, "__contains__" ), "Artifactory must have __contains__ method" assert hasattr( artifactory, "__getitem__" ), "Artifactory must have __getitem__ method" self.artifactories.remove(artifactory) def reset(self): self.artifactories = [] def map_values_in_place(object, mapper): if isinstance(object, dict): for key, value in object.items(): object[key] = mapper(value) return object if isinstance(object, list): for i in range(len(object)): object[i] = mapper(object[i]) return object return mapper(object) def get_closest_artifact_type(type): artifact_type_options = list(Artifact._class_register.keys()) matches = difflib.get_close_matches(type, artifact_type_options) if matches: return matches[0] # Return the closest match return None class UnrecognizedArtifactTypeError(ValueError): def __init__(self, type) -> None: maybe_class = "".join(word.capitalize() for word in type.split("_")) message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed." closest_artifact_type = get_closest_artifact_type(type) if closest_artifact_type is not None: message += "\n\n" f"Did you mean '{closest_artifact_type}'?" super().__init__(message) class MissingArtifactTypeError(ValueError): def __init__(self, dic) -> None: message = ( f"Missing 'type' parameter. Expected 'type' in artifact dict, got {dic}" ) super().__init__(message) class Artifact(Dataclass): _class_register = {} type: str = Field(default=None, final=True, init=False) __description__: str = NonPositionalField( default=None, required=False, also_positional=False ) __tags__: Dict[str, str] = NonPositionalField( default_factory=dict, required=False, also_positional=False ) __id__: str = InternalField(default=None, required=False, also_positional=False) @classmethod def is_artifact_dict(cls, d): return isinstance(d, dict) and "type" in d @classmethod def verify_artifact_dict(cls, d): if not isinstance(d, dict): raise ValueError( f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'." ) if "type" not in d: raise MissingArtifactTypeError(d) if not cls.is_registered_type(d["type"]): raise UnrecognizedArtifactTypeError(d["type"]) @classmethod def get_artifact_type(cls): return camel_to_snake_case(cls.__name__) @classmethod def register_class(cls, artifact_class): assert issubclass( artifact_class, Artifact ), f"Artifact class must be a subclass of Artifact, got '{artifact_class}'" assert is_camel_case( artifact_class.__name__ ), f"Artifact class name must be legal camel case, got '{artifact_class.__name__}'" snake_case_key = camel_to_snake_case(artifact_class.__name__) if cls.is_registered_type(snake_case_key): assert ( str(cls._class_register[snake_case_key]) == str(artifact_class) ), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overridden by {artifact_class}." return snake_case_key cls._class_register[snake_case_key] = artifact_class return snake_case_key def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls.register_class(cls) @classmethod def is_artifact_file(cls, path): if not os.path.exists(path) or not os.path.isfile(path): return False with open(path) as f: d = json.load(f) return cls.is_artifact_dict(d) @classmethod def is_registered_type(cls, type: str): return type in cls._class_register @classmethod def is_registered_class_name(cls, class_name: str): snake_case_key = camel_to_snake_case(class_name) return cls.is_registered_type(snake_case_key) @classmethod def is_registered_class(cls, clz: object): return clz in set(cls._class_register.values()) @classmethod def _recursive_load(cls, obj): if isinstance(obj, dict): new_d = {} for key, value in obj.items(): new_d[key] = cls._recursive_load(value) obj = new_d elif isinstance(obj, list): obj = [cls._recursive_load(value) for value in obj] else: pass if cls.is_artifact_dict(obj): cls.verify_artifact_dict(obj) return cls._class_register[obj.pop("type")](**obj) return obj @classmethod def from_dict(cls, d, overwrite_args=None): if overwrite_args is not None: d = {**d, **overwrite_args} cls.verify_artifact_dict(d) return cls._recursive_load(d) @classmethod def load(cls, path, artifact_identifier=None, overwrite_args=None): d = artifacts_json_cache(path) new_artifact = cls.from_dict(d, overwrite_args=overwrite_args) new_artifact.__id__ = artifact_identifier return new_artifact def prepare(self): pass def verify(self): pass @final def __pre_init__(self, **kwargs): self._init_dict = get_raw(kwargs) @final def __post_init__(self): self.type = self.register_class(self.__class__) for field in fields(self): if issubtype( field.type, Union[Artifact, List[Artifact], Dict[str, Artifact]] ): value = getattr(self, field.name) value = map_values_in_place(value, maybe_recover_artifact) setattr(self, field.name, value) self.prepare() self.verify() def _to_raw_dict(self): return {"type": self.type, **self._init_dict} def save(self, path): data = self.to_dict() save_json(path, data) def get_raw(obj): if isinstance(obj, Artifact): return obj._to_raw_dict() if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple return type(obj)(*[get_raw(v) for v in obj]) if isinstance(obj, (list, tuple)): return type(obj)([get_raw(v) for v in obj]) if isinstance(obj, dict): return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()}) return deepcopy(obj) class ArtifactList(list, Artifact): def prepare(self): for artifact in self: artifact.prepare() class Artifactory(Artifact): is_local: bool = AbstractField() @abstractmethod def __contains__(self, name: str) -> bool: pass @abstractmethod def __getitem__(self, name) -> Artifact: pass @abstractmethod def get_with_overwrite(self, name, overwrite_args) -> Artifact: pass class UnitxtArtifactNotFoundError(Exception): def __init__(self, name, artifactories): self.name = name self.artifactories = artifactories def __str__(self): msg = f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}." if settings.use_only_local_catalogs: msg += f" Notice that unitxt.settings.use_only_local_catalogs is set to True, if you want to use remote catalogs set this settings or the environment variable {settings.use_only_local_catalogs_key}." return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}" def fetch_artifact(name): if Artifact.is_artifact_file(name): return Artifact.load(name), None artifactory, name, args = get_artifactory_name_and_args(name=name) return artifactory.get_with_overwrite(name, overwrite_args=args), artifactory def get_artifactory_name_and_args( name: str, artifactories: Optional[List[Artifactory]] = None ): name, args = separate_inside_and_outside_square_brackets(name) if artifactories is None: artifactories = list(Artifactories()) for artifactory in artifactories: if name in artifactory: return artifactory, name, args raise UnitxtArtifactNotFoundError(name, artifactories) def verbosed_fetch_artifact(identifier): artifact, artifactory = fetch_artifact(identifier) logger.info(f"Artifact {identifier} is fetched from {artifactory}") return artifact def reset_artifacts_json_cache(): artifacts_json_cache.cache_clear() def maybe_recover_artifact(artifact): if isinstance(artifact, str): return verbosed_fetch_artifact(artifact) return artifact def register_all_artifacts(path): for loader, module_name, _is_pkg in pkgutil.walk_packages(path): logger.info(__name__) if module_name == __name__: continue logger.info(f"Loading {module_name}") # Import the module module = loader.find_module(module_name).load_module(module_name) # Iterate over every object in the module for _name, obj in inspect.getmembers(module): # Make sure the object is a class if inspect.isclass(obj): # Make sure the class is a subclass of Artifact (but not Artifact itself) if issubclass(obj, Artifact) and obj is not Artifact: logger.info(obj)