SimplifyMe's picture
Upload folder using huggingface_hub
9cddcfd
raw
history blame contribute delete
No virus
18.1 kB
from typing import List, Optional, Tuple, Dict, Iterable, overload, Union
from altair import (
Chart,
FacetChart,
LayerChart,
HConcatChart,
VConcatChart,
ConcatChart,
TopLevelUnitSpec,
FacetedUnitSpec,
UnitSpec,
UnitSpecWithFrame,
NonNormalizedSpec,
TopLevelLayerSpec,
LayerSpec,
TopLevelConcatSpec,
ConcatSpecGenericSpec,
TopLevelHConcatSpec,
HConcatSpecGenericSpec,
TopLevelVConcatSpec,
VConcatSpecGenericSpec,
TopLevelFacetSpec,
FacetSpec,
data_transformers,
)
from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion
from altair.utils.core import _DataFrameLike
from altair.utils.schemapi import Undefined
Scope = Tuple[int, ...]
FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]]
# For the transformed_data functionality, the chart classes in the values
# can be considered equivalent to the chart class in the key.
_chart_class_mapping = {
Chart: (
Chart,
TopLevelUnitSpec,
FacetedUnitSpec,
UnitSpec,
UnitSpecWithFrame,
NonNormalizedSpec,
),
LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec),
ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec),
HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec),
VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec),
FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec),
}
@overload
def transformed_data(
chart: Union[Chart, FacetChart],
row_limit: Optional[int] = None,
exclude: Optional[Iterable[str]] = None,
) -> Optional[_DataFrameLike]:
...
@overload
def transformed_data(
chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart],
row_limit: Optional[int] = None,
exclude: Optional[Iterable[str]] = None,
) -> List[_DataFrameLike]:
...
def transformed_data(chart, row_limit=None, exclude=None):
"""Evaluate a Chart's transforms
Evaluate the data transforms associated with a Chart and return the
transformed data as one or more DataFrames
Parameters
----------
chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart
Altair chart to evaluate transforms on
row_limit : int (optional)
Maximum number of rows to return for each DataFrame. None (default) for unlimited
exclude : iterable of str
Set of the names of charts to exclude
Returns
-------
DataFrame or list of DataFrames or None
If input chart is a Chart or Facet Chart, returns a DataFrame of the
transformed data. Otherwise, returns a list of DataFrames of the
transformed data
"""
vf = import_vegafusion()
if isinstance(chart, Chart):
# Add mark if none is specified to satisfy Vega-Lite
if chart.mark == Undefined:
chart = chart.mark_point()
# Deep copy chart so that we can rename marks without affecting caller
chart = chart.copy(deep=True)
# Ensure that all views are named so that we can look them up in the
# resulting Vega specification
chart_names = name_views(chart, 0, exclude=exclude)
# Compile to Vega and extract inline DataFrames
with data_transformers.enable("vegafusion"):
vega_spec = chart.to_dict(format="vega", context={"pre_transform": False})
inline_datasets = get_inline_tables(vega_spec)
# Build mapping from mark names to vega datasets
facet_mapping = get_facet_mapping(vega_spec)
dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping)
# Build a list of vega dataset names that corresponds to the order
# of the chart components
dataset_names = []
for chart_name in chart_names:
if chart_name in dataset_mapping:
dataset_names.append(dataset_mapping[chart_name])
else:
raise ValueError("Failed to locate all datasets")
# Extract transformed datasets with VegaFusion
datasets, warnings = vf.runtime.pre_transform_datasets(
vega_spec,
dataset_names,
row_limit=row_limit,
inline_datasets=inline_datasets,
)
if isinstance(chart, (Chart, FacetChart)):
# Return DataFrame (or None if it was excluded) if input was a simple Chart
if not datasets:
return None
else:
return datasets[0]
else:
# Otherwise return the list of DataFrames
return datasets
# The equivalent classes from _chart_class_mapping should also be added
# to the type hints below for `chart` as the function would also work for them.
# However, this was not possible so far as mypy then complains about
# "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]"
# This might be due to the complex type hierarchy of the chart classes.
# See also https://github.com/python/mypy/issues/5119
# and https://github.com/python/mypy/issues/4020 which show that mypy might not have
# a very consistent behavior for overloaded functions.
# The same error appeared when trying it with Protocols for the concat and layer charts.
# This function is only used internally and so we accept this inconsistency for now.
def name_views(
chart: Union[
Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart
],
i: int = 0,
exclude: Optional[Iterable[str]] = None,
) -> List[str]:
"""Name unnamed chart views
Name unnamed charts views so that we can look them up later in
the compiled Vega spec.
Note: This function mutates the input chart by applying names to
unnamed views.
Parameters
----------
chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart
Altair chart to apply names to
i : int (default 0)
Starting chart index
exclude : iterable of str
Names of charts to exclude
Returns
-------
list of str
List of the names of the charts and subcharts
"""
exclude = set(exclude) if exclude is not None else set()
if isinstance(chart, _chart_class_mapping[Chart]) or isinstance(
chart, _chart_class_mapping[FacetChart]
):
if chart.name not in exclude:
if chart.name in (None, Undefined):
# Add name since none is specified
chart.name = Chart._get_name()
return [chart.name]
else:
return []
else:
if isinstance(chart, _chart_class_mapping[LayerChart]):
subcharts = chart.layer
elif isinstance(chart, _chart_class_mapping[HConcatChart]):
subcharts = chart.hconcat
elif isinstance(chart, _chart_class_mapping[VConcatChart]):
subcharts = chart.vconcat
elif isinstance(chart, _chart_class_mapping[ConcatChart]):
subcharts = chart.concat
else:
raise ValueError(
"transformed_data accepts an instance of "
"Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n"
f"Received value of type: {type(chart)}"
)
chart_names: List[str] = []
for subchart in subcharts:
for name in name_views(subchart, i=i + len(chart_names), exclude=exclude):
chart_names.append(name)
return chart_names
def get_group_mark_for_scope(vega_spec: dict, scope: Scope) -> Optional[dict]:
"""Get the group mark at a particular scope
Parameters
----------
vega_spec : dict
Top-level Vega specification dictionary
scope : tuple of int
Scope tuple. If empty, the original Vega specification is returned.
Otherwise, the nested group mark at the scope specified is returned.
Returns
-------
dict or None
Top-level Vega spec (if scope is empty)
or group mark (if scope is non-empty)
or None (if group mark at scope does not exist)
Examples
--------
>>> spec = {
... "marks": [
... {
... "type": "group",
... "marks": [{"type": "symbol"}]
... },
... {
... "type": "group",
... "marks": [{"type": "rect"}]}
... ]
... }
>>> get_group_mark_for_scope(spec, (1,))
{'type': 'group', 'marks': [{'type': 'rect'}]}
"""
group = vega_spec
# Find group at scope
for scope_value in scope:
group_index = 0
child_group = None
for mark in group.get("marks", []):
if mark.get("type") == "group":
if group_index == scope_value:
child_group = mark
break
group_index += 1
if child_group is None:
return None
group = child_group
return group
def get_datasets_for_scope(vega_spec: dict, scope: Scope) -> List[str]:
"""Get the names of the datasets that are defined at a given scope
Parameters
----------
vega_spec : dict
Top-leve Vega specification
scope : tuple of int
Scope tuple. If empty, the names of top-level datasets are returned
Otherwise, the names of the datasets defined in the nested group mark
at the specified scope are returned.
Returns
-------
list of str
List of the names of the datasets defined at the specified scope
Examples
--------
>>> spec = {
... "data": [
... {"name": "data1"}
... ],
... "marks": [
... {
... "type": "group",
... "data": [
... {"name": "data2"}
... ],
... "marks": [{"type": "symbol"}]
... },
... {
... "type": "group",
... "data": [
... {"name": "data3"},
... {"name": "data4"},
... ],
... "marks": [{"type": "rect"}]
... }
... ]
... }
>>> get_datasets_for_scope(spec, ())
['data1']
>>> get_datasets_for_scope(spec, (0,))
['data2']
>>> get_datasets_for_scope(spec, (1,))
['data3', 'data4']
Returns empty when no group mark exists at scope
>>> get_datasets_for_scope(spec, (1, 3))
[]
"""
group = get_group_mark_for_scope(vega_spec, scope) or {}
# get datasets from group
datasets = []
for dataset in group.get("data", []):
datasets.append(dataset["name"])
# Add facet dataset
facet_dataset = group.get("from", {}).get("facet", {}).get("name", None)
if facet_dataset:
datasets.append(facet_dataset)
return datasets
def get_definition_scope_for_data_reference(
vega_spec: dict, data_name: str, usage_scope: Scope
) -> Optional[Scope]:
"""Return the scope that a dataset is defined at, for a given usage scope
Parameters
----------
vega_spec: dict
Top-level Vega specification
data_name: str
The name of a dataset reference
usage_scope: tuple of int
The scope that the dataset is referenced in
Returns
-------
tuple of int
The scope where the referenced dataset is defined,
or None if no such dataset is found
Examples
--------
>>> spec = {
... "data": [
... {"name": "data1"}
... ],
... "marks": [
... {
... "type": "group",
... "data": [
... {"name": "data2"}
... ],
... "marks": [{
... "type": "symbol",
... "encode": {
... "update": {
... "x": {"field": "x", "data": "data1"},
... "y": {"field": "y", "data": "data2"},
... }
... }
... }]
... }
... ]
... }
data1 is referenced at scope [0] and defined at scope []
>>> get_definition_scope_for_data_reference(spec, "data1", (0,))
()
data2 is referenced at scope [0] and defined at scope [0]
>>> get_definition_scope_for_data_reference(spec, "data2", (0,))
(0,)
If data2 is not visible at scope [] (the top level),
because it's defined in scope [0]
>>> repr(get_definition_scope_for_data_reference(spec, "data2", ()))
'None'
"""
for i in reversed(range(len(usage_scope) + 1)):
scope = usage_scope[:i]
datasets = get_datasets_for_scope(vega_spec, scope)
if data_name in datasets:
return scope
return None
def get_facet_mapping(group: dict, scope: Scope = ()) -> FacetMapping:
"""Create mapping from facet definitions to source datasets
Parameters
----------
group : dict
Top-level Vega spec or nested group mark
scope : tuple of int
Scope of the group dictionary within a top-level Vega spec
Returns
-------
dict
Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope)
Examples
--------
>>> spec = {
... "data": [
... {"name": "data1"}
... ],
... "marks": [
... {
... "type": "group",
... "from": {
... "facet": {
... "name": "facet1",
... "data": "data1",
... "groupby": ["colA"]
... }
... }
... }
... ]
... }
>>> get_facet_mapping(spec)
{('facet1', (0,)): ('data1', ())}
"""
facet_mapping = {}
group_index = 0
mark_group = get_group_mark_for_scope(group, scope) or {}
for mark in mark_group.get("marks", []):
if mark.get("type", None) == "group":
# Get facet for this group
group_scope = scope + (group_index,)
facet = mark.get("from", {}).get("facet", None)
if facet is not None:
facet_name = facet.get("name", None)
facet_data = facet.get("data", None)
if facet_name is not None and facet_data is not None:
definition_scope = get_definition_scope_for_data_reference(
group, facet_data, scope
)
if definition_scope is not None:
facet_mapping[(facet_name, group_scope)] = (
facet_data,
definition_scope,
)
# Handle children recursively
child_mapping = get_facet_mapping(group, scope=group_scope)
facet_mapping.update(child_mapping)
group_index += 1
return facet_mapping
def get_from_facet_mapping(
scoped_dataset: Tuple[str, Scope], facet_mapping: FacetMapping
) -> Tuple[str, Scope]:
"""Apply facet mapping to a scoped dataset
Parameters
----------
scoped_dataset : (str, tuple of int)
A dataset name and scope tuple
facet_mapping : dict from (str, tuple of int) to (str, tuple of int)
The facet mapping produced by get_facet_mapping
Returns
-------
(str, tuple of int)
Dataset name and scope tuple that has been mapped as many times as possible
Examples
--------
Facet mapping as produced by get_facet_mapping
>>> facet_mapping = {("facet1", (0,)): ("data1", ()), ("facet2", (0, 1)): ("facet1", (0,))}
>>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping)
('data1', ())
"""
while scoped_dataset in facet_mapping:
scoped_dataset = facet_mapping[scoped_dataset]
return scoped_dataset
def get_datasets_for_view_names(
group: dict,
vl_chart_names: List[str],
facet_mapping: FacetMapping,
scope: Scope = (),
) -> Dict[str, Tuple[str, Scope]]:
"""Get the Vega datasets that correspond to the provided Altair view names
Parameters
----------
group : dict
Top-level Vega spec or nested group mark
vl_chart_names : list of str
List of the Vega-Lite
facet_mapping : dict from (str, tuple of int) to (str, tuple of int)
The facet mapping produced by get_facet_mapping
scope : tuple of int
Scope of the group dictionary within a top-level Vega spec
Returns
-------
dict from str to (str, tuple of int)
Dict from Altair view names to scoped datasets
"""
datasets = {}
group_index = 0
mark_group = get_group_mark_for_scope(group, scope) or {}
for mark in mark_group.get("marks", []):
for vl_chart_name in vl_chart_names:
if mark.get("name", "") == f"{vl_chart_name}_cell":
data_name = mark.get("from", {}).get("facet", None).get("data", None)
scoped_data_name = (data_name, scope)
datasets[vl_chart_name] = get_from_facet_mapping(
scoped_data_name, facet_mapping
)
break
name = mark.get("name", "")
if mark.get("type", "") == "group":
group_data_names = get_datasets_for_view_names(
group, vl_chart_names, facet_mapping, scope=scope + (group_index,)
)
for k, v in group_data_names.items():
datasets.setdefault(k, v)
group_index += 1
else:
for vl_chart_name in vl_chart_names:
if name.startswith(vl_chart_name) and name.endswith("_marks"):
data_name = mark.get("from", {}).get("data", None)
scoped_data = get_definition_scope_for_data_reference(
group, data_name, scope
)
if scoped_data is not None:
datasets[vl_chart_name] = get_from_facet_mapping(
(data_name, scoped_data), facet_mapping
)
break
return datasets