| | |
| | |
| | import collections |
| | import contextlib |
| | import copy |
| | import inspect |
| | import json |
| | import sys |
| | import textwrap |
| | from typing import ( |
| | Any, |
| | Sequence, |
| | List, |
| | Dict, |
| | Optional, |
| | DefaultDict, |
| | Tuple, |
| | Iterable, |
| | Type, |
| | Generator, |
| | Union, |
| | overload, |
| | Literal, |
| | TypeVar, |
| | ) |
| | from itertools import zip_longest |
| | from importlib.metadata import version as importlib_version |
| | from typing import Final |
| |
|
| | import jsonschema |
| | import jsonschema.exceptions |
| | import jsonschema.validators |
| | import numpy as np |
| | import pandas as pd |
| | from packaging.version import Version |
| |
|
| | |
| | |
| | |
| | from altair import vegalite |
| |
|
| | if sys.version_info >= (3, 11): |
| | from typing import Self |
| | else: |
| | from typing_extensions import Self |
| |
|
| | TSchemaBase = TypeVar("TSchemaBase", bound=Type["SchemaBase"]) |
| |
|
| | ValidationErrorList = List[jsonschema.exceptions.ValidationError] |
| | GroupedValidationErrors = Dict[str, ValidationErrorList] |
| |
|
| | |
| | |
| | |
| | _VEGA_LITE_ROOT_URI: Final = "urn:vega-lite-schema" |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | _DEFAULT_JSON_SCHEMA_DRAFT_URL: Final = "http://json-schema.org/draft-07/schema#" |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | DEBUG_MODE: bool = True |
| |
|
| |
|
| | def enable_debug_mode() -> None: |
| | global DEBUG_MODE |
| | DEBUG_MODE = True |
| |
|
| |
|
| | def disable_debug_mode() -> None: |
| | global DEBUG_MODE |
| | DEBUG_MODE = False |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def debug_mode(arg: bool) -> Generator[None, None, None]: |
| | global DEBUG_MODE |
| | original = DEBUG_MODE |
| | DEBUG_MODE = arg |
| | try: |
| | yield |
| | finally: |
| | DEBUG_MODE = original |
| |
|
| |
|
| | @overload |
| | def validate_jsonschema( |
| | spec: Dict[str, Any], |
| | schema: Dict[str, Any], |
| | rootschema: Optional[Dict[str, Any]] = ..., |
| | *, |
| | raise_error: Literal[True] = ..., |
| | ) -> None: |
| | ... |
| |
|
| |
|
| | @overload |
| | def validate_jsonschema( |
| | spec: Dict[str, Any], |
| | schema: Dict[str, Any], |
| | rootschema: Optional[Dict[str, Any]] = ..., |
| | *, |
| | raise_error: Literal[False], |
| | ) -> Optional[jsonschema.exceptions.ValidationError]: |
| | ... |
| |
|
| |
|
| | def validate_jsonschema( |
| | spec, |
| | schema, |
| | rootschema=None, |
| | *, |
| | raise_error=True, |
| | ): |
| | """Validates the passed in spec against the schema in the context of the |
| | rootschema. If any errors are found, they are deduplicated and prioritized |
| | and only the most relevant errors are kept. Errors are then either raised |
| | or returned, depending on the value of `raise_error`. |
| | """ |
| | errors = _get_errors_from_spec(spec, schema, rootschema=rootschema) |
| | if errors: |
| | leaf_errors = _get_leaves_of_error_tree(errors) |
| | grouped_errors = _group_errors_by_json_path(leaf_errors) |
| | grouped_errors = _subset_to_most_specific_json_paths(grouped_errors) |
| | grouped_errors = _deduplicate_errors(grouped_errors) |
| |
|
| | |
| | |
| | main_error = list(grouped_errors.values())[0][0] |
| | |
| | |
| | |
| | |
| | |
| | main_error._all_errors = grouped_errors |
| | if raise_error: |
| | raise main_error |
| | else: |
| | return main_error |
| | else: |
| | return None |
| |
|
| |
|
| | def _get_errors_from_spec( |
| | spec: Dict[str, Any], |
| | schema: Dict[str, Any], |
| | rootschema: Optional[Dict[str, Any]] = None, |
| | ) -> ValidationErrorList: |
| | """Uses the relevant jsonschema validator to validate the passed in spec |
| | against the schema using the rootschema to resolve references. |
| | The schema and rootschema themselves are not validated but instead considered |
| | as valid. |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | json_schema_draft_url = _get_json_schema_draft_url(rootschema or schema) |
| | validator_cls = jsonschema.validators.validator_for( |
| | {"$schema": json_schema_draft_url} |
| | ) |
| | validator_kwargs: Dict[str, Any] = {} |
| | if hasattr(validator_cls, "FORMAT_CHECKER"): |
| | validator_kwargs["format_checker"] = validator_cls.FORMAT_CHECKER |
| |
|
| | if _use_referencing_library(): |
| | schema = _prepare_references_in_schema(schema) |
| | validator_kwargs["registry"] = _get_referencing_registry( |
| | rootschema or schema, json_schema_draft_url |
| | ) |
| |
|
| | else: |
| | |
| | validator_kwargs["resolver"] = ( |
| | jsonschema.RefResolver.from_schema(rootschema) |
| | if rootschema is not None |
| | else None |
| | ) |
| |
|
| | validator = validator_cls(schema, **validator_kwargs) |
| | errors = list(validator.iter_errors(spec)) |
| | return errors |
| |
|
| |
|
| | def _get_json_schema_draft_url(schema: dict) -> str: |
| | return schema.get("$schema", _DEFAULT_JSON_SCHEMA_DRAFT_URL) |
| |
|
| |
|
| | def _use_referencing_library() -> bool: |
| | """In version 4.18.0, the jsonschema package deprecated RefResolver in |
| | favor of the referencing library.""" |
| | jsonschema_version_str = importlib_version("jsonschema") |
| | return Version(jsonschema_version_str) >= Version("4.18") |
| |
|
| |
|
| | def _prepare_references_in_schema(schema: Dict[str, Any]) -> Dict[str, Any]: |
| | |
| | |
| | |
| | schema = copy.deepcopy(schema) |
| |
|
| | def _prepare_refs(d: Dict[str, Any]) -> Dict[str, Any]: |
| | """Add _VEGA_LITE_ROOT_URI in front of all $ref values. This function |
| | recursively iterates through the whole dictionary.""" |
| | for key, value in d.items(): |
| | if key == "$ref": |
| | d[key] = _VEGA_LITE_ROOT_URI + d[key] |
| | else: |
| | |
| | |
| | |
| | |
| | if isinstance(value, dict): |
| | d[key] = _prepare_refs(value) |
| | elif isinstance(value, list): |
| | prepared_values = [] |
| | for v in value: |
| | if isinstance(v, dict): |
| | v = _prepare_refs(v) |
| | prepared_values.append(v) |
| | d[key] = prepared_values |
| | return d |
| |
|
| | schema = _prepare_refs(schema) |
| | return schema |
| |
|
| |
|
| | |
| | |
| | def _get_referencing_registry( |
| | rootschema: Dict[str, Any], json_schema_draft_url: Optional[str] = None |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import referencing |
| | import referencing.jsonschema |
| |
|
| | if json_schema_draft_url is None: |
| | json_schema_draft_url = _get_json_schema_draft_url(rootschema) |
| |
|
| | specification = referencing.jsonschema.specification_with(json_schema_draft_url) |
| | resource = specification.create_resource(rootschema) |
| | return referencing.Registry().with_resource( |
| | uri=_VEGA_LITE_ROOT_URI, resource=resource |
| | ) |
| |
|
| |
|
| | def _json_path(err: jsonschema.exceptions.ValidationError) -> str: |
| | """Drop in replacement for the .json_path property of the jsonschema |
| | ValidationError class, which is not available as property for |
| | ValidationError with jsonschema<4.0.1. |
| | More info, see https://github.com/altair-viz/altair/issues/3038 |
| | """ |
| | path = "$" |
| | for elem in err.absolute_path: |
| | if isinstance(elem, int): |
| | path += "[" + str(elem) + "]" |
| | else: |
| | path += "." + elem |
| | return path |
| |
|
| |
|
| | def _group_errors_by_json_path( |
| | errors: ValidationErrorList, |
| | ) -> GroupedValidationErrors: |
| | """Groups errors by the `json_path` attribute of the jsonschema ValidationError |
| | class. This attribute contains the path to the offending element within |
| | a chart specification and can therefore be considered as an identifier of an |
| | 'issue' in the chart that needs to be fixed. |
| | """ |
| | errors_by_json_path = collections.defaultdict(list) |
| | for err in errors: |
| | err_key = getattr(err, "json_path", _json_path(err)) |
| | errors_by_json_path[err_key].append(err) |
| | return dict(errors_by_json_path) |
| |
|
| |
|
| | def _get_leaves_of_error_tree( |
| | errors: ValidationErrorList, |
| | ) -> ValidationErrorList: |
| | """For each error in `errors`, it traverses down the "error tree" that is generated |
| | by the jsonschema library to find and return all "leaf" errors. These are errors |
| | which have no further errors that caused it and so they are the most specific errors |
| | with the most specific error messages. |
| | """ |
| | leaves: ValidationErrorList = [] |
| | for err in errors: |
| | if err.context: |
| | |
| | |
| | |
| | leaves.extend(_get_leaves_of_error_tree(err.context)) |
| | else: |
| | leaves.append(err) |
| | return leaves |
| |
|
| |
|
| | def _subset_to_most_specific_json_paths( |
| | errors_by_json_path: GroupedValidationErrors, |
| | ) -> GroupedValidationErrors: |
| | """Removes key (json path), value (errors) pairs where the json path is fully |
| | contained in another json path. For example if `errors_by_json_path` has two |
| | keys, `$.encoding.X` and `$.encoding.X.tooltip`, then the first one will be removed |
| | and only the second one is returned. This is done under the assumption that |
| | more specific json paths give more helpful error messages to the user. |
| | """ |
| | errors_by_json_path_specific: GroupedValidationErrors = {} |
| | for json_path, errors in errors_by_json_path.items(): |
| | if not _contained_at_start_of_one_of_other_values( |
| | json_path, list(errors_by_json_path.keys()) |
| | ): |
| | errors_by_json_path_specific[json_path] = errors |
| | return errors_by_json_path_specific |
| |
|
| |
|
| | def _contained_at_start_of_one_of_other_values(x: str, values: Sequence[str]) -> bool: |
| | |
| | |
| | return any(value.startswith(x) for value in values if x != value) |
| |
|
| |
|
| | def _deduplicate_errors( |
| | grouped_errors: GroupedValidationErrors, |
| | ) -> GroupedValidationErrors: |
| | """Some errors have very similar error messages or are just in general not helpful |
| | for a user. This function removes as many of these cases as possible and |
| | can be extended over time to handle new cases that come up. |
| | """ |
| | grouped_errors_deduplicated: GroupedValidationErrors = {} |
| | for json_path, element_errors in grouped_errors.items(): |
| | errors_by_validator = _group_errors_by_validator(element_errors) |
| |
|
| | deduplication_functions = { |
| | "enum": _deduplicate_enum_errors, |
| | "additionalProperties": _deduplicate_additional_properties_errors, |
| | } |
| | deduplicated_errors: ValidationErrorList = [] |
| | for validator, errors in errors_by_validator.items(): |
| | deduplication_func = deduplication_functions.get(validator, None) |
| | if deduplication_func is not None: |
| | errors = deduplication_func(errors) |
| | deduplicated_errors.extend(_deduplicate_by_message(errors)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | deduplicated_errors = [ |
| | err for err in deduplicated_errors if not _is_required_value_error(err) |
| | ] |
| |
|
| | grouped_errors_deduplicated[json_path] = deduplicated_errors |
| | return grouped_errors_deduplicated |
| |
|
| |
|
| | def _is_required_value_error(err: jsonschema.exceptions.ValidationError) -> bool: |
| | return err.validator == "required" and err.validator_value == ["value"] |
| |
|
| |
|
| | def _group_errors_by_validator(errors: ValidationErrorList) -> GroupedValidationErrors: |
| | """Groups the errors by the json schema "validator" that casued the error. For |
| | example if the error is that a value is not one of an enumeration in the json schema |
| | then the "validator" is `"enum"`, if the error is due to an unknown property that |
| | was set although no additional properties are allowed then "validator" is |
| | `"additionalProperties`, etc. |
| | """ |
| | errors_by_validator: DefaultDict[ |
| | str, ValidationErrorList |
| | ] = collections.defaultdict(list) |
| | for err in errors: |
| | |
| | |
| | |
| | errors_by_validator[err.validator].append(err) |
| | return dict(errors_by_validator) |
| |
|
| |
|
| | def _deduplicate_enum_errors(errors: ValidationErrorList) -> ValidationErrorList: |
| | """Deduplicate enum errors by removing the errors where the allowed values |
| | are a subset of another error. For example, if `enum` contains two errors |
| | and one has `validator_value` (i.e. accepted values) ["A", "B"] and the |
| | other one ["A", "B", "C"] then the first one is removed and the final |
| | `enum` list only contains the error with ["A", "B", "C"]. |
| | """ |
| | if len(errors) > 1: |
| | |
| | |
| | |
| | value_strings = [",".join(err.validator_value) for err in errors] |
| | longest_enums: ValidationErrorList = [] |
| | for value_str, err in zip(value_strings, errors): |
| | if not _contained_at_start_of_one_of_other_values(value_str, value_strings): |
| | longest_enums.append(err) |
| | errors = longest_enums |
| | return errors |
| |
|
| |
|
| | def _deduplicate_additional_properties_errors( |
| | errors: ValidationErrorList, |
| | ) -> ValidationErrorList: |
| | """If there are multiple additional property errors it usually means that |
| | the offending element was validated against multiple schemas and |
| | its parent is a common anyOf validator. |
| | The error messages produced from these cases are usually |
| | very similar and we just take the shortest one. For example, |
| | the following 3 errors are raised for the `unknown` channel option in |
| | `alt.X("variety", unknown=2)`: |
| | - "Additional properties are not allowed ('unknown' was unexpected)" |
| | - "Additional properties are not allowed ('field', 'unknown' were unexpected)" |
| | - "Additional properties are not allowed ('field', 'type', 'unknown' were unexpected)" |
| | """ |
| | if len(errors) > 1: |
| | |
| | |
| | |
| | parent = errors[0].parent |
| | if ( |
| | parent is not None |
| | and parent.validator == "anyOf" |
| | |
| | |
| | and all(err.parent is parent for err in errors[1:]) |
| | ): |
| | errors = [min(errors, key=lambda x: len(x.message))] |
| | return errors |
| |
|
| |
|
| | def _deduplicate_by_message(errors: ValidationErrorList) -> ValidationErrorList: |
| | """Deduplicate errors by message. This keeps the original order in case |
| | it was chosen intentionally. |
| | """ |
| | return list({e.message: e for e in errors}.values()) |
| |
|
| |
|
| | def _subclasses(cls: type) -> Generator[type, None, None]: |
| | """Breadth-first sequence of all classes which inherit from cls.""" |
| | seen = set() |
| | current_set = {cls} |
| | while current_set: |
| | seen |= current_set |
| | current_set = set.union(*(set(cls.__subclasses__()) for cls in current_set)) |
| | for cls in current_set - seen: |
| | yield cls |
| |
|
| |
|
| | def _todict(obj: Any, context: Optional[Dict[str, Any]]) -> Any: |
| | """Convert an object to a dict representation.""" |
| | if isinstance(obj, SchemaBase): |
| | return obj.to_dict(validate=False, context=context) |
| | elif isinstance(obj, (list, tuple, np.ndarray)): |
| | return [_todict(v, context) for v in obj] |
| | elif isinstance(obj, dict): |
| | return {k: _todict(v, context) for k, v in obj.items() if v is not Undefined} |
| | elif hasattr(obj, "to_dict"): |
| | return obj.to_dict() |
| | elif isinstance(obj, np.number): |
| | return float(obj) |
| | elif isinstance(obj, (pd.Timestamp, np.datetime64)): |
| | return pd.Timestamp(obj).isoformat() |
| | else: |
| | return obj |
| |
|
| |
|
| | def _resolve_references( |
| | schema: Dict[str, Any], rootschema: Optional[Dict[str, Any]] = None |
| | ) -> Dict[str, Any]: |
| | """Resolve schema references until there is no $ref anymore |
| | in the top-level of the dictionary. |
| | """ |
| | if _use_referencing_library(): |
| | registry = _get_referencing_registry(rootschema or schema) |
| | |
| | |
| | |
| | referencing_resolver = registry.resolver() |
| | while "$ref" in schema: |
| | schema = referencing_resolver.lookup( |
| | _VEGA_LITE_ROOT_URI + schema["$ref"] |
| | ).contents |
| | else: |
| | resolver = jsonschema.RefResolver.from_schema(rootschema or schema) |
| | while "$ref" in schema: |
| | with resolver.resolving(schema["$ref"]) as resolved: |
| | schema = resolved |
| | return schema |
| |
|
| |
|
| | class SchemaValidationError(jsonschema.ValidationError): |
| | """A wrapper for jsonschema.ValidationError with friendlier traceback""" |
| |
|
| | def __init__(self, obj: "SchemaBase", err: jsonschema.ValidationError) -> None: |
| | super().__init__(**err._contents()) |
| | self.obj = obj |
| | self._errors: GroupedValidationErrors = getattr( |
| | err, "_all_errors", {getattr(err, "json_path", _json_path(err)): [err]} |
| | ) |
| | |
| | self._original_message = self.message |
| | self.message = self._get_message() |
| |
|
| | def __str__(self) -> str: |
| | return self.message |
| |
|
| | def _get_message(self) -> str: |
| | def indent_second_line_onwards(message: str, indent: int = 4) -> str: |
| | modified_lines: List[str] = [] |
| | for idx, line in enumerate(message.split("\n")): |
| | if idx > 0 and len(line) > 0: |
| | line = " " * indent + line |
| | modified_lines.append(line) |
| | return "\n".join(modified_lines) |
| |
|
| | error_messages: List[str] = [] |
| | |
| | |
| | for errors in list(self._errors.values())[:3]: |
| | error_messages.append(self._get_message_for_errors_group(errors)) |
| |
|
| | message = "" |
| | if len(error_messages) > 1: |
| | error_messages = [ |
| | indent_second_line_onwards(f"Error {error_id}: {m}") |
| | for error_id, m in enumerate(error_messages, start=1) |
| | ] |
| | message += "Multiple errors were found.\n\n" |
| | message += "\n\n".join(error_messages) |
| | return message |
| |
|
| | def _get_message_for_errors_group( |
| | self, |
| | errors: ValidationErrorList, |
| | ) -> str: |
| | if errors[0].validator == "additionalProperties": |
| | |
| | |
| | |
| | |
| | |
| | message = self._get_additional_properties_error_message(errors[0]) |
| | else: |
| | message = self._get_default_error_message(errors=errors) |
| |
|
| | return message.strip() |
| |
|
| | def _get_additional_properties_error_message( |
| | self, |
| | error: jsonschema.exceptions.ValidationError, |
| | ) -> str: |
| | """Output all existing parameters when an unknown parameter is specified.""" |
| | altair_cls = self._get_altair_class_for_error(error) |
| | param_dict_keys = inspect.signature(altair_cls).parameters.keys() |
| | param_names_table = self._format_params_as_table(param_dict_keys) |
| |
|
| | |
| | |
| | |
| | parameter_name = error.message.split("('")[-1].split("'")[0] |
| | message = f"""\ |
| | `{altair_cls.__name__}` has no parameter named '{parameter_name}' |
| | |
| | Existing parameter names are: |
| | {param_names_table} |
| | See the help for `{altair_cls.__name__}` to read the full description of these parameters""" |
| | return message |
| |
|
| | def _get_altair_class_for_error( |
| | self, error: jsonschema.exceptions.ValidationError |
| | ) -> Type["SchemaBase"]: |
| | """Try to get the lowest class possible in the chart hierarchy so |
| | it can be displayed in the error message. This should lead to more informative |
| | error messages pointing the user closer to the source of the issue. |
| | """ |
| | for prop_name in reversed(error.absolute_path): |
| | |
| | if isinstance(prop_name, str): |
| | potential_class_name = prop_name[0].upper() + prop_name[1:] |
| | cls = getattr(vegalite, potential_class_name, None) |
| | if cls is not None: |
| | break |
| | else: |
| | |
| | |
| | |
| | cls = self.obj.__class__ |
| | return cls |
| |
|
| | @staticmethod |
| | def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: |
| | """Format param names into a table so that they are easier to read""" |
| | param_names: Tuple[str, ...] |
| | name_lengths: Tuple[int, ...] |
| | param_names, name_lengths = zip( |
| | *[ |
| | (name, len(name)) |
| | for name in param_dict_keys |
| | if name not in ["kwds", "self"] |
| | ] |
| | ) |
| | |
| | |
| | max_name_length = max(name_lengths) |
| | max_column_width = 80 |
| | |
| | num_param_names = len(param_names) |
| | square_columns = int(np.ceil(num_param_names**0.5)) |
| | columns = min(max_column_width // max_name_length, square_columns) |
| |
|
| | |
| | def split_into_equal_parts(n: int, p: int) -> List[int]: |
| | return [n // p + 1] * (n % p) + [n // p] * (p - n % p) |
| |
|
| | column_heights = split_into_equal_parts(num_param_names, columns) |
| |
|
| | |
| | param_names_columns: List[Tuple[str, ...]] = [] |
| | column_max_widths: List[int] = [] |
| | last_end_idx: int = 0 |
| | for ch in column_heights: |
| | param_names_columns.append(param_names[last_end_idx : last_end_idx + ch]) |
| | column_max_widths.append( |
| | max([len(param_name) for param_name in param_names_columns[-1]]) |
| | ) |
| | last_end_idx = ch + last_end_idx |
| |
|
| | |
| | param_names_rows: List[Tuple[str, ...]] = [] |
| | for li in zip_longest(*param_names_columns, fillvalue=""): |
| | param_names_rows.append(li) |
| | |
| | param_names_table: str = "" |
| | for param_names_row in param_names_rows: |
| | for num, param_name in enumerate(param_names_row): |
| | |
| | max_name_length_column = column_max_widths[num] |
| | column_pad = 3 |
| | param_names_table += "{:<{}}".format( |
| | param_name, max_name_length_column + column_pad |
| | ) |
| | |
| | if num == (len(param_names_row) - 1): |
| | param_names_table += "\n" |
| | return param_names_table |
| |
|
| | def _get_default_error_message( |
| | self, |
| | errors: ValidationErrorList, |
| | ) -> str: |
| | bullet_points: List[str] = [] |
| | errors_by_validator = _group_errors_by_validator(errors) |
| | if "enum" in errors_by_validator: |
| | for error in errors_by_validator["enum"]: |
| | bullet_points.append(f"one of {error.validator_value}") |
| |
|
| | if "type" in errors_by_validator: |
| | types = [f"'{err.validator_value}'" for err in errors_by_validator["type"]] |
| | point = "of type " |
| | if len(types) == 1: |
| | point += types[0] |
| | elif len(types) == 2: |
| | point += f"{types[0]} or {types[1]}" |
| | else: |
| | point += ", ".join(types[:-1]) + f", or {types[-1]}" |
| | bullet_points.append(point) |
| |
|
| | |
| | |
| | |
| | error = errors[0] |
| | |
| | |
| | message = f"'{error.instance}' is an invalid value" |
| | if error.absolute_path: |
| | message += f" for `{error.absolute_path[-1]}`" |
| |
|
| | |
| | if len(bullet_points) == 0: |
| | message += ".\n\n" |
| | elif len(bullet_points) == 1: |
| | message += f". Valid values are {bullet_points[0]}.\n\n" |
| | else: |
| | |
| | |
| | bullet_points = [point[0].upper() + point[1:] for point in bullet_points] |
| | message += ". Valid values are:\n\n" |
| | message += "\n".join([f"- {point}" for point in bullet_points]) |
| | message += "\n\n" |
| |
|
| | |
| | |
| | |
| | for validator, errors in errors_by_validator.items(): |
| | if validator not in ("enum", "type"): |
| | message += "\n".join([e.message for e in errors]) |
| |
|
| | return message |
| |
|
| |
|
| | class UndefinedType: |
| | """A singleton object for marking undefined parameters""" |
| |
|
| | __instance = None |
| |
|
| | def __new__(cls, *args, **kwargs): |
| | if not isinstance(cls.__instance, cls): |
| | cls.__instance = object.__new__(cls, *args, **kwargs) |
| | return cls.__instance |
| |
|
| | def __repr__(self): |
| | return "Undefined" |
| |
|
| |
|
| | Undefined = UndefinedType() |
| |
|
| |
|
| | class SchemaBase: |
| | """Base class for schema wrappers. |
| | |
| | Each derived class should set the _schema class attribute (and optionally |
| | the _rootschema class attribute) which is used for validation. |
| | """ |
| |
|
| | _schema: Optional[Dict[str, Any]] = None |
| | _rootschema: Optional[Dict[str, Any]] = None |
| | _class_is_valid_at_instantiation: bool = True |
| |
|
| | def __init__(self, *args: Any, **kwds: Any) -> None: |
| | |
| | |
| | |
| | |
| | if self._schema is None: |
| | raise ValueError( |
| | "Cannot instantiate object of type {}: " |
| | "_schema class attribute is not defined." |
| | "".format(self.__class__) |
| | ) |
| |
|
| | if kwds: |
| | assert len(args) == 0 |
| | else: |
| | assert len(args) in [0, 1] |
| |
|
| | |
| | object.__setattr__(self, "_args", args) |
| | object.__setattr__(self, "_kwds", kwds) |
| |
|
| | if DEBUG_MODE and self._class_is_valid_at_instantiation: |
| | self.to_dict(validate=True) |
| |
|
| | def copy( |
| | self, deep: Union[bool, Iterable] = True, ignore: Optional[list] = None |
| | ) -> Self: |
| | """Return a copy of the object |
| | |
| | Parameters |
| | ---------- |
| | deep : boolean or list, optional |
| | If True (default) then return a deep copy of all dict, list, and |
| | SchemaBase objects within the object structure. |
| | If False, then only copy the top object. |
| | If a list or iterable, then only copy the listed attributes. |
| | ignore : list, optional |
| | A list of keys for which the contents should not be copied, but |
| | only stored by reference. |
| | """ |
| |
|
| | def _shallow_copy(obj): |
| | if isinstance(obj, SchemaBase): |
| | return obj.copy(deep=False) |
| | elif isinstance(obj, list): |
| | return obj[:] |
| | elif isinstance(obj, dict): |
| | return obj.copy() |
| | else: |
| | return obj |
| |
|
| | def _deep_copy(obj, ignore: Optional[list] = None): |
| | if ignore is None: |
| | ignore = [] |
| | if isinstance(obj, SchemaBase): |
| | args = tuple(_deep_copy(arg) for arg in obj._args) |
| | kwds = { |
| | k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) |
| | for k, v in obj._kwds.items() |
| | } |
| | with debug_mode(False): |
| | return obj.__class__(*args, **kwds) |
| | elif isinstance(obj, list): |
| | return [_deep_copy(v, ignore=ignore) for v in obj] |
| | elif isinstance(obj, dict): |
| | return { |
| | k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) |
| | for k, v in obj.items() |
| | } |
| | else: |
| | return obj |
| |
|
| | try: |
| | deep = list(deep) |
| | except TypeError: |
| | deep_is_list = False |
| | else: |
| | deep_is_list = True |
| |
|
| | if deep and not deep_is_list: |
| | return _deep_copy(self, ignore=ignore) |
| |
|
| | with debug_mode(False): |
| | copy = self.__class__(*self._args, **self._kwds) |
| | if deep_is_list: |
| | |
| | assert isinstance(deep, list) |
| | for attr in deep: |
| | copy[attr] = _shallow_copy(copy._get(attr)) |
| | return copy |
| |
|
| | def _get(self, attr, default=Undefined): |
| | """Get an attribute, returning default if not present.""" |
| | attr = self._kwds.get(attr, Undefined) |
| | if attr is Undefined: |
| | attr = default |
| | return attr |
| |
|
| | def __getattr__(self, attr): |
| | |
| | if attr == "_kwds": |
| | raise AttributeError() |
| | if attr in self._kwds: |
| | return self._kwds[attr] |
| | else: |
| | try: |
| | _getattr = super(SchemaBase, self).__getattr__ |
| | except AttributeError: |
| | _getattr = super(SchemaBase, self).__getattribute__ |
| | return _getattr(attr) |
| |
|
| | def __setattr__(self, item, val): |
| | self._kwds[item] = val |
| |
|
| | def __getitem__(self, item): |
| | return self._kwds[item] |
| |
|
| | def __setitem__(self, item, val): |
| | self._kwds[item] = val |
| |
|
| | def __repr__(self): |
| | if self._kwds: |
| | args = ( |
| | "{}: {!r}".format(key, val) |
| | for key, val in sorted(self._kwds.items()) |
| | if val is not Undefined |
| | ) |
| | args = "\n" + ",\n".join(args) |
| | return "{0}({{{1}\n}})".format( |
| | self.__class__.__name__, args.replace("\n", "\n ") |
| | ) |
| | else: |
| | return "{}({!r})".format(self.__class__.__name__, self._args[0]) |
| |
|
| | def __eq__(self, other): |
| | return ( |
| | type(self) is type(other) |
| | and self._args == other._args |
| | and self._kwds == other._kwds |
| | ) |
| |
|
| | def to_dict( |
| | self, |
| | validate: bool = True, |
| | *, |
| | ignore: Optional[List[str]] = None, |
| | context: Optional[Dict[str, Any]] = None, |
| | ) -> dict: |
| | """Return a dictionary representation of the object |
| | |
| | Parameters |
| | ---------- |
| | validate : bool, optional |
| | If True (default), then validate the output dictionary |
| | against the schema. |
| | ignore : list[str], optional |
| | A list of keys to ignore. It is usually not needed |
| | to specify this argument as a user. |
| | context : dict[str, Any], optional |
| | A context dictionary. It is usually not needed |
| | to specify this argument as a user. |
| | |
| | Notes |
| | ----- |
| | Technical: The ignore parameter will *not* be passed to child to_dict |
| | function calls. |
| | |
| | Returns |
| | ------- |
| | dict |
| | The dictionary representation of this object |
| | |
| | Raises |
| | ------ |
| | SchemaValidationError : |
| | if validate=True and the dict does not conform to the schema |
| | """ |
| | if context is None: |
| | context = {} |
| | if ignore is None: |
| | ignore = [] |
| |
|
| | if self._args and not self._kwds: |
| | result = _todict(self._args[0], context=context) |
| | elif not self._args: |
| | kwds = self._kwds.copy() |
| | |
| | |
| | |
| | |
| | parsed_shorthand = context.pop("parsed_shorthand", {}) |
| | |
| | |
| | |
| | if "sort" in parsed_shorthand and ( |
| | "sort" not in kwds or kwds["type"] not in ["ordinal", Undefined] |
| | ): |
| | parsed_shorthand.pop("sort") |
| |
|
| | kwds.update( |
| | { |
| | k: v |
| | for k, v in parsed_shorthand.items() |
| | if kwds.get(k, Undefined) is Undefined |
| | } |
| | ) |
| | kwds = { |
| | k: v for k, v in kwds.items() if k not in list(ignore) + ["shorthand"] |
| | } |
| | if "mark" in kwds and isinstance(kwds["mark"], str): |
| | kwds["mark"] = {"type": kwds["mark"]} |
| | result = _todict( |
| | kwds, |
| | context=context, |
| | ) |
| | else: |
| | raise ValueError( |
| | "{} instance has both a value and properties : " |
| | "cannot serialize to dict".format(self.__class__) |
| | ) |
| | if validate: |
| | try: |
| | self.validate(result) |
| | except jsonschema.ValidationError as err: |
| | |
| | |
| | |
| | |
| | |
| | raise SchemaValidationError(self, err) from None |
| | return result |
| |
|
| | def to_json( |
| | self, |
| | validate: bool = True, |
| | indent: int = 2, |
| | sort_keys: bool = True, |
| | *, |
| | ignore: Optional[List[str]] = None, |
| | context: Optional[Dict[str, Any]] = None, |
| | **kwargs, |
| | ) -> str: |
| | """Emit the JSON representation for this object as a string. |
| | |
| | Parameters |
| | ---------- |
| | validate : bool, optional |
| | If True (default), then validate the output dictionary |
| | against the schema. |
| | indent : int, optional |
| | The number of spaces of indentation to use. The default is 2. |
| | sort_keys : bool, optional |
| | If True (default), sort keys in the output. |
| | ignore : list[str], optional |
| | A list of keys to ignore. It is usually not needed |
| | to specify this argument as a user. |
| | context : dict[str, Any], optional |
| | A context dictionary. It is usually not needed |
| | to specify this argument as a user. |
| | **kwargs |
| | Additional keyword arguments are passed to ``json.dumps()`` |
| | |
| | Notes |
| | ----- |
| | Technical: The ignore parameter will *not* be passed to child to_dict |
| | function calls. |
| | |
| | Returns |
| | ------- |
| | str |
| | The JSON specification of the chart object. |
| | """ |
| | if ignore is None: |
| | ignore = [] |
| | if context is None: |
| | context = {} |
| | dct = self.to_dict(validate=validate, ignore=ignore, context=context) |
| | return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs) |
| |
|
| | @classmethod |
| | def _default_wrapper_classes(cls) -> Generator[Type["SchemaBase"], None, None]: |
| | """Return the set of classes used within cls.from_dict()""" |
| | return _subclasses(SchemaBase) |
| |
|
| | @classmethod |
| | def from_dict( |
| | cls, |
| | dct: dict, |
| | validate: bool = True, |
| | _wrapper_classes: Optional[Iterable[Type["SchemaBase"]]] = None, |
| | |
| | |
| | ) -> "SchemaBase": |
| | """Construct class from a dictionary representation |
| | |
| | Parameters |
| | ---------- |
| | dct : dictionary |
| | The dict from which to construct the class |
| | validate : boolean |
| | If True (default), then validate the input against the schema. |
| | _wrapper_classes : iterable (optional) |
| | The set of SchemaBase classes to use when constructing wrappers |
| | of the dict inputs. If not specified, the result of |
| | cls._default_wrapper_classes will be used. |
| | |
| | Returns |
| | ------- |
| | obj : Schema object |
| | The wrapped schema |
| | |
| | Raises |
| | ------ |
| | jsonschema.ValidationError : |
| | if validate=True and dct does not conform to the schema |
| | """ |
| | if validate: |
| | cls.validate(dct) |
| | if _wrapper_classes is None: |
| | _wrapper_classes = cls._default_wrapper_classes() |
| | converter = _FromDict(_wrapper_classes) |
| | return converter.from_dict(dct, cls) |
| |
|
| | @classmethod |
| | def from_json( |
| | cls, |
| | json_string: str, |
| | validate: bool = True, |
| | **kwargs: Any, |
| | |
| | |
| | ) -> Any: |
| | """Instantiate the object from a valid JSON string |
| | |
| | Parameters |
| | ---------- |
| | json_string : string |
| | The string containing a valid JSON chart specification. |
| | validate : boolean |
| | If True (default), then validate the input against the schema. |
| | **kwargs : |
| | Additional keyword arguments are passed to json.loads |
| | |
| | Returns |
| | ------- |
| | chart : Chart object |
| | The altair Chart object built from the specification. |
| | """ |
| | dct = json.loads(json_string, **kwargs) |
| | return cls.from_dict(dct, validate=validate) |
| |
|
| | @classmethod |
| | def validate( |
| | cls, instance: Dict[str, Any], schema: Optional[Dict[str, Any]] = None |
| | ) -> None: |
| | """ |
| | Validate the instance against the class schema in the context of the |
| | rootschema. |
| | """ |
| | if schema is None: |
| | schema = cls._schema |
| | |
| | assert schema is not None |
| | return validate_jsonschema( |
| | instance, schema, rootschema=cls._rootschema or cls._schema |
| | ) |
| |
|
| | @classmethod |
| | def resolve_references(cls, schema: Optional[dict] = None) -> dict: |
| | """Resolve references in the context of this object's schema or root schema.""" |
| | schema_to_pass = schema or cls._schema |
| | |
| | assert schema_to_pass is not None |
| | return _resolve_references( |
| | schema=schema_to_pass, |
| | rootschema=(cls._rootschema or cls._schema or schema), |
| | ) |
| |
|
| | @classmethod |
| | def validate_property( |
| | cls, name: str, value: Any, schema: Optional[dict] = None |
| | ) -> None: |
| | """ |
| | Validate a property against property schema in the context of the |
| | rootschema |
| | """ |
| | value = _todict(value, context={}) |
| | props = cls.resolve_references(schema or cls._schema).get("properties", {}) |
| | return validate_jsonschema( |
| | value, props.get(name, {}), rootschema=cls._rootschema or cls._schema |
| | ) |
| |
|
| | def __dir__(self) -> list: |
| | return sorted(list(super().__dir__()) + list(self._kwds.keys())) |
| |
|
| |
|
| | def _passthrough(*args, **kwds): |
| | return args[0] if args else kwds |
| |
|
| |
|
| | class _FromDict: |
| | """Class used to construct SchemaBase class hierarchies from a dict |
| | |
| | The primary purpose of using this class is to be able to build a hash table |
| | that maps schemas to their wrapper classes. The candidate classes are |
| | specified in the ``class_list`` argument to the constructor. |
| | """ |
| |
|
| | _hash_exclude_keys = ("definitions", "title", "description", "$schema", "id") |
| |
|
| | def __init__(self, class_list: Iterable[Type[SchemaBase]]) -> None: |
| | |
| | |
| | self.class_dict = collections.defaultdict(list) |
| | for cls in class_list: |
| | if cls._schema is not None: |
| | self.class_dict[self.hash_schema(cls._schema)].append(cls) |
| |
|
| | @classmethod |
| | def hash_schema(cls, schema: dict, use_json: bool = True) -> int: |
| | """ |
| | Compute a python hash for a nested dictionary which |
| | properly handles dicts, lists, sets, and tuples. |
| | |
| | At the top level, the function excludes from the hashed schema all keys |
| | listed in `exclude_keys`. |
| | |
| | This implements two methods: one based on conversion to JSON, and one based |
| | on recursive conversions of unhashable to hashable types; the former seems |
| | to be slightly faster in several benchmarks. |
| | """ |
| | if cls._hash_exclude_keys and isinstance(schema, dict): |
| | schema = { |
| | key: val |
| | for key, val in schema.items() |
| | if key not in cls._hash_exclude_keys |
| | } |
| | if use_json: |
| | s = json.dumps(schema, sort_keys=True) |
| | return hash(s) |
| | else: |
| |
|
| | def _freeze(val): |
| | if isinstance(val, dict): |
| | return frozenset((k, _freeze(v)) for k, v in val.items()) |
| | elif isinstance(val, set): |
| | return frozenset(map(_freeze, val)) |
| | elif isinstance(val, list) or isinstance(val, tuple): |
| | return tuple(map(_freeze, val)) |
| | else: |
| | return val |
| |
|
| | return hash(_freeze(schema)) |
| |
|
| | def from_dict( |
| | self, |
| | dct: dict, |
| | cls: Optional[Type[SchemaBase]] = None, |
| | schema: Optional[dict] = None, |
| | rootschema: Optional[dict] = None, |
| | default_class=_passthrough, |
| | |
| | |
| | ) -> Any: |
| | """Construct an object from a dict representation""" |
| | if (schema is None) == (cls is None): |
| | raise ValueError("Must provide either cls or schema, but not both.") |
| | if schema is None: |
| | |
| | schema = cls._schema |
| | |
| | assert schema is not None |
| | if rootschema: |
| | rootschema = rootschema |
| | elif cls is not None and cls._rootschema is not None: |
| | rootschema = cls._rootschema |
| | else: |
| | rootschema = None |
| | rootschema = rootschema or schema |
| |
|
| | if isinstance(dct, SchemaBase): |
| | return dct |
| |
|
| | if cls is None: |
| | |
| | |
| | |
| | matches = self.class_dict[self.hash_schema(schema)] |
| | if matches: |
| | cls = matches[0] |
| | else: |
| | cls = default_class |
| | schema = _resolve_references(schema, rootschema) |
| |
|
| | if "anyOf" in schema or "oneOf" in schema: |
| | schemas = schema.get("anyOf", []) + schema.get("oneOf", []) |
| | for possible_schema in schemas: |
| | try: |
| | validate_jsonschema(dct, possible_schema, rootschema=rootschema) |
| | except jsonschema.ValidationError: |
| | continue |
| | else: |
| | return self.from_dict( |
| | dct, |
| | schema=possible_schema, |
| | rootschema=rootschema, |
| | default_class=cls, |
| | ) |
| |
|
| | if isinstance(dct, dict): |
| | |
| | props = schema.get("properties", {}) |
| | kwds = {} |
| | for key, val in dct.items(): |
| | if key in props: |
| | val = self.from_dict(val, schema=props[key], rootschema=rootschema) |
| | kwds[key] = val |
| | return cls(**kwds) |
| |
|
| | elif isinstance(dct, list): |
| | item_schema = schema.get("items", {}) |
| | dct = [ |
| | self.from_dict(val, schema=item_schema, rootschema=rootschema) |
| | for val in dct |
| | ] |
| | return cls(dct) |
| | else: |
| | return cls(dct) |
| |
|
| |
|
| | class _PropertySetter: |
| | def __init__(self, prop: str, schema: dict) -> None: |
| | self.prop = prop |
| | self.schema = schema |
| |
|
| | def __get__(self, obj, cls): |
| | self.obj = obj |
| | self.cls = cls |
| | |
| | |
| | self.__doc__ = self.schema["description"].replace("__", "**") |
| | property_name = f"{self.prop}"[0].upper() + f"{self.prop}"[1:] |
| | if hasattr(vegalite, property_name): |
| | altair_prop = getattr(vegalite, property_name) |
| | |
| | |
| | |
| | parameter_index = altair_prop.__doc__.find("Parameters\n") |
| | if parameter_index > -1: |
| | self.__doc__ = ( |
| | altair_prop.__doc__[:parameter_index].replace(" ", "") |
| | + self.__doc__ |
| | + textwrap.dedent( |
| | f"\n\n {altair_prop.__doc__[parameter_index:]}" |
| | ) |
| | ) |
| | |
| | else: |
| | self.__doc__ = ( |
| | altair_prop.__doc__.replace(" ", "") + "\n" + self.__doc__ |
| | ) |
| | |
| | self.__signature__ = inspect.signature(altair_prop) |
| | self.__wrapped__ = inspect.getfullargspec(altair_prop) |
| | self.__name__ = altair_prop.__name__ |
| | else: |
| | |
| | |
| | pass |
| | return self |
| |
|
| | def __call__(self, *args, **kwargs): |
| | obj = self.obj.copy() |
| | |
| | obj[self.prop] = args[0] if args else kwargs |
| | return obj |
| |
|
| |
|
| | def with_property_setters(cls: TSchemaBase) -> TSchemaBase: |
| | """ |
| | Decorator to add property setters to a Schema class. |
| | """ |
| | schema = cls.resolve_references() |
| | for prop, propschema in schema.get("properties", {}).items(): |
| | setattr(cls, prop, _PropertySetter(prop, propschema)) |
| | return cls |
| |
|