from typing import List, Optional, Tuple, Dict, Iterable, overload, Union from altair import ( Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart, TopLevelUnitSpec, FacetedUnitSpec, UnitSpec, UnitSpecWithFrame, NonNormalizedSpec, TopLevelLayerSpec, LayerSpec, TopLevelConcatSpec, ConcatSpecGenericSpec, TopLevelHConcatSpec, HConcatSpecGenericSpec, TopLevelVConcatSpec, VConcatSpecGenericSpec, TopLevelFacetSpec, FacetSpec, data_transformers, ) from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion from altair.utils.core import _DataFrameLike from altair.utils.schemapi import Undefined Scope = Tuple[int, ...] FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] # For the transformed_data functionality, the chart classes in the values # can be considered equivalent to the chart class in the key. _chart_class_mapping = { Chart: ( Chart, TopLevelUnitSpec, FacetedUnitSpec, UnitSpec, UnitSpecWithFrame, NonNormalizedSpec, ), LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec), ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec), HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec), VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec), FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec), } @overload def transformed_data( chart: Union[Chart, FacetChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, ) -> Optional[_DataFrameLike]: ... @overload def transformed_data( chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, ) -> List[_DataFrameLike]: ... def transformed_data(chart, row_limit=None, exclude=None): """Evaluate a Chart's transforms Evaluate the data transforms associated with a Chart and return the transformed data as one or more DataFrames Parameters ---------- chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart Altair chart to evaluate transforms on row_limit : int (optional) Maximum number of rows to return for each DataFrame. None (default) for unlimited exclude : iterable of str Set of the names of charts to exclude Returns ------- DataFrame or list of DataFrames or None If input chart is a Chart or Facet Chart, returns a DataFrame of the transformed data. Otherwise, returns a list of DataFrames of the transformed data """ vf = import_vegafusion() if isinstance(chart, Chart): # Add mark if none is specified to satisfy Vega-Lite if chart.mark == Undefined: chart = chart.mark_point() # Deep copy chart so that we can rename marks without affecting caller chart = chart.copy(deep=True) # Ensure that all views are named so that we can look them up in the # resulting Vega specification chart_names = name_views(chart, 0, exclude=exclude) # Compile to Vega and extract inline DataFrames with data_transformers.enable("vegafusion"): vega_spec = chart.to_dict(format="vega", context={"pre_transform": False}) inline_datasets = get_inline_tables(vega_spec) # Build mapping from mark names to vega datasets facet_mapping = get_facet_mapping(vega_spec) dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping) # Build a list of vega dataset names that corresponds to the order # of the chart components dataset_names = [] for chart_name in chart_names: if chart_name in dataset_mapping: dataset_names.append(dataset_mapping[chart_name]) else: raise ValueError("Failed to locate all datasets") # Extract transformed datasets with VegaFusion datasets, warnings = vf.runtime.pre_transform_datasets( vega_spec, dataset_names, row_limit=row_limit, inline_datasets=inline_datasets, ) if isinstance(chart, (Chart, FacetChart)): # Return DataFrame (or None if it was excluded) if input was a simple Chart if not datasets: return None else: return datasets[0] else: # Otherwise return the list of DataFrames return datasets # The equivalent classes from _chart_class_mapping should also be added # to the type hints below for `chart` as the function would also work for them. # However, this was not possible so far as mypy then complains about # "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]" # This might be due to the complex type hierarchy of the chart classes. # See also https://github.com/python/mypy/issues/5119 # and https://github.com/python/mypy/issues/4020 which show that mypy might not have # a very consistent behavior for overloaded functions. # The same error appeared when trying it with Protocols for the concat and layer charts. # This function is only used internally and so we accept this inconsistency for now. def name_views( chart: Union[ Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart ], i: int = 0, exclude: Optional[Iterable[str]] = None, ) -> List[str]: """Name unnamed chart views Name unnamed charts views so that we can look them up later in the compiled Vega spec. Note: This function mutates the input chart by applying names to unnamed views. Parameters ---------- chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart Altair chart to apply names to i : int (default 0) Starting chart index exclude : iterable of str Names of charts to exclude Returns ------- list of str List of the names of the charts and subcharts """ exclude = set(exclude) if exclude is not None else set() if isinstance(chart, _chart_class_mapping[Chart]) or isinstance( chart, _chart_class_mapping[FacetChart] ): if chart.name not in exclude: if chart.name in (None, Undefined): # Add name since none is specified chart.name = Chart._get_name() return [chart.name] else: return [] else: if isinstance(chart, _chart_class_mapping[LayerChart]): subcharts = chart.layer elif isinstance(chart, _chart_class_mapping[HConcatChart]): subcharts = chart.hconcat elif isinstance(chart, _chart_class_mapping[VConcatChart]): subcharts = chart.vconcat elif isinstance(chart, _chart_class_mapping[ConcatChart]): subcharts = chart.concat else: raise ValueError( "transformed_data accepts an instance of " "Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n" f"Received value of type: {type(chart)}" ) chart_names: List[str] = [] for subchart in subcharts: for name in name_views(subchart, i=i + len(chart_names), exclude=exclude): chart_names.append(name) return chart_names def get_group_mark_for_scope(vega_spec: dict, scope: Scope) -> Optional[dict]: """Get the group mark at a particular scope Parameters ---------- vega_spec : dict Top-level Vega specification dictionary scope : tuple of int Scope tuple. If empty, the original Vega specification is returned. Otherwise, the nested group mark at the scope specified is returned. Returns ------- dict or None Top-level Vega spec (if scope is empty) or group mark (if scope is non-empty) or None (if group mark at scope does not exist) Examples -------- >>> spec = { ... "marks": [ ... { ... "type": "group", ... "marks": [{"type": "symbol"}] ... }, ... { ... "type": "group", ... "marks": [{"type": "rect"}]} ... ] ... } >>> get_group_mark_for_scope(spec, (1,)) {'type': 'group', 'marks': [{'type': 'rect'}]} """ group = vega_spec # Find group at scope for scope_value in scope: group_index = 0 child_group = None for mark in group.get("marks", []): if mark.get("type") == "group": if group_index == scope_value: child_group = mark break group_index += 1 if child_group is None: return None group = child_group return group def get_datasets_for_scope(vega_spec: dict, scope: Scope) -> List[str]: """Get the names of the datasets that are defined at a given scope Parameters ---------- vega_spec : dict Top-leve Vega specification scope : tuple of int Scope tuple. If empty, the names of top-level datasets are returned Otherwise, the names of the datasets defined in the nested group mark at the specified scope are returned. Returns ------- list of str List of the names of the datasets defined at the specified scope Examples -------- >>> spec = { ... "data": [ ... {"name": "data1"} ... ], ... "marks": [ ... { ... "type": "group", ... "data": [ ... {"name": "data2"} ... ], ... "marks": [{"type": "symbol"}] ... }, ... { ... "type": "group", ... "data": [ ... {"name": "data3"}, ... {"name": "data4"}, ... ], ... "marks": [{"type": "rect"}] ... } ... ] ... } >>> get_datasets_for_scope(spec, ()) ['data1'] >>> get_datasets_for_scope(spec, (0,)) ['data2'] >>> get_datasets_for_scope(spec, (1,)) ['data3', 'data4'] Returns empty when no group mark exists at scope >>> get_datasets_for_scope(spec, (1, 3)) [] """ group = get_group_mark_for_scope(vega_spec, scope) or {} # get datasets from group datasets = [] for dataset in group.get("data", []): datasets.append(dataset["name"]) # Add facet dataset facet_dataset = group.get("from", {}).get("facet", {}).get("name", None) if facet_dataset: datasets.append(facet_dataset) return datasets def get_definition_scope_for_data_reference( vega_spec: dict, data_name: str, usage_scope: Scope ) -> Optional[Scope]: """Return the scope that a dataset is defined at, for a given usage scope Parameters ---------- vega_spec: dict Top-level Vega specification data_name: str The name of a dataset reference usage_scope: tuple of int The scope that the dataset is referenced in Returns ------- tuple of int The scope where the referenced dataset is defined, or None if no such dataset is found Examples -------- >>> spec = { ... "data": [ ... {"name": "data1"} ... ], ... "marks": [ ... { ... "type": "group", ... "data": [ ... {"name": "data2"} ... ], ... "marks": [{ ... "type": "symbol", ... "encode": { ... "update": { ... "x": {"field": "x", "data": "data1"}, ... "y": {"field": "y", "data": "data2"}, ... } ... } ... }] ... } ... ] ... } data1 is referenced at scope [0] and defined at scope [] >>> get_definition_scope_for_data_reference(spec, "data1", (0,)) () data2 is referenced at scope [0] and defined at scope [0] >>> get_definition_scope_for_data_reference(spec, "data2", (0,)) (0,) If data2 is not visible at scope [] (the top level), because it's defined in scope [0] >>> repr(get_definition_scope_for_data_reference(spec, "data2", ())) 'None' """ for i in reversed(range(len(usage_scope) + 1)): scope = usage_scope[:i] datasets = get_datasets_for_scope(vega_spec, scope) if data_name in datasets: return scope return None def get_facet_mapping(group: dict, scope: Scope = ()) -> FacetMapping: """Create mapping from facet definitions to source datasets Parameters ---------- group : dict Top-level Vega spec or nested group mark scope : tuple of int Scope of the group dictionary within a top-level Vega spec Returns ------- dict Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope) Examples -------- >>> spec = { ... "data": [ ... {"name": "data1"} ... ], ... "marks": [ ... { ... "type": "group", ... "from": { ... "facet": { ... "name": "facet1", ... "data": "data1", ... "groupby": ["colA"] ... } ... } ... } ... ] ... } >>> get_facet_mapping(spec) {('facet1', (0,)): ('data1', ())} """ facet_mapping = {} group_index = 0 mark_group = get_group_mark_for_scope(group, scope) or {} for mark in mark_group.get("marks", []): if mark.get("type", None) == "group": # Get facet for this group group_scope = scope + (group_index,) facet = mark.get("from", {}).get("facet", None) if facet is not None: facet_name = facet.get("name", None) facet_data = facet.get("data", None) if facet_name is not None and facet_data is not None: definition_scope = get_definition_scope_for_data_reference( group, facet_data, scope ) if definition_scope is not None: facet_mapping[(facet_name, group_scope)] = ( facet_data, definition_scope, ) # Handle children recursively child_mapping = get_facet_mapping(group, scope=group_scope) facet_mapping.update(child_mapping) group_index += 1 return facet_mapping def get_from_facet_mapping( scoped_dataset: Tuple[str, Scope], facet_mapping: FacetMapping ) -> Tuple[str, Scope]: """Apply facet mapping to a scoped dataset Parameters ---------- scoped_dataset : (str, tuple of int) A dataset name and scope tuple facet_mapping : dict from (str, tuple of int) to (str, tuple of int) The facet mapping produced by get_facet_mapping Returns ------- (str, tuple of int) Dataset name and scope tuple that has been mapped as many times as possible Examples -------- Facet mapping as produced by get_facet_mapping >>> facet_mapping = {("facet1", (0,)): ("data1", ()), ("facet2", (0, 1)): ("facet1", (0,))} >>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping) ('data1', ()) """ while scoped_dataset in facet_mapping: scoped_dataset = facet_mapping[scoped_dataset] return scoped_dataset def get_datasets_for_view_names( group: dict, vl_chart_names: List[str], facet_mapping: FacetMapping, scope: Scope = (), ) -> Dict[str, Tuple[str, Scope]]: """Get the Vega datasets that correspond to the provided Altair view names Parameters ---------- group : dict Top-level Vega spec or nested group mark vl_chart_names : list of str List of the Vega-Lite facet_mapping : dict from (str, tuple of int) to (str, tuple of int) The facet mapping produced by get_facet_mapping scope : tuple of int Scope of the group dictionary within a top-level Vega spec Returns ------- dict from str to (str, tuple of int) Dict from Altair view names to scoped datasets """ datasets = {} group_index = 0 mark_group = get_group_mark_for_scope(group, scope) or {} for mark in mark_group.get("marks", []): for vl_chart_name in vl_chart_names: if mark.get("name", "") == f"{vl_chart_name}_cell": data_name = mark.get("from", {}).get("facet", None).get("data", None) scoped_data_name = (data_name, scope) datasets[vl_chart_name] = get_from_facet_mapping( scoped_data_name, facet_mapping ) break name = mark.get("name", "") if mark.get("type", "") == "group": group_data_names = get_datasets_for_view_names( group, vl_chart_names, facet_mapping, scope=scope + (group_index,) ) for k, v in group_data_names.items(): datasets.setdefault(k, v) group_index += 1 else: for vl_chart_name in vl_chart_names: if name.startswith(vl_chart_name) and name.endswith("_marks"): data_name = mark.get("from", {}).get("data", None) scoped_data = get_definition_scope_for_data_reference( group, data_name, scope ) if scoped_data is not None: datasets[vl_chart_name] = get_from_facet_mapping( (data_name, scoped_data), facet_mapping ) break return datasets