|
|
|
|
|
|
|
from typing import (Any, Dict, List) |
|
|
|
from . import debug_configs |
|
|
|
__all__ = ['Operation', 'Cell'] |
|
|
|
|
|
def _convert_name(name: str) -> str: |
|
""" |
|
Convert the names using separator '.' to valid variable name in code |
|
""" |
|
return name.replace('.', '__') |
|
|
|
|
|
class Operation: |
|
""" |
|
Calculation logic of a graph node. |
|
|
|
The constructor is private. Use `Operation.new()` to create operation object. |
|
|
|
`Operation` is a naive record. |
|
Do not "mutate" its attributes or store information relate to specific node. |
|
All complex logic should be implemented in `Node` class. |
|
|
|
Attributes |
|
---------- |
|
type |
|
Operation type name (e.g. Conv2D). |
|
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output). |
|
parameters |
|
Arbitrary key-value parameters (e.g. kernel_size). |
|
""" |
|
|
|
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False): |
|
assert _internal, '`Operation()` is private, use `Operation.new()` instead' |
|
self.type: str = type_name |
|
self.parameters: Dict[str, Any] = parameters |
|
|
|
def to_init_code(self, field: str) -> str: |
|
raise NotImplementedError() |
|
|
|
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: |
|
raise NotImplementedError() |
|
|
|
def _to_class_name(self) -> str: |
|
raise NotImplementedError() |
|
|
|
def __bool__(self) -> bool: |
|
return True |
|
|
|
@staticmethod |
|
def new(type_name: str, parameters: Dict[str, Any] = {}, cell_name: str = None) -> 'Operation': |
|
if type_name == '_cell': |
|
|
|
return Cell(cell_name, parameters) |
|
else: |
|
if debug_configs.framework.lower() in ('torch', 'pytorch'): |
|
from .operation_def import torch_op_def |
|
cls = PyTorchOperation._find_subclass(type_name) |
|
elif debug_configs.framework.lower() in ('tf', 'tensorflow'): |
|
from .operation_def import tf_op_def |
|
cls = TensorFlowOperation._find_subclass(type_name) |
|
else: |
|
raise ValueError(f'Unsupported framework: {debug_configs.framework}') |
|
return cls(type_name, parameters, _internal=True) |
|
|
|
@classmethod |
|
def _find_subclass(cls, subclass_name): |
|
for subclass in cls.__subclasses__(): |
|
if subclass.__name__ == subclass_name: |
|
return subclass |
|
return cls |
|
|
|
def __repr__(self): |
|
type_name = type(self).__name__ |
|
args = [f'{key}={repr(value)}' for key, value in self.parameters.items()] |
|
if type_name != self.type: |
|
args = [f'type="{self.type}"'] + args |
|
return f'{type_name}({", ".join(args)})' |
|
|
|
def __eq__(self, other): |
|
return type(other) is type(self) and other.type == self.type and other.parameters == self.parameters |
|
|
|
|
|
class PyTorchOperation(Operation): |
|
@classmethod |
|
def _find_subclass(cls, subclass_name): |
|
if cls.to_class_name(subclass_name) is not None: |
|
subclass_name = 'ModuleOperator' |
|
if cls.is_functional(subclass_name): |
|
subclass_name = 'FunctionalOperator' |
|
for subclass in cls.__subclasses__(): |
|
if hasattr(subclass, '_ori_type_name') and \ |
|
subclass_name in subclass._ori_type_name: |
|
return subclass |
|
return cls |
|
|
|
@classmethod |
|
def to_class_name(cls, type_name) -> str: |
|
if type_name.startswith('__torch__.'): |
|
return type_name[len('__torch__.'):] |
|
elif type_name.startswith('__mutated__.'): |
|
return type_name[len('__mutated__.'):] |
|
else: |
|
return None |
|
|
|
@classmethod |
|
def is_functional(cls, type_name) -> bool: |
|
return type_name.startswith('Function.') |
|
|
|
def _to_class_name(self) -> str: |
|
if self.type.startswith('__torch__.'): |
|
return self.type[len('__torch__.'):] |
|
elif self.type.startswith('__mutated__.'): |
|
return self.type[len('__mutated__.'):] |
|
else: |
|
return None |
|
|
|
def get_import_pkg(self) -> str: |
|
if self.type.startswith('__torch__.'): |
|
return self.type[len('__torch__.'):].split('.')[0] |
|
elif self.type.startswith('__mutated__.'): |
|
return self.type[len('__mutated__.'):].split('.')[0] |
|
else: |
|
return None |
|
|
|
def to_init_code(self, field: str) -> str: |
|
if self._to_class_name() is not None: |
|
assert 'positional_args' not in self.parameters |
|
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items()) |
|
return f'self.{field} = {self._to_class_name()}({kw_params})' |
|
return None |
|
|
|
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: |
|
""" |
|
Parameters |
|
---------- |
|
field : str |
|
the name of member submodule |
|
output : str |
|
the output name (lvalue) of this line of code |
|
inputs : List[str] |
|
variables used in this line of code |
|
inputs_value : List[Any] |
|
some variables are actually constant, their real values are recorded in ```inputs_value```. |
|
if not constant, we simply put None at the corresponding index |
|
|
|
Returns |
|
------- |
|
str |
|
generated code line |
|
""" |
|
if self.type == 'aten::slice': |
|
raise RuntimeError('not supposed to have aten::slice operation') |
|
else: |
|
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}') |
|
|
|
|
|
class TensorFlowOperation(Operation): |
|
def _to_class_name(self) -> str: |
|
return 'K.layers.' + self.type |
|
|
|
|
|
class Cell(PyTorchOperation): |
|
""" |
|
TODO: this is pytorch cell |
|
|
|
An operation reference to a subgraph. |
|
|
|
Example code: |
|
``` |
|
def __init__(...): |
|
... |
|
self.cell = CustomCell(...) |
|
self.relu = K.layers.ReLU() |
|
... |
|
|
|
def forward(...): |
|
... |
|
x = self.cell(x) |
|
... |
|
``` |
|
|
|
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`. |
|
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`. |
|
|
|
TODO: parameters of subgraph (see `Node` class) |
|
|
|
Attributes |
|
---------- |
|
type |
|
Always "_cell". |
|
parameters |
|
A dict with only one item; the key is "cell" and the value is cell's name. |
|
framework |
|
No real usage. Exists for compatibility with base class. |
|
""" |
|
|
|
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}): |
|
self.type = '_cell' |
|
self.cell_name = cell_name |
|
self.parameters = parameters |
|
|
|
def _to_class_name(self): |
|
|
|
return _convert_name(self.cell_name) |
|
|
|
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: |
|
return f'{output} = self.{field}({", ".join(inputs)})' |
|
|
|
class _IOPseudoOperation(Operation): |
|
""" |
|
This is the pseudo operation used by I/O nodes. |
|
The benefit is that users no longer need to verify `Node.operation is not None`, |
|
especially in static type checking. |
|
""" |
|
|
|
def __init__(self, type_name: str, io_names: List = None): |
|
assert type_name.startswith('_') |
|
super(_IOPseudoOperation, self).__init__(type_name, {}, True) |
|
self.io_names = io_names |
|
|
|
def to_init_code(self, field: str) -> str: |
|
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"') |
|
|
|
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: |
|
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"') |
|
|
|
def __bool__(self) -> bool: |
|
return False |
|
|