|
|
|
|
|
import collections |
|
import contextlib |
|
import copy |
|
import inspect |
|
import json |
|
import sys |
|
import textwrap |
|
from typing import ( |
|
Any, |
|
Sequence, |
|
List, |
|
Dict, |
|
Optional, |
|
DefaultDict, |
|
Tuple, |
|
Iterable, |
|
Type, |
|
Generator, |
|
Union, |
|
overload, |
|
Literal, |
|
TypeVar, |
|
) |
|
from itertools import zip_longest |
|
from importlib.metadata import version as importlib_version |
|
from typing import Final |
|
|
|
import jsonschema |
|
import jsonschema.exceptions |
|
import jsonschema.validators |
|
import numpy as np |
|
import pandas as pd |
|
from packaging.version import Version |
|
|
|
|
|
|
|
|
|
from altair import vegalite |
|
|
|
if sys.version_info >= (3, 11): |
|
from typing import Self |
|
else: |
|
from typing_extensions import Self |
|
|
|
_TSchemaBase = TypeVar("_TSchemaBase", bound="SchemaBase") |
|
|
|
ValidationErrorList = List[jsonschema.exceptions.ValidationError] |
|
GroupedValidationErrors = Dict[str, ValidationErrorList] |
|
|
|
|
|
|
|
|
|
_VEGA_LITE_ROOT_URI: Final = "urn:vega-lite-schema" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_DEFAULT_JSON_SCHEMA_DRAFT_URL: Final = "http://json-schema.org/draft-07/schema#" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG_MODE: bool = True |
|
|
|
|
|
def enable_debug_mode() -> None: |
|
global DEBUG_MODE |
|
DEBUG_MODE = True |
|
|
|
|
|
def disable_debug_mode() -> None: |
|
global DEBUG_MODE |
|
DEBUG_MODE = False |
|
|
|
|
|
@contextlib.contextmanager |
|
def debug_mode(arg: bool) -> Generator[None, None, None]: |
|
global DEBUG_MODE |
|
original = DEBUG_MODE |
|
DEBUG_MODE = arg |
|
try: |
|
yield |
|
finally: |
|
DEBUG_MODE = original |
|
|
|
|
|
@overload |
|
def validate_jsonschema( |
|
spec: Dict[str, Any], |
|
schema: Dict[str, Any], |
|
rootschema: Optional[Dict[str, Any]] = ..., |
|
*, |
|
raise_error: Literal[True] = ..., |
|
) -> None: |
|
... |
|
|
|
|
|
@overload |
|
def validate_jsonschema( |
|
spec: Dict[str, Any], |
|
schema: Dict[str, Any], |
|
rootschema: Optional[Dict[str, Any]] = ..., |
|
*, |
|
raise_error: Literal[False], |
|
) -> Optional[jsonschema.exceptions.ValidationError]: |
|
... |
|
|
|
|
|
def validate_jsonschema( |
|
spec, |
|
schema, |
|
rootschema=None, |
|
*, |
|
raise_error=True, |
|
): |
|
"""Validates the passed in spec against the schema in the context of the |
|
rootschema. If any errors are found, they are deduplicated and prioritized |
|
and only the most relevant errors are kept. Errors are then either raised |
|
or returned, depending on the value of `raise_error`. |
|
""" |
|
errors = _get_errors_from_spec(spec, schema, rootschema=rootschema) |
|
if errors: |
|
leaf_errors = _get_leaves_of_error_tree(errors) |
|
grouped_errors = _group_errors_by_json_path(leaf_errors) |
|
grouped_errors = _subset_to_most_specific_json_paths(grouped_errors) |
|
grouped_errors = _deduplicate_errors(grouped_errors) |
|
|
|
|
|
|
|
main_error = list(grouped_errors.values())[0][0] |
|
|
|
|
|
|
|
|
|
|
|
main_error._all_errors = grouped_errors |
|
if raise_error: |
|
raise main_error |
|
else: |
|
return main_error |
|
else: |
|
return None |
|
|
|
|
|
def _get_errors_from_spec( |
|
spec: Dict[str, Any], |
|
schema: Dict[str, Any], |
|
rootschema: Optional[Dict[str, Any]] = None, |
|
) -> ValidationErrorList: |
|
"""Uses the relevant jsonschema validator to validate the passed in spec |
|
against the schema using the rootschema to resolve references. |
|
The schema and rootschema themselves are not validated but instead considered |
|
as valid. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
json_schema_draft_url = _get_json_schema_draft_url(rootschema or schema) |
|
validator_cls = jsonschema.validators.validator_for( |
|
{"$schema": json_schema_draft_url} |
|
) |
|
validator_kwargs: Dict[str, Any] = {} |
|
if hasattr(validator_cls, "FORMAT_CHECKER"): |
|
validator_kwargs["format_checker"] = validator_cls.FORMAT_CHECKER |
|
|
|
if _use_referencing_library(): |
|
schema = _prepare_references_in_schema(schema) |
|
validator_kwargs["registry"] = _get_referencing_registry( |
|
rootschema or schema, json_schema_draft_url |
|
) |
|
|
|
else: |
|
|
|
validator_kwargs["resolver"] = ( |
|
jsonschema.RefResolver.from_schema(rootschema) |
|
if rootschema is not None |
|
else None |
|
) |
|
|
|
validator = validator_cls(schema, **validator_kwargs) |
|
errors = list(validator.iter_errors(spec)) |
|
return errors |
|
|
|
|
|
def _get_json_schema_draft_url(schema: dict) -> str: |
|
return schema.get("$schema", _DEFAULT_JSON_SCHEMA_DRAFT_URL) |
|
|
|
|
|
def _use_referencing_library() -> bool: |
|
"""In version 4.18.0, the jsonschema package deprecated RefResolver in |
|
favor of the referencing library.""" |
|
jsonschema_version_str = importlib_version("jsonschema") |
|
return Version(jsonschema_version_str) >= Version("4.18") |
|
|
|
|
|
def _prepare_references_in_schema(schema: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
|
|
schema = copy.deepcopy(schema) |
|
|
|
def _prepare_refs(d: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Add _VEGA_LITE_ROOT_URI in front of all $ref values. This function |
|
recursively iterates through the whole dictionary.""" |
|
for key, value in d.items(): |
|
if key == "$ref": |
|
d[key] = _VEGA_LITE_ROOT_URI + d[key] |
|
else: |
|
|
|
|
|
|
|
|
|
if isinstance(value, dict): |
|
d[key] = _prepare_refs(value) |
|
elif isinstance(value, list): |
|
prepared_values = [] |
|
for v in value: |
|
if isinstance(v, dict): |
|
v = _prepare_refs(v) |
|
prepared_values.append(v) |
|
d[key] = prepared_values |
|
return d |
|
|
|
schema = _prepare_refs(schema) |
|
return schema |
|
|
|
|
|
|
|
|
|
def _get_referencing_registry( |
|
rootschema: Dict[str, Any], json_schema_draft_url: Optional[str] = None |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import referencing |
|
import referencing.jsonschema |
|
|
|
if json_schema_draft_url is None: |
|
json_schema_draft_url = _get_json_schema_draft_url(rootschema) |
|
|
|
specification = referencing.jsonschema.specification_with(json_schema_draft_url) |
|
resource = specification.create_resource(rootschema) |
|
return referencing.Registry().with_resource( |
|
uri=_VEGA_LITE_ROOT_URI, resource=resource |
|
) |
|
|
|
|
|
def _json_path(err: jsonschema.exceptions.ValidationError) -> str: |
|
"""Drop in replacement for the .json_path property of the jsonschema |
|
ValidationError class, which is not available as property for |
|
ValidationError with jsonschema<4.0.1. |
|
More info, see https://github.com/altair-viz/altair/issues/3038 |
|
""" |
|
path = "$" |
|
for elem in err.absolute_path: |
|
if isinstance(elem, int): |
|
path += "[" + str(elem) + "]" |
|
else: |
|
path += "." + elem |
|
return path |
|
|
|
|
|
def _group_errors_by_json_path( |
|
errors: ValidationErrorList, |
|
) -> GroupedValidationErrors: |
|
"""Groups errors by the `json_path` attribute of the jsonschema ValidationError |
|
class. This attribute contains the path to the offending element within |
|
a chart specification and can therefore be considered as an identifier of an |
|
'issue' in the chart that needs to be fixed. |
|
""" |
|
errors_by_json_path = collections.defaultdict(list) |
|
for err in errors: |
|
err_key = getattr(err, "json_path", _json_path(err)) |
|
errors_by_json_path[err_key].append(err) |
|
return dict(errors_by_json_path) |
|
|
|
|
|
def _get_leaves_of_error_tree( |
|
errors: ValidationErrorList, |
|
) -> ValidationErrorList: |
|
"""For each error in `errors`, it traverses down the "error tree" that is generated |
|
by the jsonschema library to find and return all "leaf" errors. These are errors |
|
which have no further errors that caused it and so they are the most specific errors |
|
with the most specific error messages. |
|
""" |
|
leaves: ValidationErrorList = [] |
|
for err in errors: |
|
if err.context: |
|
|
|
|
|
|
|
leaves.extend(_get_leaves_of_error_tree(err.context)) |
|
else: |
|
leaves.append(err) |
|
return leaves |
|
|
|
|
|
def _subset_to_most_specific_json_paths( |
|
errors_by_json_path: GroupedValidationErrors, |
|
) -> GroupedValidationErrors: |
|
"""Removes key (json path), value (errors) pairs where the json path is fully |
|
contained in another json path. For example if `errors_by_json_path` has two |
|
keys, `$.encoding.X` and `$.encoding.X.tooltip`, then the first one will be removed |
|
and only the second one is returned. This is done under the assumption that |
|
more specific json paths give more helpful error messages to the user. |
|
""" |
|
errors_by_json_path_specific: GroupedValidationErrors = {} |
|
for json_path, errors in errors_by_json_path.items(): |
|
if not _contained_at_start_of_one_of_other_values( |
|
json_path, list(errors_by_json_path.keys()) |
|
): |
|
errors_by_json_path_specific[json_path] = errors |
|
return errors_by_json_path_specific |
|
|
|
|
|
def _contained_at_start_of_one_of_other_values(x: str, values: Sequence[str]) -> bool: |
|
|
|
|
|
return any(value.startswith(x) for value in values if x != value) |
|
|
|
|
|
def _deduplicate_errors( |
|
grouped_errors: GroupedValidationErrors, |
|
) -> GroupedValidationErrors: |
|
"""Some errors have very similar error messages or are just in general not helpful |
|
for a user. This function removes as many of these cases as possible and |
|
can be extended over time to handle new cases that come up. |
|
""" |
|
grouped_errors_deduplicated: GroupedValidationErrors = {} |
|
for json_path, element_errors in grouped_errors.items(): |
|
errors_by_validator = _group_errors_by_validator(element_errors) |
|
|
|
deduplication_functions = { |
|
"enum": _deduplicate_enum_errors, |
|
"additionalProperties": _deduplicate_additional_properties_errors, |
|
} |
|
deduplicated_errors: ValidationErrorList = [] |
|
for validator, errors in errors_by_validator.items(): |
|
deduplication_func = deduplication_functions.get(validator, None) |
|
if deduplication_func is not None: |
|
errors = deduplication_func(errors) |
|
deduplicated_errors.extend(_deduplicate_by_message(errors)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deduplicated_errors = [ |
|
err for err in deduplicated_errors if not _is_required_value_error(err) |
|
] |
|
|
|
grouped_errors_deduplicated[json_path] = deduplicated_errors |
|
return grouped_errors_deduplicated |
|
|
|
|
|
def _is_required_value_error(err: jsonschema.exceptions.ValidationError) -> bool: |
|
return err.validator == "required" and err.validator_value == ["value"] |
|
|
|
|
|
def _group_errors_by_validator(errors: ValidationErrorList) -> GroupedValidationErrors: |
|
"""Groups the errors by the json schema "validator" that casued the error. For |
|
example if the error is that a value is not one of an enumeration in the json schema |
|
then the "validator" is `"enum"`, if the error is due to an unknown property that |
|
was set although no additional properties are allowed then "validator" is |
|
`"additionalProperties`, etc. |
|
""" |
|
errors_by_validator: DefaultDict[ |
|
str, ValidationErrorList |
|
] = collections.defaultdict(list) |
|
for err in errors: |
|
|
|
|
|
|
|
errors_by_validator[err.validator].append(err) |
|
return dict(errors_by_validator) |
|
|
|
|
|
def _deduplicate_enum_errors(errors: ValidationErrorList) -> ValidationErrorList: |
|
"""Deduplicate enum errors by removing the errors where the allowed values |
|
are a subset of another error. For example, if `enum` contains two errors |
|
and one has `validator_value` (i.e. accepted values) ["A", "B"] and the |
|
other one ["A", "B", "C"] then the first one is removed and the final |
|
`enum` list only contains the error with ["A", "B", "C"]. |
|
""" |
|
if len(errors) > 1: |
|
|
|
|
|
|
|
value_strings = [",".join(err.validator_value) for err in errors] |
|
longest_enums: ValidationErrorList = [] |
|
for value_str, err in zip(value_strings, errors): |
|
if not _contained_at_start_of_one_of_other_values(value_str, value_strings): |
|
longest_enums.append(err) |
|
errors = longest_enums |
|
return errors |
|
|
|
|
|
def _deduplicate_additional_properties_errors( |
|
errors: ValidationErrorList, |
|
) -> ValidationErrorList: |
|
"""If there are multiple additional property errors it usually means that |
|
the offending element was validated against multiple schemas and |
|
its parent is a common anyOf validator. |
|
The error messages produced from these cases are usually |
|
very similar and we just take the shortest one. For example, |
|
the following 3 errors are raised for the `unknown` channel option in |
|
`alt.X("variety", unknown=2)`: |
|
- "Additional properties are not allowed ('unknown' was unexpected)" |
|
- "Additional properties are not allowed ('field', 'unknown' were unexpected)" |
|
- "Additional properties are not allowed ('field', 'type', 'unknown' were unexpected)" |
|
""" |
|
if len(errors) > 1: |
|
|
|
|
|
|
|
parent = errors[0].parent |
|
if ( |
|
parent is not None |
|
and parent.validator == "anyOf" |
|
|
|
|
|
and all(err.parent is parent for err in errors[1:]) |
|
): |
|
errors = [min(errors, key=lambda x: len(x.message))] |
|
return errors |
|
|
|
|
|
def _deduplicate_by_message(errors: ValidationErrorList) -> ValidationErrorList: |
|
"""Deduplicate errors by message. This keeps the original order in case |
|
it was chosen intentionally. |
|
""" |
|
return list({e.message: e for e in errors}.values()) |
|
|
|
|
|
def _subclasses(cls: type) -> Generator[type, None, None]: |
|
"""Breadth-first sequence of all classes which inherit from cls.""" |
|
seen = set() |
|
current_set = {cls} |
|
while current_set: |
|
seen |= current_set |
|
current_set = set.union(*(set(cls.__subclasses__()) for cls in current_set)) |
|
for cls in current_set - seen: |
|
yield cls |
|
|
|
|
|
def _todict(obj: Any, context: Optional[Dict[str, Any]]) -> Any: |
|
"""Convert an object to a dict representation.""" |
|
if isinstance(obj, SchemaBase): |
|
return obj.to_dict(validate=False, context=context) |
|
elif isinstance(obj, (list, tuple, np.ndarray)): |
|
return [_todict(v, context) for v in obj] |
|
elif isinstance(obj, dict): |
|
return {k: _todict(v, context) for k, v in obj.items() if v is not Undefined} |
|
elif hasattr(obj, "to_dict"): |
|
return obj.to_dict() |
|
elif isinstance(obj, np.number): |
|
return float(obj) |
|
elif isinstance(obj, (pd.Timestamp, np.datetime64)): |
|
return pd.Timestamp(obj).isoformat() |
|
else: |
|
return obj |
|
|
|
|
|
def _resolve_references( |
|
schema: Dict[str, Any], rootschema: Optional[Dict[str, Any]] = None |
|
) -> Dict[str, Any]: |
|
"""Resolve schema references until there is no $ref anymore |
|
in the top-level of the dictionary. |
|
""" |
|
if _use_referencing_library(): |
|
registry = _get_referencing_registry(rootschema or schema) |
|
|
|
|
|
|
|
referencing_resolver = registry.resolver() |
|
while "$ref" in schema: |
|
schema = referencing_resolver.lookup( |
|
_VEGA_LITE_ROOT_URI + schema["$ref"] |
|
).contents |
|
else: |
|
resolver = jsonschema.RefResolver.from_schema(rootschema or schema) |
|
while "$ref" in schema: |
|
with resolver.resolving(schema["$ref"]) as resolved: |
|
schema = resolved |
|
return schema |
|
|
|
|
|
class SchemaValidationError(jsonschema.ValidationError): |
|
"""A wrapper for jsonschema.ValidationError with friendlier traceback""" |
|
|
|
def __init__(self, obj: "SchemaBase", err: jsonschema.ValidationError) -> None: |
|
super().__init__(**err._contents()) |
|
self.obj = obj |
|
self._errors: GroupedValidationErrors = getattr( |
|
err, "_all_errors", {getattr(err, "json_path", _json_path(err)): [err]} |
|
) |
|
|
|
self._original_message = self.message |
|
self.message = self._get_message() |
|
|
|
def __str__(self) -> str: |
|
return self.message |
|
|
|
def _get_message(self) -> str: |
|
def indent_second_line_onwards(message: str, indent: int = 4) -> str: |
|
modified_lines: List[str] = [] |
|
for idx, line in enumerate(message.split("\n")): |
|
if idx > 0 and len(line) > 0: |
|
line = " " * indent + line |
|
modified_lines.append(line) |
|
return "\n".join(modified_lines) |
|
|
|
error_messages: List[str] = [] |
|
|
|
|
|
for errors in list(self._errors.values())[:3]: |
|
error_messages.append(self._get_message_for_errors_group(errors)) |
|
|
|
message = "" |
|
if len(error_messages) > 1: |
|
error_messages = [ |
|
indent_second_line_onwards(f"Error {error_id}: {m}") |
|
for error_id, m in enumerate(error_messages, start=1) |
|
] |
|
message += "Multiple errors were found.\n\n" |
|
message += "\n\n".join(error_messages) |
|
return message |
|
|
|
def _get_message_for_errors_group( |
|
self, |
|
errors: ValidationErrorList, |
|
) -> str: |
|
if errors[0].validator == "additionalProperties": |
|
|
|
|
|
|
|
|
|
|
|
message = self._get_additional_properties_error_message(errors[0]) |
|
else: |
|
message = self._get_default_error_message(errors=errors) |
|
|
|
return message.strip() |
|
|
|
def _get_additional_properties_error_message( |
|
self, |
|
error: jsonschema.exceptions.ValidationError, |
|
) -> str: |
|
"""Output all existing parameters when an unknown parameter is specified.""" |
|
altair_cls = self._get_altair_class_for_error(error) |
|
param_dict_keys = inspect.signature(altair_cls).parameters.keys() |
|
param_names_table = self._format_params_as_table(param_dict_keys) |
|
|
|
|
|
|
|
|
|
parameter_name = error.message.split("('")[-1].split("'")[0] |
|
message = f"""\ |
|
`{altair_cls.__name__}` has no parameter named '{parameter_name}' |
|
|
|
Existing parameter names are: |
|
{param_names_table} |
|
See the help for `{altair_cls.__name__}` to read the full description of these parameters""" |
|
return message |
|
|
|
def _get_altair_class_for_error( |
|
self, error: jsonschema.exceptions.ValidationError |
|
) -> Type["SchemaBase"]: |
|
"""Try to get the lowest class possible in the chart hierarchy so |
|
it can be displayed in the error message. This should lead to more informative |
|
error messages pointing the user closer to the source of the issue. |
|
""" |
|
for prop_name in reversed(error.absolute_path): |
|
|
|
if isinstance(prop_name, str): |
|
potential_class_name = prop_name[0].upper() + prop_name[1:] |
|
cls = getattr(vegalite, potential_class_name, None) |
|
if cls is not None: |
|
break |
|
else: |
|
|
|
|
|
|
|
cls = self.obj.__class__ |
|
return cls |
|
|
|
@staticmethod |
|
def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: |
|
"""Format param names into a table so that they are easier to read""" |
|
param_names: Tuple[str, ...] |
|
name_lengths: Tuple[int, ...] |
|
param_names, name_lengths = zip( |
|
*[ |
|
(name, len(name)) |
|
for name in param_dict_keys |
|
if name not in ["kwds", "self"] |
|
] |
|
) |
|
|
|
|
|
max_name_length = max(name_lengths) |
|
max_column_width = 80 |
|
|
|
num_param_names = len(param_names) |
|
square_columns = int(np.ceil(num_param_names**0.5)) |
|
columns = min(max_column_width // max_name_length, square_columns) |
|
|
|
|
|
def split_into_equal_parts(n: int, p: int) -> List[int]: |
|
return [n // p + 1] * (n % p) + [n // p] * (p - n % p) |
|
|
|
column_heights = split_into_equal_parts(num_param_names, columns) |
|
|
|
|
|
param_names_columns: List[Tuple[str, ...]] = [] |
|
column_max_widths: List[int] = [] |
|
last_end_idx: int = 0 |
|
for ch in column_heights: |
|
param_names_columns.append(param_names[last_end_idx : last_end_idx + ch]) |
|
column_max_widths.append( |
|
max([len(param_name) for param_name in param_names_columns[-1]]) |
|
) |
|
last_end_idx = ch + last_end_idx |
|
|
|
|
|
param_names_rows: List[Tuple[str, ...]] = [] |
|
for li in zip_longest(*param_names_columns, fillvalue=""): |
|
param_names_rows.append(li) |
|
|
|
param_names_table: str = "" |
|
for param_names_row in param_names_rows: |
|
for num, param_name in enumerate(param_names_row): |
|
|
|
max_name_length_column = column_max_widths[num] |
|
column_pad = 3 |
|
param_names_table += "{:<{}}".format( |
|
param_name, max_name_length_column + column_pad |
|
) |
|
|
|
if num == (len(param_names_row) - 1): |
|
param_names_table += "\n" |
|
return param_names_table |
|
|
|
def _get_default_error_message( |
|
self, |
|
errors: ValidationErrorList, |
|
) -> str: |
|
bullet_points: List[str] = [] |
|
errors_by_validator = _group_errors_by_validator(errors) |
|
if "enum" in errors_by_validator: |
|
for error in errors_by_validator["enum"]: |
|
bullet_points.append(f"one of {error.validator_value}") |
|
|
|
if "type" in errors_by_validator: |
|
types = [f"'{err.validator_value}'" for err in errors_by_validator["type"]] |
|
point = "of type " |
|
if len(types) == 1: |
|
point += types[0] |
|
elif len(types) == 2: |
|
point += f"{types[0]} or {types[1]}" |
|
else: |
|
point += ", ".join(types[:-1]) + f", or {types[-1]}" |
|
bullet_points.append(point) |
|
|
|
|
|
|
|
|
|
error = errors[0] |
|
|
|
|
|
message = f"'{error.instance}' is an invalid value" |
|
if error.absolute_path: |
|
message += f" for `{error.absolute_path[-1]}`" |
|
|
|
|
|
if len(bullet_points) == 0: |
|
message += ".\n\n" |
|
elif len(bullet_points) == 1: |
|
message += f". Valid values are {bullet_points[0]}.\n\n" |
|
else: |
|
|
|
|
|
bullet_points = [point[0].upper() + point[1:] for point in bullet_points] |
|
message += ". Valid values are:\n\n" |
|
message += "\n".join([f"- {point}" for point in bullet_points]) |
|
message += "\n\n" |
|
|
|
|
|
|
|
|
|
for validator, errors in errors_by_validator.items(): |
|
if validator not in ("enum", "type"): |
|
message += "\n".join([e.message for e in errors]) |
|
|
|
return message |
|
|
|
|
|
class UndefinedType: |
|
"""A singleton object for marking undefined parameters""" |
|
|
|
__instance = None |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if not isinstance(cls.__instance, cls): |
|
cls.__instance = object.__new__(cls, *args, **kwargs) |
|
return cls.__instance |
|
|
|
def __repr__(self): |
|
return "Undefined" |
|
|
|
|
|
|
|
|
|
|
|
|
|
Undefined: Any = UndefinedType() |
|
|
|
|
|
class SchemaBase: |
|
"""Base class for schema wrappers. |
|
|
|
Each derived class should set the _schema class attribute (and optionally |
|
the _rootschema class attribute) which is used for validation. |
|
""" |
|
|
|
_schema: Optional[Dict[str, Any]] = None |
|
_rootschema: Optional[Dict[str, Any]] = None |
|
_class_is_valid_at_instantiation: bool = True |
|
|
|
def __init__(self, *args: Any, **kwds: Any) -> None: |
|
|
|
|
|
|
|
|
|
if self._schema is None: |
|
raise ValueError( |
|
"Cannot instantiate object of type {}: " |
|
"_schema class attribute is not defined." |
|
"".format(self.__class__) |
|
) |
|
|
|
if kwds: |
|
assert len(args) == 0 |
|
else: |
|
assert len(args) in [0, 1] |
|
|
|
|
|
object.__setattr__(self, "_args", args) |
|
object.__setattr__(self, "_kwds", kwds) |
|
|
|
if DEBUG_MODE and self._class_is_valid_at_instantiation: |
|
self.to_dict(validate=True) |
|
|
|
def copy( |
|
self, deep: Union[bool, Iterable] = True, ignore: Optional[list] = None |
|
) -> Self: |
|
"""Return a copy of the object |
|
|
|
Parameters |
|
---------- |
|
deep : boolean or list, optional |
|
If True (default) then return a deep copy of all dict, list, and |
|
SchemaBase objects within the object structure. |
|
If False, then only copy the top object. |
|
If a list or iterable, then only copy the listed attributes. |
|
ignore : list, optional |
|
A list of keys for which the contents should not be copied, but |
|
only stored by reference. |
|
""" |
|
|
|
def _shallow_copy(obj): |
|
if isinstance(obj, SchemaBase): |
|
return obj.copy(deep=False) |
|
elif isinstance(obj, list): |
|
return obj[:] |
|
elif isinstance(obj, dict): |
|
return obj.copy() |
|
else: |
|
return obj |
|
|
|
def _deep_copy(obj, ignore: Optional[list] = None): |
|
if ignore is None: |
|
ignore = [] |
|
if isinstance(obj, SchemaBase): |
|
args = tuple(_deep_copy(arg) for arg in obj._args) |
|
kwds = { |
|
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) |
|
for k, v in obj._kwds.items() |
|
} |
|
with debug_mode(False): |
|
return obj.__class__(*args, **kwds) |
|
elif isinstance(obj, list): |
|
return [_deep_copy(v, ignore=ignore) for v in obj] |
|
elif isinstance(obj, dict): |
|
return { |
|
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) |
|
for k, v in obj.items() |
|
} |
|
else: |
|
return obj |
|
|
|
try: |
|
deep = list(deep) |
|
except TypeError: |
|
deep_is_list = False |
|
else: |
|
deep_is_list = True |
|
|
|
if deep and not deep_is_list: |
|
return _deep_copy(self, ignore=ignore) |
|
|
|
with debug_mode(False): |
|
copy = self.__class__(*self._args, **self._kwds) |
|
if deep_is_list: |
|
|
|
assert isinstance(deep, list) |
|
for attr in deep: |
|
copy[attr] = _shallow_copy(copy._get(attr)) |
|
return copy |
|
|
|
def _get(self, attr, default=Undefined): |
|
"""Get an attribute, returning default if not present.""" |
|
attr = self._kwds.get(attr, Undefined) |
|
if attr is Undefined: |
|
attr = default |
|
return attr |
|
|
|
def __getattr__(self, attr): |
|
|
|
if attr == "_kwds": |
|
raise AttributeError() |
|
if attr in self._kwds: |
|
return self._kwds[attr] |
|
else: |
|
try: |
|
_getattr = super(SchemaBase, self).__getattr__ |
|
except AttributeError: |
|
_getattr = super(SchemaBase, self).__getattribute__ |
|
return _getattr(attr) |
|
|
|
def __setattr__(self, item, val): |
|
self._kwds[item] = val |
|
|
|
def __getitem__(self, item): |
|
return self._kwds[item] |
|
|
|
def __setitem__(self, item, val): |
|
self._kwds[item] = val |
|
|
|
def __repr__(self): |
|
if self._kwds: |
|
args = ( |
|
"{}: {!r}".format(key, val) |
|
for key, val in sorted(self._kwds.items()) |
|
if val is not Undefined |
|
) |
|
args = "\n" + ",\n".join(args) |
|
return "{0}({{{1}\n}})".format( |
|
self.__class__.__name__, args.replace("\n", "\n ") |
|
) |
|
else: |
|
return "{}({!r})".format(self.__class__.__name__, self._args[0]) |
|
|
|
def __eq__(self, other): |
|
return ( |
|
type(self) is type(other) |
|
and self._args == other._args |
|
and self._kwds == other._kwds |
|
) |
|
|
|
def to_dict( |
|
self, |
|
validate: bool = True, |
|
*, |
|
ignore: Optional[List[str]] = None, |
|
context: Optional[Dict[str, Any]] = None, |
|
) -> dict: |
|
"""Return a dictionary representation of the object |
|
|
|
Parameters |
|
---------- |
|
validate : bool, optional |
|
If True (default), then validate the output dictionary |
|
against the schema. |
|
ignore : list[str], optional |
|
A list of keys to ignore. It is usually not needed |
|
to specify this argument as a user. |
|
context : dict[str, Any], optional |
|
A context dictionary. It is usually not needed |
|
to specify this argument as a user. |
|
|
|
Notes |
|
----- |
|
Technical: The ignore parameter will *not* be passed to child to_dict |
|
function calls. |
|
|
|
Returns |
|
------- |
|
dict |
|
The dictionary representation of this object |
|
|
|
Raises |
|
------ |
|
SchemaValidationError : |
|
if validate=True and the dict does not conform to the schema |
|
""" |
|
if context is None: |
|
context = {} |
|
if ignore is None: |
|
ignore = [] |
|
|
|
if self._args and not self._kwds: |
|
result = _todict(self._args[0], context=context) |
|
elif not self._args: |
|
kwds = self._kwds.copy() |
|
|
|
|
|
|
|
|
|
parsed_shorthand = context.pop("parsed_shorthand", {}) |
|
|
|
|
|
|
|
if "sort" in parsed_shorthand and ( |
|
"sort" not in kwds or kwds["type"] not in ["ordinal", Undefined] |
|
): |
|
parsed_shorthand.pop("sort") |
|
|
|
kwds.update( |
|
{ |
|
k: v |
|
for k, v in parsed_shorthand.items() |
|
if kwds.get(k, Undefined) is Undefined |
|
} |
|
) |
|
kwds = { |
|
k: v for k, v in kwds.items() if k not in list(ignore) + ["shorthand"] |
|
} |
|
if "mark" in kwds and isinstance(kwds["mark"], str): |
|
kwds["mark"] = {"type": kwds["mark"]} |
|
result = _todict( |
|
kwds, |
|
context=context, |
|
) |
|
else: |
|
raise ValueError( |
|
"{} instance has both a value and properties : " |
|
"cannot serialize to dict".format(self.__class__) |
|
) |
|
if validate: |
|
try: |
|
self.validate(result) |
|
except jsonschema.ValidationError as err: |
|
|
|
|
|
|
|
|
|
|
|
raise SchemaValidationError(self, err) from None |
|
return result |
|
|
|
def to_json( |
|
self, |
|
validate: bool = True, |
|
indent: int = 2, |
|
sort_keys: bool = True, |
|
*, |
|
ignore: Optional[List[str]] = None, |
|
context: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
) -> str: |
|
"""Emit the JSON representation for this object as a string. |
|
|
|
Parameters |
|
---------- |
|
validate : bool, optional |
|
If True (default), then validate the output dictionary |
|
against the schema. |
|
indent : int, optional |
|
The number of spaces of indentation to use. The default is 2. |
|
sort_keys : bool, optional |
|
If True (default), sort keys in the output. |
|
ignore : list[str], optional |
|
A list of keys to ignore. It is usually not needed |
|
to specify this argument as a user. |
|
context : dict[str, Any], optional |
|
A context dictionary. It is usually not needed |
|
to specify this argument as a user. |
|
**kwargs |
|
Additional keyword arguments are passed to ``json.dumps()`` |
|
|
|
Notes |
|
----- |
|
Technical: The ignore parameter will *not* be passed to child to_dict |
|
function calls. |
|
|
|
Returns |
|
------- |
|
str |
|
The JSON specification of the chart object. |
|
""" |
|
if ignore is None: |
|
ignore = [] |
|
if context is None: |
|
context = {} |
|
dct = self.to_dict(validate=validate, ignore=ignore, context=context) |
|
return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs) |
|
|
|
@classmethod |
|
def _default_wrapper_classes(cls) -> Generator[Type["SchemaBase"], None, None]: |
|
"""Return the set of classes used within cls.from_dict()""" |
|
return _subclasses(SchemaBase) |
|
|
|
@classmethod |
|
def from_dict( |
|
cls, |
|
dct: dict, |
|
validate: bool = True, |
|
_wrapper_classes: Optional[Iterable[Type["SchemaBase"]]] = None, |
|
|
|
|
|
) -> "SchemaBase": |
|
"""Construct class from a dictionary representation |
|
|
|
Parameters |
|
---------- |
|
dct : dictionary |
|
The dict from which to construct the class |
|
validate : boolean |
|
If True (default), then validate the input against the schema. |
|
_wrapper_classes : iterable (optional) |
|
The set of SchemaBase classes to use when constructing wrappers |
|
of the dict inputs. If not specified, the result of |
|
cls._default_wrapper_classes will be used. |
|
|
|
Returns |
|
------- |
|
obj : Schema object |
|
The wrapped schema |
|
|
|
Raises |
|
------ |
|
jsonschema.ValidationError : |
|
if validate=True and dct does not conform to the schema |
|
""" |
|
if validate: |
|
cls.validate(dct) |
|
if _wrapper_classes is None: |
|
_wrapper_classes = cls._default_wrapper_classes() |
|
converter = _FromDict(_wrapper_classes) |
|
return converter.from_dict(dct, cls) |
|
|
|
@classmethod |
|
def from_json( |
|
cls, |
|
json_string: str, |
|
validate: bool = True, |
|
**kwargs: Any |
|
|
|
|
|
) -> Any: |
|
"""Instantiate the object from a valid JSON string |
|
|
|
Parameters |
|
---------- |
|
json_string : string |
|
The string containing a valid JSON chart specification. |
|
validate : boolean |
|
If True (default), then validate the input against the schema. |
|
**kwargs : |
|
Additional keyword arguments are passed to json.loads |
|
|
|
Returns |
|
------- |
|
chart : Chart object |
|
The altair Chart object built from the specification. |
|
""" |
|
dct = json.loads(json_string, **kwargs) |
|
return cls.from_dict(dct, validate=validate) |
|
|
|
@classmethod |
|
def validate( |
|
cls, instance: Dict[str, Any], schema: Optional[Dict[str, Any]] = None |
|
) -> None: |
|
""" |
|
Validate the instance against the class schema in the context of the |
|
rootschema. |
|
""" |
|
if schema is None: |
|
schema = cls._schema |
|
|
|
assert schema is not None |
|
return validate_jsonschema( |
|
instance, schema, rootschema=cls._rootschema or cls._schema |
|
) |
|
|
|
@classmethod |
|
def resolve_references(cls, schema: Optional[dict] = None) -> dict: |
|
"""Resolve references in the context of this object's schema or root schema.""" |
|
schema_to_pass = schema or cls._schema |
|
|
|
assert schema_to_pass is not None |
|
return _resolve_references( |
|
schema=schema_to_pass, |
|
rootschema=(cls._rootschema or cls._schema or schema), |
|
) |
|
|
|
@classmethod |
|
def validate_property( |
|
cls, name: str, value: Any, schema: Optional[dict] = None |
|
) -> None: |
|
""" |
|
Validate a property against property schema in the context of the |
|
rootschema |
|
""" |
|
value = _todict(value, context={}) |
|
props = cls.resolve_references(schema or cls._schema).get("properties", {}) |
|
return validate_jsonschema( |
|
value, props.get(name, {}), rootschema=cls._rootschema or cls._schema |
|
) |
|
|
|
def __dir__(self) -> list: |
|
return sorted(list(super().__dir__()) + list(self._kwds.keys())) |
|
|
|
|
|
def _passthrough(*args, **kwds): |
|
return args[0] if args else kwds |
|
|
|
|
|
class _FromDict: |
|
"""Class used to construct SchemaBase class hierarchies from a dict |
|
|
|
The primary purpose of using this class is to be able to build a hash table |
|
that maps schemas to their wrapper classes. The candidate classes are |
|
specified in the ``class_list`` argument to the constructor. |
|
""" |
|
|
|
_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id") |
|
|
|
def __init__(self, class_list: Iterable[Type[SchemaBase]]) -> None: |
|
|
|
|
|
self.class_dict = collections.defaultdict(list) |
|
for cls in class_list: |
|
if cls._schema is not None: |
|
self.class_dict[self.hash_schema(cls._schema)].append(cls) |
|
|
|
@classmethod |
|
def hash_schema(cls, schema: dict, use_json: bool = True) -> int: |
|
""" |
|
Compute a python hash for a nested dictionary which |
|
properly handles dicts, lists, sets, and tuples. |
|
|
|
At the top level, the function excludes from the hashed schema all keys |
|
listed in `exclude_keys`. |
|
|
|
This implements two methods: one based on conversion to JSON, and one based |
|
on recursive conversions of unhashable to hashable types; the former seems |
|
to be slightly faster in several benchmarks. |
|
""" |
|
if cls._hash_exclude_keys and isinstance(schema, dict): |
|
schema = { |
|
key: val |
|
for key, val in schema.items() |
|
if key not in cls._hash_exclude_keys |
|
} |
|
if use_json: |
|
s = json.dumps(schema, sort_keys=True) |
|
return hash(s) |
|
else: |
|
|
|
def _freeze(val): |
|
if isinstance(val, dict): |
|
return frozenset((k, _freeze(v)) for k, v in val.items()) |
|
elif isinstance(val, set): |
|
return frozenset(map(_freeze, val)) |
|
elif isinstance(val, list) or isinstance(val, tuple): |
|
return tuple(map(_freeze, val)) |
|
else: |
|
return val |
|
|
|
return hash(_freeze(schema)) |
|
|
|
def from_dict( |
|
self, |
|
dct: dict, |
|
cls: Optional[Type[SchemaBase]] = None, |
|
schema: Optional[dict] = None, |
|
rootschema: Optional[dict] = None, |
|
default_class=_passthrough, |
|
|
|
|
|
) -> Any: |
|
"""Construct an object from a dict representation""" |
|
if (schema is None) == (cls is None): |
|
raise ValueError("Must provide either cls or schema, but not both.") |
|
if schema is None: |
|
|
|
schema = cls._schema |
|
|
|
assert schema is not None |
|
if rootschema: |
|
rootschema = rootschema |
|
elif cls is not None and cls._rootschema is not None: |
|
rootschema = cls._rootschema |
|
else: |
|
rootschema = None |
|
rootschema = rootschema or schema |
|
|
|
if isinstance(dct, SchemaBase): |
|
return dct |
|
|
|
if cls is None: |
|
|
|
|
|
|
|
matches = self.class_dict[self.hash_schema(schema)] |
|
if matches: |
|
cls = matches[0] |
|
else: |
|
cls = default_class |
|
schema = _resolve_references(schema, rootschema) |
|
|
|
if "anyOf" in schema or "oneOf" in schema: |
|
schemas = schema.get("anyOf", []) + schema.get("oneOf", []) |
|
for possible_schema in schemas: |
|
try: |
|
validate_jsonschema(dct, possible_schema, rootschema=rootschema) |
|
except jsonschema.ValidationError: |
|
continue |
|
else: |
|
return self.from_dict( |
|
dct, |
|
schema=possible_schema, |
|
rootschema=rootschema, |
|
default_class=cls, |
|
) |
|
|
|
if isinstance(dct, dict): |
|
|
|
props = schema.get("properties", {}) |
|
kwds = {} |
|
for key, val in dct.items(): |
|
if key in props: |
|
val = self.from_dict(val, schema=props[key], rootschema=rootschema) |
|
kwds[key] = val |
|
return cls(**kwds) |
|
|
|
elif isinstance(dct, list): |
|
item_schema = schema.get("items", {}) |
|
dct = [ |
|
self.from_dict(val, schema=item_schema, rootschema=rootschema) |
|
for val in dct |
|
] |
|
return cls(dct) |
|
else: |
|
return cls(dct) |
|
|
|
|
|
class _PropertySetter: |
|
def __init__(self, prop: str, schema: dict) -> None: |
|
self.prop = prop |
|
self.schema = schema |
|
|
|
def __get__(self, obj, cls): |
|
self.obj = obj |
|
self.cls = cls |
|
|
|
|
|
self.__doc__ = self.schema["description"].replace("__", "**") |
|
property_name = f"{self.prop}"[0].upper() + f"{self.prop}"[1:] |
|
if hasattr(vegalite, property_name): |
|
altair_prop = getattr(vegalite, property_name) |
|
|
|
|
|
|
|
parameter_index = altair_prop.__doc__.find("Parameters\n") |
|
if parameter_index > -1: |
|
self.__doc__ = ( |
|
altair_prop.__doc__[:parameter_index].replace(" ", "") |
|
+ self.__doc__ |
|
+ textwrap.dedent( |
|
f"\n\n {altair_prop.__doc__[parameter_index:]}" |
|
) |
|
) |
|
|
|
else: |
|
self.__doc__ = ( |
|
altair_prop.__doc__.replace(" ", "") + "\n" + self.__doc__ |
|
) |
|
|
|
self.__signature__ = inspect.signature(altair_prop) |
|
self.__wrapped__ = inspect.getfullargspec(altair_prop) |
|
self.__name__ = altair_prop.__name__ |
|
else: |
|
|
|
|
|
pass |
|
return self |
|
|
|
def __call__(self, *args, **kwargs): |
|
obj = self.obj.copy() |
|
|
|
obj[self.prop] = args[0] if args else kwargs |
|
return obj |
|
|
|
|
|
def with_property_setters(cls: _TSchemaBase) -> _TSchemaBase: |
|
""" |
|
Decorator to add property setters to a Schema class. |
|
""" |
|
schema = cls.resolve_references() |
|
for prop, propschema in schema.get("properties", {}).items(): |
|
setattr(cls, prop, _PropertySetter(prop, propschema)) |
|
return cls |
|
|