| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numbers |
| import operator |
| from collections.abc import Iterable, Mapping, MutableMapping |
| from functools import partial |
|
|
| |
| from typing import Any, Callable, Generic, Optional, TypeVar, Union |
|
|
| import numpy as np |
| import pandas as pd |
| import pyarrow as pa |
|
|
| from ..features import Features |
| from ..features.features import _ArrayXDExtensionType, _is_zero_copy_only, decode_nested_example, pandas_types_mapper |
| from ..table import Table |
| from ..utils.py_utils import no_op_if_value_is_null |
|
|
|
|
| T = TypeVar("T") |
|
|
| RowFormat = TypeVar("RowFormat") |
| ColumnFormat = TypeVar("ColumnFormat") |
| BatchFormat = TypeVar("BatchFormat") |
|
|
|
|
| def _is_range_contiguous(key: range) -> bool: |
| return key.step == 1 and key.stop >= key.start |
|
|
|
|
| def _raise_bad_key_type(key: Any): |
| raise TypeError( |
| f"Wrong key type: '{key}' of type '{type(key)}'. Expected one of int, slice, range, str or Iterable." |
| ) |
|
|
|
|
| def _query_table_with_indices_mapping( |
| table: Table, key: Union[int, slice, range, str, Iterable], indices: Table |
| ) -> pa.Table: |
| """ |
| Query a pyarrow Table to extract the subtable that correspond to the given key. |
| The :obj:`indices` parameter corresponds to the indices mapping in case we cant to take into |
| account a shuffling or an indices selection for example. |
| The indices table must contain one column named "indices" of type uint64. |
| """ |
| if isinstance(key, int): |
| key = indices.fast_slice(key % indices.num_rows, 1).column(0)[0].as_py() |
| return _query_table(table, key) |
| if isinstance(key, slice): |
| key = range(*key.indices(indices.num_rows)) |
| if isinstance(key, range): |
| if _is_range_contiguous(key) and key.start >= 0: |
| return _query_table( |
| table, [i.as_py() for i in indices.fast_slice(key.start, key.stop - key.start).column(0)] |
| ) |
| else: |
| pass |
| if isinstance(key, str): |
| table = table.select([key]) |
| return _query_table(table, indices.column(0).to_pylist()) |
| if isinstance(key, Iterable): |
| return _query_table(table, [indices.fast_slice(i, 1).column(0)[0].as_py() for i in key]) |
|
|
| _raise_bad_key_type(key) |
|
|
|
|
| def _query_table(table: Table, key: Union[int, slice, range, str, Iterable]) -> pa.Table: |
| """ |
| Query a pyarrow Table to extract the subtable that correspond to the given key. |
| """ |
| if isinstance(key, int): |
| return table.fast_slice(key % table.num_rows, 1) |
| if isinstance(key, slice): |
| key = range(*key.indices(table.num_rows)) |
| if isinstance(key, range): |
| if _is_range_contiguous(key) and key.start >= 0: |
| return table.fast_slice(key.start, key.stop - key.start) |
| else: |
| pass |
| if isinstance(key, str): |
| return table.table.drop([column for column in table.column_names if column != key]) |
| if isinstance(key, Iterable): |
| key = np.fromiter(key, np.int64) |
| if len(key) == 0: |
| return table.table.slice(0, 0) |
| |
| return table.fast_gather(key % table.num_rows) |
|
|
| _raise_bad_key_type(key) |
|
|
|
|
| def _is_array_with_nulls(pa_array: pa.Array) -> bool: |
| return pa_array.null_count > 0 |
|
|
|
|
| class BaseArrowExtractor(Generic[RowFormat, ColumnFormat, BatchFormat]): |
| """ |
| Arrow extractor are used to extract data from pyarrow tables. |
| It makes it possible to extract rows, columns and batches. |
| These three extractions types have to be implemented. |
| """ |
|
|
| def extract_row(self, pa_table: pa.Table) -> RowFormat: |
| raise NotImplementedError |
|
|
| def extract_column(self, pa_table: pa.Table) -> ColumnFormat: |
| raise NotImplementedError |
|
|
| def extract_batch(self, pa_table: pa.Table) -> BatchFormat: |
| raise NotImplementedError |
|
|
|
|
| def _unnest(py_dict: dict[str, list[T]]) -> dict[str, T]: |
| """Return the first element of a batch (dict) as a row (dict)""" |
| return {key: array[0] for key, array in py_dict.items()} |
|
|
|
|
| class SimpleArrowExtractor(BaseArrowExtractor[pa.Table, pa.Array, pa.Table]): |
| def extract_row(self, pa_table: pa.Table) -> pa.Table: |
| return pa_table |
|
|
| def extract_column(self, pa_table: pa.Table) -> pa.Array: |
| return pa_table.column(0) |
|
|
| def extract_batch(self, pa_table: pa.Table) -> pa.Table: |
| return pa_table |
|
|
|
|
| class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]): |
| def extract_row(self, pa_table: pa.Table) -> dict: |
| return _unnest(pa_table.to_pydict()) |
|
|
| def extract_column(self, pa_table: pa.Table) -> list: |
| return pa_table.column(0).to_pylist() |
|
|
| def extract_batch(self, pa_table: pa.Table) -> dict: |
| return pa_table.to_pydict() |
|
|
|
|
| class NumpyArrowExtractor(BaseArrowExtractor[dict, np.ndarray, dict]): |
| def __init__(self, **np_array_kwargs): |
| self.np_array_kwargs = np_array_kwargs |
|
|
| def extract_row(self, pa_table: pa.Table) -> dict: |
| return _unnest(self.extract_batch(pa_table)) |
|
|
| def extract_column(self, pa_table: pa.Table) -> np.ndarray: |
| return self._arrow_array_to_numpy(pa_table[pa_table.column_names[0]]) |
|
|
| def extract_batch(self, pa_table: pa.Table) -> dict: |
| return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names} |
|
|
| def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray: |
| if isinstance(pa_array, pa.ChunkedArray): |
| if isinstance(pa_array.type, _ArrayXDExtensionType): |
| |
| zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) |
| array: list = [ |
| row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) |
| ] |
| else: |
| zero_copy_only = _is_zero_copy_only(pa_array.type) and all( |
| not _is_array_with_nulls(chunk) for chunk in pa_array.chunks |
| ) |
| array: list = [ |
| row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) |
| ] |
| else: |
| if isinstance(pa_array.type, _ArrayXDExtensionType): |
| |
| zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) |
| array: list = pa_array.to_numpy(zero_copy_only=zero_copy_only) |
| else: |
| zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array) |
| array: list = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist() |
|
|
| if len(array) > 0: |
| if any( |
| (isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape)) |
| or (isinstance(x, float) and np.isnan(x)) |
| for x in array |
| ): |
| if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": |
| return np.asarray(array, dtype=object) |
| return np.array(array, copy=False, dtype=object) |
| if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": |
| return np.asarray(array) |
| else: |
| return np.array(array, copy=False) |
|
|
|
|
| class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]): |
| def extract_row(self, pa_table: pa.Table) -> pd.DataFrame: |
| return pa_table.slice(length=1).to_pandas(types_mapper=pandas_types_mapper) |
|
|
| def extract_column(self, pa_table: pa.Table) -> pd.Series: |
| return pa_table.select([0]).to_pandas(types_mapper=pandas_types_mapper)[pa_table.column_names[0]] |
|
|
| def extract_batch(self, pa_table: pa.Table) -> pd.DataFrame: |
| return pa_table.to_pandas(types_mapper=pandas_types_mapper) |
|
|
|
|
| class PythonFeaturesDecoder: |
| def __init__( |
| self, features: Optional[Features], token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None |
| ): |
| self.features = features |
| self.token_per_repo_id = token_per_repo_id |
|
|
| def decode_row(self, row: dict) -> dict: |
| return self.features.decode_example(row, token_per_repo_id=self.token_per_repo_id) if self.features else row |
|
|
| def decode_column(self, column: list, column_name: str) -> list: |
| return ( |
| self.features.decode_column(column, column_name, token_per_repo_id=self.token_per_repo_id) |
| if self.features |
| else column |
| ) |
|
|
| def decode_batch(self, batch: dict) -> dict: |
| return self.features.decode_batch(batch, token_per_repo_id=self.token_per_repo_id) if self.features else batch |
|
|
|
|
| class PandasFeaturesDecoder: |
| def __init__(self, features: Optional[Features]): |
| self.features = features |
|
|
| def decode_row(self, row: pd.DataFrame) -> pd.DataFrame: |
| decode = ( |
| { |
| column_name: no_op_if_value_is_null(partial(decode_nested_example, feature)) |
| for column_name, feature in self.features.items() |
| if self.features._column_requires_decoding[column_name] |
| } |
| if self.features |
| else {} |
| ) |
| if decode: |
| row[list(decode.keys())] = row.transform(decode) |
| return row |
|
|
| def decode_column(self, column: pd.Series, column_name: str) -> pd.Series: |
| decode = ( |
| no_op_if_value_is_null(partial(decode_nested_example, self.features[column_name])) |
| if self.features and column_name in self.features and self.features._column_requires_decoding[column_name] |
| else None |
| ) |
| if decode: |
| column = column.transform(decode) |
| return column |
|
|
| def decode_batch(self, batch: pd.DataFrame) -> pd.DataFrame: |
| return self.decode_row(batch) |
|
|
|
|
| class LazyDict(MutableMapping): |
| """A dictionary backed by Arrow data. The values are formatted on-the-fly when accessing the dictionary.""" |
|
|
| def __init__(self, pa_table: pa.Table, formatter: "Formatter"): |
| self.pa_table = pa_table |
| self.formatter = formatter |
|
|
| self.data = dict.fromkeys(pa_table.column_names) |
| self.keys_to_format = set(self.data.keys()) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, key): |
| value = self.data[key] |
| if key in self.keys_to_format: |
| value = self.format(key) |
| self.data[key] = value |
| self.keys_to_format.remove(key) |
| return value |
|
|
| def __setitem__(self, key, value): |
| if key in self.keys_to_format: |
| self.keys_to_format.remove(key) |
| self.data[key] = value |
|
|
| def __delitem__(self, key) -> None: |
| if key in self.keys_to_format: |
| self.keys_to_format.remove(key) |
| del self.data[key] |
|
|
| def __iter__(self): |
| return iter(self.data) |
|
|
| def __contains__(self, key): |
| return key in self.data |
|
|
| def __repr__(self): |
| self._format_all() |
| return repr(self.data) |
|
|
| def __or__(self, other): |
| if isinstance(other, LazyDict): |
| inst = self.copy() |
| other = other.copy() |
| other._format_all() |
| inst.keys_to_format -= other.data.keys() |
| inst.data = inst.data | other.data |
| return inst |
| if isinstance(other, dict): |
| inst = self.copy() |
| inst.keys_to_format -= other.keys() |
| inst.data = inst.data | other |
| return inst |
| return NotImplemented |
|
|
| def __ror__(self, other): |
| if isinstance(other, LazyDict): |
| inst = self.copy() |
| other = other.copy() |
| other._format_all() |
| inst.keys_to_format -= other.data.keys() |
| inst.data = other.data | inst.data |
| return inst |
| if isinstance(other, dict): |
| inst = self.copy() |
| inst.keys_to_format -= other.keys() |
| inst.data = other | inst.data |
| return inst |
| return NotImplemented |
|
|
| def __ior__(self, other): |
| if isinstance(other, LazyDict): |
| other = other.copy() |
| other._format_all() |
| self.keys_to_format -= other.data.keys() |
| self.data |= other.data |
| else: |
| self.keys_to_format -= other.keys() |
| self.data |= other |
| return self |
|
|
| def __copy__(self): |
| |
| inst = self.__class__.__new__(self.__class__) |
| inst.__dict__.update(self.__dict__) |
| |
| inst.__dict__["data"] = self.__dict__["data"].copy() |
| inst.__dict__["keys_to_format"] = self.__dict__["keys_to_format"].copy() |
| return inst |
|
|
| def copy(self): |
| import copy |
|
|
| return copy.copy(self) |
|
|
| @classmethod |
| def fromkeys(cls, iterable, value=None): |
| raise NotImplementedError |
|
|
| def format(self, key): |
| raise NotImplementedError |
|
|
| def _format_all(self): |
| for key in self.keys_to_format: |
| self.data[key] = self.format(key) |
| self.keys_to_format.clear() |
|
|
|
|
| class LazyRow(LazyDict): |
| def format(self, key): |
| return self.formatter.format_column(self.pa_table.select([key]))[0] |
|
|
|
|
| class LazyBatch(LazyDict): |
| def format(self, key): |
| return self.formatter.format_column(self.pa_table.select([key])) |
|
|
|
|
| class Formatter(Generic[RowFormat, ColumnFormat, BatchFormat]): |
| """ |
| A formatter is an object that extracts and formats data from pyarrow tables. |
| It defines the formatting for rows, columns and batches. |
| """ |
|
|
| simple_arrow_extractor = SimpleArrowExtractor |
| python_arrow_extractor = PythonArrowExtractor |
| numpy_arrow_extractor = NumpyArrowExtractor |
| pandas_arrow_extractor = PandasArrowExtractor |
|
|
| def __init__( |
| self, |
| features: Optional[Features] = None, |
| token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None, |
| ): |
| self.features = features |
| self.token_per_repo_id = token_per_repo_id |
| self.python_features_decoder = PythonFeaturesDecoder(self.features, self.token_per_repo_id) |
| self.pandas_features_decoder = PandasFeaturesDecoder(self.features) |
|
|
| def __call__(self, pa_table: pa.Table, query_type: str) -> Union[RowFormat, ColumnFormat, BatchFormat]: |
| if query_type == "row": |
| return self.format_row(pa_table) |
| elif query_type == "column": |
| return self.format_column(pa_table) |
| elif query_type == "batch": |
| return self.format_batch(pa_table) |
|
|
| def format_row(self, pa_table: pa.Table) -> RowFormat: |
| raise NotImplementedError |
|
|
| def format_column(self, pa_table: pa.Table) -> ColumnFormat: |
| raise NotImplementedError |
|
|
| def format_batch(self, pa_table: pa.Table) -> BatchFormat: |
| raise NotImplementedError |
|
|
|
|
| class TensorFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]): |
| def recursive_tensorize(self, data_struct: dict): |
| raise NotImplementedError |
|
|
|
|
| class TableFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]): |
| table_type: str |
| column_type: str |
|
|
|
|
| class ArrowFormatter(TableFormatter[pa.Table, pa.Array, pa.Table]): |
| table_type = "arrow table" |
| column_type = "arrow array" |
|
|
| def format_row(self, pa_table: pa.Table) -> pa.Table: |
| return self.simple_arrow_extractor().extract_row(pa_table) |
|
|
| def format_column(self, pa_table: pa.Table) -> pa.Array: |
| return self.simple_arrow_extractor().extract_column(pa_table) |
|
|
| def format_batch(self, pa_table: pa.Table) -> pa.Table: |
| return self.simple_arrow_extractor().extract_batch(pa_table) |
|
|
|
|
| class PythonFormatter(Formatter[Mapping, list, Mapping]): |
| def __init__(self, features=None, lazy=False, token_per_repo_id=None): |
| super().__init__(features, token_per_repo_id) |
| self.lazy = lazy |
|
|
| def format_row(self, pa_table: pa.Table) -> Mapping: |
| if self.lazy: |
| return LazyRow(pa_table, self) |
| row = self.python_arrow_extractor().extract_row(pa_table) |
| row = self.python_features_decoder.decode_row(row) |
| return row |
|
|
| def format_column(self, pa_table: pa.Table) -> list: |
| column = self.python_arrow_extractor().extract_column(pa_table) |
| column = self.python_features_decoder.decode_column(column, pa_table.column_names[0]) |
| return column |
|
|
| def format_batch(self, pa_table: pa.Table) -> Mapping: |
| if self.lazy: |
| return LazyBatch(pa_table, self) |
| batch = self.python_arrow_extractor().extract_batch(pa_table) |
| batch = self.python_features_decoder.decode_batch(batch) |
| return batch |
|
|
|
|
| class PandasFormatter(TableFormatter[pd.DataFrame, pd.Series, pd.DataFrame]): |
| table_type = "pandas dataframe" |
| column_type = "pandas series" |
|
|
| def format_row(self, pa_table: pa.Table) -> pd.DataFrame: |
| row = self.pandas_arrow_extractor().extract_row(pa_table) |
| row = self.pandas_features_decoder.decode_row(row) |
| return row |
|
|
| def format_column(self, pa_table: pa.Table) -> pd.Series: |
| column = self.pandas_arrow_extractor().extract_column(pa_table) |
| column = self.pandas_features_decoder.decode_column(column, pa_table.column_names[0]) |
| return column |
|
|
| def format_batch(self, pa_table: pa.Table) -> pd.DataFrame: |
| row = self.pandas_arrow_extractor().extract_batch(pa_table) |
| row = self.pandas_features_decoder.decode_batch(row) |
| return row |
|
|
|
|
| class CustomFormatter(Formatter[dict, ColumnFormat, dict]): |
| """ |
| A user-defined custom formatter function defined by a ``transform``. |
| The transform must take as input a batch of data extracted for an arrow table using the python extractor, |
| and return a batch. |
| If the output batch is not a dict, then output_all_columns won't work. |
| If the output batch has several fields, then querying a single column won't work since we don't know which field |
| to return. |
| """ |
|
|
| def __init__(self, transform: Callable[[dict], dict], features=None, token_per_repo_id=None, **kwargs): |
| super().__init__(features=features, token_per_repo_id=token_per_repo_id) |
| self.transform = transform |
|
|
| def format_row(self, pa_table: pa.Table) -> dict: |
| formatted_batch = self.format_batch(pa_table) |
| try: |
| return _unnest(formatted_batch) |
| except Exception as exc: |
| raise TypeError( |
| f"Custom formatting function must return a dict of sequences to be able to pick a row, but got {formatted_batch}" |
| ) from exc |
|
|
| def format_column(self, pa_table: pa.Table) -> ColumnFormat: |
| formatted_batch = self.format_batch(pa_table) |
| if hasattr(formatted_batch, "keys"): |
| if len(formatted_batch.keys()) > 1: |
| raise TypeError( |
| "Tried to query a column but the custom formatting function returns too many columns. " |
| f"Only one column was expected but got columns {list(formatted_batch.keys())}." |
| ) |
| else: |
| raise TypeError( |
| f"Custom formatting function must return a dict to be able to pick a row, but got {formatted_batch}" |
| ) |
| try: |
| return formatted_batch[pa_table.column_names[0]] |
| except Exception as exc: |
| raise TypeError( |
| f"Custom formatting function must return a dict to be able to pick a row, but got {formatted_batch}" |
| ) from exc |
|
|
| def format_batch(self, pa_table: pa.Table) -> dict: |
| batch = self.python_arrow_extractor().extract_batch(pa_table) |
| batch = self.python_features_decoder.decode_batch(batch) |
| return self.transform(batch) |
|
|
|
|
| def _check_valid_column_key(key: str, columns: list[str]) -> None: |
| if key not in columns: |
| raise KeyError(f"Column {key} not in the dataset. Current columns in the dataset: {columns}") |
|
|
|
|
| def _check_valid_index_key(key: Union[int, slice, range, Iterable], size: int) -> None: |
| if isinstance(key, int): |
| if (key < 0 and key + size < 0) or (key >= size): |
| raise IndexError(f"Invalid key: {key} is out of bounds for size {size}") |
| return |
| elif isinstance(key, slice): |
| pass |
| elif isinstance(key, range): |
| if len(key) > 0: |
| _check_valid_index_key(max(key), size=size) |
| _check_valid_index_key(min(key), size=size) |
| elif isinstance(key, Iterable): |
| if len(key) > 0: |
| _check_valid_index_key(int(max(key)), size=size) |
| _check_valid_index_key(int(min(key)), size=size) |
| else: |
| _raise_bad_key_type(key) |
|
|
|
|
| def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str: |
| if isinstance(key, numbers.Integral): |
| return "row" |
| elif isinstance(key, str): |
| return "column" |
| elif isinstance(key, (slice, range, Iterable)): |
| return "batch" |
| _raise_bad_key_type(key) |
|
|
|
|
| def query_table( |
| table: Table, |
| key: Union[int, slice, range, str, Iterable], |
| indices: Optional[Table] = None, |
| ) -> pa.Table: |
| """ |
| Query a Table to extract the subtable that correspond to the given key. |
| |
| Args: |
| table (``datasets.table.Table``): The input Table to query from |
| key (``Union[int, slice, range, str, Iterable]``): The key can be of different types: |
| - an integer i: the subtable containing only the i-th row |
| - a slice [i:j:k]: the subtable containing the rows that correspond to this slice |
| - a range(i, j, k): the subtable containing the rows that correspond to this range |
| - a string c: the subtable containing all the rows but only the column c |
| - an iterable l: the subtable that is the concatenation of all the i-th rows for all i in the iterable |
| indices (Optional ``datasets.table.Table``): If not None, it is used to re-map the given key to the table rows. |
| The indices table must contain one column named "indices" of type uint64. |
| This is used in case of shuffling or rows selection. |
| |
| |
| Returns: |
| ``pyarrow.Table``: the result of the query on the input table |
| """ |
| |
| if not isinstance(key, (int, slice, range, str, Iterable)): |
| try: |
| key = operator.index(key) |
| except TypeError: |
| _raise_bad_key_type(key) |
| if isinstance(key, str): |
| _check_valid_column_key(key, table.column_names) |
| else: |
| size = indices.num_rows if indices is not None else table.num_rows |
| _check_valid_index_key(key, size) |
| |
| if indices is None: |
| pa_subtable = _query_table(table, key) |
| else: |
| pa_subtable = _query_table_with_indices_mapping(table, key, indices=indices) |
| return pa_subtable |
|
|
|
|
| def format_table( |
| table: Table, |
| key: Union[int, slice, range, str, Iterable], |
| formatter: Formatter, |
| format_columns: Optional[list] = None, |
| output_all_columns=False, |
| ): |
| """ |
| Format a Table depending on the key that was used and a Formatter object. |
| |
| Args: |
| table (``datasets.table.Table``): The input Table to format |
| key (``Union[int, slice, range, str, Iterable]``): Depending on the key that was used, the formatter formats |
| the table as either a row, a column or a batch. |
| formatter (``datasets.formatting.formatting.Formatter``): Any subclass of a Formatter such as |
| PythonFormatter, NumpyFormatter, etc. |
| format_columns (:obj:`List[str]`, optional): if not None, it defines the columns that will be formatted using the |
| given formatter. Other columns are discarded (unless ``output_all_columns`` is True) |
| output_all_columns (:obj:`bool`, defaults to False). If True, the formatted output is completed using the columns |
| that are not in the ``format_columns`` list. For these columns, the PythonFormatter is used. |
| |
| |
| Returns: |
| A row, column or batch formatted object defined by the Formatter: |
| - the PythonFormatter returns a dictionary for a row or a batch, and a list for a column. |
| - the NumpyFormatter returns a dictionary for a row or a batch, and a np.array for a column. |
| - the PandasFormatter returns a pd.DataFrame for a row or a batch, and a pd.Series for a column. |
| - the TorchFormatter returns a dictionary for a row or a batch, and a torch.Tensor for a column. |
| - the TFFormatter returns a dictionary for a row or a batch, and a tf.Tensor for a column. |
| """ |
| if isinstance(table, Table): |
| pa_table = table.table |
| else: |
| pa_table = table |
| query_type = key_to_query_type(key) |
| python_formatter = PythonFormatter(features=formatter.features) |
| if format_columns is None: |
| return formatter(pa_table, query_type=query_type) |
| elif query_type == "column": |
| if key in format_columns: |
| return formatter(pa_table, query_type) |
| else: |
| return python_formatter(pa_table, query_type=query_type) |
| else: |
| pa_table_to_format = pa_table.drop(col for col in pa_table.column_names if col not in format_columns) |
| formatted_output = formatter(pa_table_to_format, query_type=query_type) |
| if output_all_columns: |
| if isinstance(formatted_output, MutableMapping): |
| pa_table_with_remaining_columns = pa_table.drop( |
| col for col in pa_table.column_names if col in format_columns |
| ) |
| remaining_columns_dict = python_formatter(pa_table_with_remaining_columns, query_type=query_type) |
| formatted_output.update(remaining_columns_dict) |
| else: |
| raise TypeError( |
| f"Custom formatting function must return a dict to work with output_all_columns=True, but got {formatted_output}" |
| ) |
| return formatted_output |
|
|