""" Utility routines """ from collections.abc import Mapping, MutableMapping from copy import deepcopy import json import itertools import re import sys import traceback import warnings from typing import ( Callable, TypeVar, Any, Union, Dict, Optional, Tuple, Sequence, Type, cast, ) from types import ModuleType import jsonschema import pandas as pd import numpy as np from pandas.api.types import infer_dtype from altair.utils.schemapi import SchemaBase from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame if sys.version_info >= (3, 10): from typing import ParamSpec else: from typing_extensions import ParamSpec from typing import Literal, Protocol, TYPE_CHECKING if TYPE_CHECKING: from pandas.core.interchange.dataframe_protocol import Column as PandasColumn _V = TypeVar("_V") _P = ParamSpec("_P") class _DataFrameLike(Protocol): def __dataframe__(self, *args, **kwargs) -> DfiDataFrame: ... TYPECODE_MAP = { "ordinal": "O", "nominal": "N", "quantitative": "Q", "temporal": "T", "geojson": "G", } INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()} # aggregates from vega-lite version 4.6.0 AGGREGATES = [ "argmax", "argmin", "average", "count", "distinct", "max", "mean", "median", "min", "missing", "product", "q1", "q3", "ci0", "ci1", "stderr", "stdev", "stdevp", "sum", "valid", "values", "variance", "variancep", ] # window aggregates from vega-lite version 4.6.0 WINDOW_AGGREGATES = [ "row_number", "rank", "dense_rank", "percent_rank", "cume_dist", "ntile", "lag", "lead", "first_value", "last_value", "nth_value", ] # timeUnits from vega-lite version 4.17.0 TIMEUNITS = [ "year", "quarter", "month", "week", "day", "dayofyear", "date", "hours", "minutes", "seconds", "milliseconds", "yearquarter", "yearquartermonth", "yearmonth", "yearmonthdate", "yearmonthdatehours", "yearmonthdatehoursminutes", "yearmonthdatehoursminutesseconds", "yearweek", "yearweekday", "yearweekdayhours", "yearweekdayhoursminutes", "yearweekdayhoursminutesseconds", "yeardayofyear", "quartermonth", "monthdate", "monthdatehours", "monthdatehoursminutes", "monthdatehoursminutesseconds", "weekday", "weeksdayhours", "weekdayhoursminutes", "weekdayhoursminutesseconds", "dayhours", "dayhoursminutes", "dayhoursminutesseconds", "hoursminutes", "hoursminutesseconds", "minutesseconds", "secondsmilliseconds", "utcyear", "utcquarter", "utcmonth", "utcweek", "utcday", "utcdayofyear", "utcdate", "utchours", "utcminutes", "utcseconds", "utcmilliseconds", "utcyearquarter", "utcyearquartermonth", "utcyearmonth", "utcyearmonthdate", "utcyearmonthdatehours", "utcyearmonthdatehoursminutes", "utcyearmonthdatehoursminutesseconds", "utcyearweek", "utcyearweekday", "utcyearweekdayhours", "utcyearweekdayhoursminutes", "utcyearweekdayhoursminutesseconds", "utcyeardayofyear", "utcquartermonth", "utcmonthdate", "utcmonthdatehours", "utcmonthdatehoursminutes", "utcmonthdatehoursminutesseconds", "utcweekday", "utcweeksdayhours", "utcweekdayhoursminutes", "utcweekdayhoursminutesseconds", "utcdayhours", "utcdayhoursminutes", "utcdayhoursminutesseconds", "utchoursminutes", "utchoursminutesseconds", "utcminutesseconds", "utcsecondsmilliseconds", ] _InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] def infer_vegalite_type( data: object, ) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]: """ From an array-like input, infer the correct vega typecode ('ordinal', 'nominal', 'quantitative', or 'temporal') Parameters ---------- data: object """ typ = infer_dtype(data, skipna=False) if typ in [ "floating", "mixed-integer-float", "integer", "mixed-integer", "complex", ]: return "quantitative" elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered: return ("ordinal", data.cat.categories.tolist()) elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: return "nominal" elif typ in [ "datetime", "datetime64", "timedelta", "timedelta64", "date", "time", "period", ]: return "temporal" else: warnings.warn( "I don't know how to infer vegalite type from '{}'. " "Defaulting to nominal.".format(typ), stacklevel=1, ) return "nominal" def merge_props_geom(feat: dict) -> dict: """ Merge properties with geometry * Overwrites 'type' and 'geometry' entries if existing """ geom = {k: feat[k] for k in ("type", "geometry")} try: feat["properties"].update(geom) props_geom = feat["properties"] except (AttributeError, KeyError): # AttributeError when 'properties' equals None # KeyError when 'properties' is non-existing props_geom = geom return props_geom def sanitize_geo_interface(geo: MutableMapping) -> dict: """Santize a geo_interface to prepare it for serialization. * Make a copy * Convert type array or _Array to list * Convert tuples to lists (using json.loads/dumps) * Merge properties with geometry """ geo = deepcopy(geo) # convert type _Array or array to list for key in geo.keys(): if str(type(geo[key]).__name__).startswith(("_Array", "array")): geo[key] = geo[key].tolist() # convert (nested) tuples to lists geo_dct: dict = json.loads(json.dumps(geo)) # sanitize features if geo_dct["type"] == "FeatureCollection": geo_dct = geo_dct["features"] if len(geo_dct) > 0: for idx, feat in enumerate(geo_dct): geo_dct[idx] = merge_props_geom(feat) elif geo_dct["type"] == "Feature": geo_dct = merge_props_geom(geo_dct) else: geo_dct = {"type": "Feature", "geometry": geo_dct} return geo_dct def numpy_is_subtype(dtype: Any, subtype: Any) -> bool: try: return np.issubdtype(dtype, subtype) except (NotImplementedError, TypeError): return False def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901 """Sanitize a DataFrame to prepare it for serialization. * Make a copy * Convert RangeIndex columns to strings * Raise ValueError if column names are not strings * Raise ValueError if it has a hierarchical index. * Convert categoricals to strings. * Convert np.bool_ dtypes to Python bool objects * Convert np.int dtypes to Python int objects * Convert floats to objects and replace NaNs/infs with None. * Convert DateTime dtypes into appropriate string representations * Convert Nullable integers to objects and replace NaN with None * Convert Nullable boolean to objects and replace NaN with None * convert dedicated string column to objects and replace NaN with None * Raise a ValueError for TimeDelta dtypes """ df = df.copy() if isinstance(df.columns, pd.RangeIndex): df.columns = df.columns.astype(str) for col_name in df.columns: if not isinstance(col_name, str): raise ValueError( "Dataframe contains invalid column name: {0!r}. " "Column names must be strings".format(col_name) ) if isinstance(df.index, pd.MultiIndex): raise ValueError("Hierarchical indices not supported") if isinstance(df.columns, pd.MultiIndex): raise ValueError("Hierarchical indices not supported") def to_list_if_array(val): if isinstance(val, np.ndarray): return val.tolist() else: return val for dtype_item in df.dtypes.items(): # We know that the column names are strings from the isinstance check # further above but mypy thinks it is of type Hashable and therefore does not # let us assign it to the col_name variable which is already of type str. col_name = cast(str, dtype_item[0]) dtype = dtype_item[1] dtype_name = str(dtype) if dtype_name == "category": # Work around bug in to_json for categorical types in older versions # of pandas as they do not properly convert NaN values to null in to_json. # We can probably remove this part once we require Pandas >= 1.0 col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif dtype_name == "string": # dedicated string datatype (since 1.0) # https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif dtype_name == "bool": # convert numpy bools to objects; np.bool is not JSON serializable df[col_name] = df[col_name].astype(object) elif dtype_name == "boolean": # dedicated boolean datatype (since 1.0) # https://pandas.io/docs/user_guide/boolean.html col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif dtype_name.startswith("datetime") or dtype_name.startswith("timestamp"): # Convert datetimes to strings. This needs to be a full ISO string # with time, which is why we cannot use ``col.astype(str)``. # This is because Javascript parses date-only times in UTC, but # parses full ISO-8601 dates as local time, and dates in Vega and # Vega-Lite are displayed in local time by default. # (see https://github.com/altair-viz/altair/issues/1027) df[col_name] = ( df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") ) elif dtype_name.startswith("timedelta"): raise ValueError( 'Field "{col_name}" has type "{dtype}" which is ' "not supported by Altair. Please convert to " "either a timestamp or a numerical value." "".format(col_name=col_name, dtype=dtype) ) elif dtype_name.startswith("geometry"): # geopandas >=0.6.1 uses the dtype geometry. Continue here # otherwise it will give an error on np.issubdtype(dtype, np.integer) continue elif dtype_name in { "Int8", "Int16", "Int32", "Int64", "UInt8", "UInt16", "UInt32", "UInt64", "Float32", "Float64", }: # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0) # https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif numpy_is_subtype(dtype, np.integer): # convert integers to objects; np.int is not JSON serializable df[col_name] = df[col_name].astype(object) elif numpy_is_subtype(dtype, np.floating): # For floats, convert to Python float: np.float is not JSON serializable # Also convert NaN/inf values to null, as they are not JSON serializable col = df[col_name] bad_values = col.isnull() | np.isinf(col) df[col_name] = col.astype(object).where(~bad_values, None) elif dtype == object: # Convert numpy arrays saved as objects to lists # Arrays are not JSON serializable col = df[col_name].astype(object).apply(to_list_if_array) df[col_name] = col.where(col.notnull(), None) return df def sanitize_arrow_table(pa_table): """Sanitize arrow table for JSON serialization""" import pyarrow as pa import pyarrow.compute as pc arrays = [] schema = pa_table.schema for name in schema.names: array = pa_table[name] dtype = schema.field(name).type if str(dtype).startswith("timestamp"): arrays.append(pc.strftime(array)) elif str(dtype).startswith("duration"): raise ValueError( 'Field "{col_name}" has type "{dtype}" which is ' "not supported by Altair. Please convert to " "either a timestamp or a numerical value." "".format(col_name=name, dtype=dtype) ) else: arrays.append(array) return pa.Table.from_arrays(arrays, names=schema.names) def parse_shorthand( shorthand: Union[Dict[str, Any], str], data: Optional[Union[pd.DataFrame, _DataFrameLike]] = None, parse_aggregates: bool = True, parse_window_ops: bool = False, parse_timeunits: bool = True, parse_types: bool = True, ) -> Dict[str, Any]: """General tool to parse shorthand values These are of the form: - "col_name" - "col_name:O" - "average(col_name)" - "average(col_name):O" Optionally, a dataframe may be supplied, from which the type will be inferred if not specified in the shorthand. Parameters ---------- shorthand : dict or string The shorthand representation to be parsed data : DataFrame, optional If specified and of type DataFrame, then use these values to infer the column type if not provided by the shorthand. parse_aggregates : boolean If True (default), then parse aggregate functions within the shorthand. parse_window_ops : boolean If True then parse window operations within the shorthand (default:False) parse_timeunits : boolean If True (default), then parse timeUnits from within the shorthand parse_types : boolean If True (default), then parse typecodes within the shorthand Returns ------- attrs : dict a dictionary of attributes extracted from the shorthand Examples -------- >>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], ... 'bar': [1, 2, 3, 4]}) >>> parse_shorthand('name') == {'field': 'name'} True >>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'} True >>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'} True >>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'} True >>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'} True >>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'} True >>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'} True >>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'} True >>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'} True >>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'} True >>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'} True >>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} True """ from altair.utils._importers import pyarrow_available if not shorthand: return {} valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP) units = { "field": "(?P.*)", "type": "(?P{})".format("|".join(valid_typecodes)), "agg_count": "(?Pcount)", "op_count": "(?Pcount)", "aggregate": "(?P{})".format("|".join(AGGREGATES)), "window_op": "(?P{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), "timeUnit": "(?P{})".format("|".join(TIMEUNITS)), } patterns = [] if parse_aggregates: patterns.extend([r"{agg_count}\(\)"]) patterns.extend([r"{aggregate}\({field}\)"]) if parse_window_ops: patterns.extend([r"{op_count}\(\)"]) patterns.extend([r"{window_op}\({field}\)"]) if parse_timeunits: patterns.extend([r"{timeUnit}\({field}\)"]) patterns.extend([r"{field}"]) if parse_types: patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns))) regexps = ( re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns ) # find matches depending on valid fields passed if isinstance(shorthand, dict): attrs = shorthand else: attrs = next( exp.match(shorthand).groupdict() # type: ignore[union-attr] for exp in regexps if exp.match(shorthand) is not None ) # Handle short form of the type expression if "type" in attrs: attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"]) # counts are quantitative by default if attrs == {"aggregate": "count"}: attrs["type"] = "quantitative" # times are temporal by default if "timeUnit" in attrs and "type" not in attrs: attrs["type"] = "temporal" # if data is specified and type is not, infer type from data if "type" not in attrs: if pyarrow_available() and data is not None and hasattr(data, "__dataframe__"): dfi = data.__dataframe__() if "field" in attrs: unescaped_field = attrs["field"].replace("\\", "") if unescaped_field in dfi.column_names(): column = dfi.get_column_by_name(unescaped_field) try: attrs["type"] = infer_vegalite_type_for_dfi_column(column) except (NotImplementedError, AttributeError, ValueError): # Fall back to pandas-based inference. # Note: The AttributeError catch is a workaround for # https://github.com/pandas-dev/pandas/issues/55332 if isinstance(data, pd.DataFrame): attrs["type"] = infer_vegalite_type(data[unescaped_field]) else: raise if isinstance(attrs["type"], tuple): attrs["sort"] = attrs["type"][1] attrs["type"] = attrs["type"][0] elif isinstance(data, pd.DataFrame): # Fallback if pyarrow is not installed or if pandas is older than 1.5 # # Remove escape sequences so that types can be inferred for columns with special characters if "field" in attrs and attrs["field"].replace("\\", "") in data.columns: attrs["type"] = infer_vegalite_type( data[attrs["field"].replace("\\", "")] ) # ordered categorical dataframe columns return the type and sort order as a tuple if isinstance(attrs["type"], tuple): attrs["sort"] = attrs["type"][1] attrs["type"] = attrs["type"][0] # If an unescaped colon is still present, it's often due to an incorrect data type specification # but could also be due to using a column name with ":" in it. if ( "field" in attrs and ":" in attrs["field"] and attrs["field"][attrs["field"].rfind(":") - 1] != "\\" ): raise ValueError( '"{}" '.format(attrs["field"].split(":")[-1]) + "is not one of the valid encoding data types: {}.".format( ", ".join(TYPECODE_MAP.values()) ) + "\nFor more details, see https://altair-viz.github.io/user_guide/encodings/index.html#encoding-data-types. " + "If you are trying to use a column name that contains a colon, " + 'prefix it with a backslash; for example "column\\:name" instead of "column:name".' ) return attrs def infer_vegalite_type_for_dfi_column( column: Union[Column, "PandasColumn"], ) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]: from pyarrow.interchange.from_dataframe import column_to_array try: kind = column.dtype[0] except NotImplementedError as e: # Edge case hack: # dtype access fails for pandas column with datetime64[ns, UTC] type, # but all we need to know is that its temporal, so check the # error message for the presence of datetime64. # # See https://github.com/pandas-dev/pandas/issues/54239 if "datetime64" in e.args[0] or "timestamp" in e.args[0]: return "temporal" raise e if ( kind == DtypeKind.CATEGORICAL and column.describe_categorical["is_ordered"] and column.describe_categorical["categories"] is not None ): # Treat ordered categorical column as Vega-Lite ordinal categories_column = column.describe_categorical["categories"] categories_array = column_to_array(categories_column) return "ordinal", categories_array.to_pylist() if kind in (DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL): return "nominal" elif kind in (DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT): return "quantitative" elif kind == DtypeKind.DATETIME: return "temporal" else: raise ValueError(f"Unexpected DtypeKind: {kind}") def use_signature(Obj: Callable[_P, Any]): """Apply call signature and documentation of Obj to the decorated method""" def decorate(f: Callable[..., _V]) -> Callable[_P, _V]: # call-signature of f is exposed via __wrapped__. # we want it to mimic Obj.__init__ f.__wrapped__ = Obj.__init__ # type: ignore f._uses_signature = Obj # type: ignore # Supplement the docstring of f with information from Obj if Obj.__doc__: # Patch in a reference to the class this docstring is copied from, # to generate a hyperlink. doclines = Obj.__doc__.splitlines() doclines[0] = f"Refer to :class:`{Obj.__name__}`" if f.__doc__: doc = f.__doc__ + "\n".join(doclines[1:]) else: doc = "\n".join(doclines) try: f.__doc__ = doc except AttributeError: # __doc__ is not modifiable for classes in Python < 3.3 pass return f return decorate def update_nested( original: MutableMapping, update: Mapping, copy: bool = False ) -> MutableMapping: """Update nested dictionaries Parameters ---------- original : MutableMapping the original (nested) dictionary, which will be updated in-place update : Mapping the nested dictionary of updates copy : bool, default False if True, then copy the original dictionary rather than modifying it Returns ------- original : MutableMapping a reference to the (modified) original dict Examples -------- >>> original = {'x': {'b': 2, 'c': 4}} >>> update = {'x': {'b': 5, 'd': 6}, 'y': 40} >>> update_nested(original, update) # doctest: +SKIP {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} >>> original # doctest: +SKIP {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} """ if copy: original = deepcopy(original) for key, val in update.items(): if isinstance(val, Mapping): orig_val = original.get(key, {}) if isinstance(orig_val, MutableMapping): original[key] = update_nested(orig_val, val) else: original[key] = val else: original[key] = val return original def display_traceback(in_ipython: bool = True): exc_info = sys.exc_info() if in_ipython: from IPython.core.getipython import get_ipython ip = get_ipython() else: ip = None if ip is not None: ip.showtraceback(exc_info) else: traceback.print_exception(*exc_info) def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType): """Infer typed keyword arguments for args and kwargs Parameters ---------- args : Sequence Sequence of function args kwargs : MutableMapping Dict of function kwargs channels : ModuleType The module containing all altair encoding channel classes. Returns ------- kwargs : dict All args and kwargs in a single dict, with keys and types based on the channels mapping. """ # Construct a dictionary of channel type to encoding name # TODO: cache this somehow? channel_objs = (getattr(channels, name) for name in dir(channels)) channel_objs = ( c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) ) channel_to_name: Dict[Type[SchemaBase], str] = { c: c._encoding_name for c in channel_objs } name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {} for chan, name in channel_to_name.items(): chans = name_to_channel.setdefault(name, {}) if chan.__name__.endswith("Datum"): key = "datum" elif chan.__name__.endswith("Value"): key = "value" else: key = "field" chans[key] = chan # First use the mapping to convert args to kwargs based on their types. for arg in args: if isinstance(arg, (list, tuple)) and len(arg) > 0: type_ = type(arg[0]) else: type_ = type(arg) encoding = channel_to_name.get(type_, None) if encoding is None: raise NotImplementedError("positional of type {}" "".format(type_)) if encoding in kwargs: raise ValueError("encoding {} specified twice.".format(encoding)) kwargs[encoding] = arg def _wrap_in_channel_class(obj, encoding): if isinstance(obj, SchemaBase): return obj if isinstance(obj, str): obj = {"shorthand": obj} if isinstance(obj, (list, tuple)): return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] if encoding not in name_to_channel: warnings.warn( "Unrecognized encoding channel '{}'".format(encoding), stacklevel=1 ) return obj classes = name_to_channel[encoding] cls = classes["value"] if "value" in obj else classes["field"] try: # Don't force validation here; some objects won't be valid until # they're created in the context of a chart. return cls.from_dict(obj, validate=False) except jsonschema.ValidationError: # our attempts at finding the correct class have failed return obj return { encoding: _wrap_in_channel_class(obj, encoding) for encoding, obj in kwargs.items() }