|
"""Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py. |
|
|
|
MIT License |
|
|
|
Copyright (c) 2018 Alex Rogozhnikov |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
""" |
|
from __future__ import annotations |
|
|
|
import keyword |
|
import warnings |
|
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union |
|
|
|
_ellipsis: str = "…" |
|
|
|
|
|
class AnonymousAxis: |
|
"""Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier. |
|
|
|
Note: Different instances of this class are not equal to each other, even if they have the same value. |
|
""" |
|
|
|
def __init__(self, value: str) -> None: |
|
self.value = int(value) |
|
if self.value < 1: |
|
raise ValueError( |
|
f"Anonymous axis should have positive length, not {self.value}" |
|
) |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.value}-axis" |
|
|
|
|
|
class ParsedExpression: |
|
"""Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)').""" |
|
|
|
def __init__( |
|
self, |
|
expression: str, |
|
*, |
|
allow_underscore: bool = False, |
|
allow_duplicates: bool = False, |
|
) -> None: |
|
"""Parse the expression and store relevant metadata. |
|
|
|
Args: |
|
expression (str): the `einops`-pattern to parse |
|
allow_underscore (bool): whether to allow axis identifier names to begin with an underscore |
|
allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression |
|
""" |
|
self.has_ellipsis: bool = False |
|
self.has_ellipsis_parenthesized: Optional[bool] = None |
|
self.identifiers: Set[Union[str, AnonymousAxis]] = set() |
|
|
|
self.has_non_unitary_anonymous_axes: bool = False |
|
|
|
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = [] |
|
if "." in expression: |
|
if "..." not in expression: |
|
raise ValueError( |
|
"Expression may contain dots only inside ellipsis (...)" |
|
) |
|
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: |
|
raise ValueError( |
|
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " |
|
) |
|
expression = expression.replace("...", _ellipsis) |
|
self.has_ellipsis = True |
|
|
|
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None |
|
|
|
def add_axis_name(x: str) -> None: |
|
if x in self.identifiers: |
|
if not (allow_underscore and x == "_") and not allow_duplicates: |
|
raise ValueError( |
|
f"Indexing expression contains duplicate dimension '{x}'" |
|
) |
|
if x == _ellipsis: |
|
self.identifiers.add(_ellipsis) |
|
if bracket_group is None: |
|
self.composition.append(_ellipsis) |
|
self.has_ellipsis_parenthesized = False |
|
else: |
|
bracket_group.append(_ellipsis) |
|
self.has_ellipsis_parenthesized = True |
|
else: |
|
is_number = str.isdecimal(x) |
|
if is_number and int(x) == 1: |
|
|
|
if bracket_group is None: |
|
self.composition.append([]) |
|
else: |
|
pass |
|
return |
|
is_axis_name, reason = self.check_axis_name_return_reason( |
|
x, allow_underscore=allow_underscore |
|
) |
|
if not (is_number or is_axis_name): |
|
raise ValueError(f"Invalid axis identifier: {x}\n{reason}") |
|
axis_name: Union[str, AnonymousAxis] = ( |
|
AnonymousAxis(x) if is_number else x |
|
) |
|
self.identifiers.add(axis_name) |
|
if is_number: |
|
self.has_non_unitary_anonymous_axes = True |
|
if bracket_group is None: |
|
self.composition.append([axis_name]) |
|
else: |
|
bracket_group.append(axis_name) |
|
|
|
current_identifier = None |
|
for char in expression: |
|
if char in "() ": |
|
if current_identifier is not None: |
|
add_axis_name(current_identifier) |
|
current_identifier = None |
|
if char == "(": |
|
if bracket_group is not None: |
|
raise ValueError( |
|
"Axis composition is one-level (brackets inside brackets not allowed)" |
|
) |
|
bracket_group = [] |
|
elif char == ")": |
|
if bracket_group is None: |
|
raise ValueError("Brackets are not balanced") |
|
self.composition.append(bracket_group) |
|
bracket_group = None |
|
elif str.isalnum(char) or char in ["_", _ellipsis]: |
|
if current_identifier is None: |
|
current_identifier = char |
|
else: |
|
current_identifier += char |
|
else: |
|
raise ValueError(f"Unknown character '{char}'") |
|
|
|
if bracket_group is not None: |
|
raise ValueError(f"Imbalanced parentheses in expression: '{expression}'") |
|
if current_identifier is not None: |
|
add_axis_name(current_identifier) |
|
|
|
@staticmethod |
|
def check_axis_name_return_reason( |
|
name: str, allow_underscore: bool = False |
|
) -> Tuple[bool, str]: |
|
"""Check if the given axis name is valid, and a message explaining why if not. |
|
|
|
Valid axes names are python identifiers except keywords, and should not start or end with an underscore. |
|
|
|
Args: |
|
name (str): the axis name to check |
|
allow_underscore (bool): whether axis names are allowed to start with an underscore |
|
|
|
Returns: |
|
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not |
|
""" |
|
if not str.isidentifier(name): |
|
return False, "not a valid python identifier" |
|
elif name[0] == "_" or name[-1] == "_": |
|
if name == "_" and allow_underscore: |
|
return True, "" |
|
return False, "axis name should should not start or end with underscore" |
|
else: |
|
if keyword.iskeyword(name): |
|
warnings.warn( |
|
f"It is discouraged to use axes names that are keywords: {name}", |
|
RuntimeWarning, |
|
) |
|
if name in ["axis"]: |
|
warnings.warn( |
|
"It is discouraged to use 'axis' as an axis name and will raise an error in future", |
|
FutureWarning, |
|
) |
|
return True, "" |
|
|
|
@staticmethod |
|
def check_axis_name(name: str) -> bool: |
|
"""Check if the name is a valid axis name. |
|
|
|
Args: |
|
name (str): the axis name to check |
|
|
|
Returns: |
|
bool: whether the axis name is valid |
|
""" |
|
is_valid, _ = ParsedExpression.check_axis_name_return_reason(name) |
|
return is_valid |
|
|
|
|
|
def parse_pattern( |
|
pattern: str, axes_lengths: Mapping[str, int] |
|
) -> Tuple[ParsedExpression, ParsedExpression]: |
|
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object. |
|
|
|
Args: |
|
pattern (str): the `einops`-style rearrangement pattern |
|
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions |
|
|
|
Returns: |
|
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions |
|
""" |
|
|
|
|
|
try: |
|
left_str, right_str = pattern.split("->") |
|
except ValueError: |
|
raise ValueError("Pattern must contain a single '->' separator") from None |
|
|
|
if _ellipsis in axes_lengths: |
|
raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier") |
|
|
|
left = ParsedExpression(left_str) |
|
right = ParsedExpression(right_str) |
|
|
|
if not left.has_ellipsis and right.has_ellipsis: |
|
raise ValueError( |
|
f"Ellipsis found in right side, but not left side of a pattern {pattern}" |
|
) |
|
if left.has_ellipsis and left.has_ellipsis_parenthesized: |
|
raise ValueError( |
|
f"Ellipsis is parenthesis in the left side is not allowed: {pattern}" |
|
) |
|
|
|
return left, right |
|
|
|
|
|
def validate_rearrange_expressions( |
|
left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int] |
|
) -> None: |
|
"""Perform expression validations that are specific to the `rearrange` operation. |
|
|
|
Args: |
|
left (ParsedExpression): left-hand side expression |
|
right (ParsedExpression): right-hand side expression |
|
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions |
|
""" |
|
for length in axes_lengths.values(): |
|
if (length_type := type(length)) is not int: |
|
raise TypeError( |
|
f"rearrange axis lengths must be integers, got: {length_type}" |
|
) |
|
|
|
if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes: |
|
raise ValueError("rearrange only supports unnamed axes of size 1") |
|
|
|
difference = set.symmetric_difference(left.identifiers, right.identifiers) |
|
if len(difference) > 0: |
|
raise ValueError( |
|
f"Identifiers only on one side of rearrange expression (should be on both): {difference}" |
|
) |
|
|
|
unmatched_axes = axes_lengths.keys() - left.identifiers |
|
if len(unmatched_axes) > 0: |
|
raise ValueError( |
|
f"Identifiers not found in rearrange expression: {unmatched_axes}" |
|
) |
|
|
|
|
|
def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str: |
|
"""Convert a collection of strings representing first class dims into a comma-separated string. |
|
|
|
Args: |
|
collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert |
|
|
|
Returns: |
|
str: the comma-separated string |
|
|
|
Examples: |
|
>>> comma_separate(('d0',)) |
|
'd0' |
|
|
|
>>> comma_separate(('d0', 'd1', 'd2', 'd3')) |
|
'd0, d1, d2, d3' |
|
|
|
>>> comma_separate([('d1', 'd4')]) |
|
'(d1, d4)' |
|
|
|
>>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')]) |
|
'(d0,), (), (d1,), (d2,), (d3, d4)' |
|
""" |
|
return ", ".join( |
|
item |
|
if isinstance(item, str) |
|
else f"({comma_separate(item)}{',' if len(item) == 1 else ''})" |
|
for item in collection |
|
) |
|
|