| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """This module contains the user- and codegen-facing API for DiastaticMalt.""" |
|
|
| import functools |
| import importlib |
| import inspect |
| import os |
| import sys |
| import textwrap |
| import traceback |
|
|
| from malt import operators |
| from malt import utils |
| from malt.converters import asserts |
| from malt.converters import break_statements |
| from malt.converters import call_trees |
| from malt.converters import conditional_expressions |
| from malt.converters import continue_statements |
| from malt.converters import control_flow |
| from malt.converters import directives |
| from malt.converters import functions |
| from malt.converters import lists |
| from malt.converters import logical_expressions |
| from malt.converters import return_statements |
| from malt.converters import slices |
| from malt.converters import variables |
| from malt.core import ag_ctx |
| from malt.core import converter |
| from malt.core import unsupported_features_checker |
| from malt.impl import conversion |
| from malt.lang import special_functions |
| from malt.operators import py_builtins |
| from malt.pyct import anno |
| from malt.pyct import cfg |
| from malt.pyct import error_utils |
| from malt.pyct import errors |
| from malt.pyct import inspect_utils |
| from malt.pyct import qual_names |
| from malt.pyct import transpiler |
| from malt.pyct.static_analysis import activity |
| from malt.pyct.static_analysis import reaching_definitions |
| from malt.utils import ag_logging as logging |
|
|
|
|
| def is_autograph_strict_conversion_mode(): |
| return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0 |
|
|
|
|
| |
| |
| |
|
|
|
|
| |
| class AutoGraphError(errors.PyCTError): |
| """Base class for all AutoGraph exceptions.""" |
| pass |
|
|
|
|
| class ConversionError(AutoGraphError): |
| """Raised during the conversion process.""" |
| pass |
|
|
|
|
| class StagingError(AutoGraphError): |
| """Raised during the staging (i.e. Python execution) of converted code.""" |
| pass |
|
|
|
|
| class _ErrorMetadata(error_utils.ErrorMetadataBase): |
| """AutoGraph-specific error metadata. See base class.""" |
|
|
| def create_exception(self, source_error): |
| preferred_type = type(source_error) |
| |
| if preferred_type in (errors.PyCTError, AutoGraphError, ConversionError, StagingError): |
| return preferred_type(self.get_message()) |
|
|
| exc = super(_ErrorMetadata, self).create_exception(source_error) |
| if exc is not None: |
| return exc |
|
|
| |
| |
| |
| |
| |
| return StagingError(self.get_message()) |
|
|
|
|
| def _attach_error_metadata(e, f): |
| """Augments an error with the metadata necessary for rewrite.""" |
| if hasattr(e, 'ag_pass_through'): |
| return |
|
|
| metadata = getattr(e, 'ag_error_metadata', None) |
| source_map = f.ag_source_map |
|
|
| if metadata is None: |
| logging.log(1, 'Caught error in user callable %s', f, exc_info=True) |
| message = '{}: {}'.format(e.__class__.__name__, e) |
| else: |
| message = None |
|
|
| cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] |
|
|
| e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map, |
| __file__) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class PyToPy(transpiler.PyToPy): |
| """A generic AutoGraph transformer to subclass from or replace.""" |
|
|
| def __init__(self): |
| super(PyToPy, self).__init__() |
| self._extra_locals = None |
|
|
| def get_transformed_name(self, node): |
| return 'ag__' + super(PyToPy, self).get_transformed_name(node) |
|
|
| def get_extra_locals(self): |
| if self._extra_locals is None: |
| |
| |
| |
| module_spec = importlib.machinery.ModuleSpec('malt', None) |
| ag_internal = importlib.util.module_from_spec(module_spec) |
| ag_internal.__dict__.update(inspect.getmodule(PyToPy).__dict__) |
| ag_internal.ConversionOptions = converter.ConversionOptions |
| ag_internal.STD = converter.STANDARD_OPTIONS |
| ag_internal.Feature = converter.Feature |
| ag_internal.utils = utils |
| |
| |
| |
| ag_internal.__dict__.update(special_functions.__dict__) |
| ag_internal.__dict__.update(operators.__dict__) |
|
|
| self._extra_locals = {'ag__': ag_internal} |
| return self._extra_locals |
|
|
| def get_caching_key(self, ctx): |
| return ctx.options |
|
|
| def initial_analysis(self, node, ctx): |
| graphs = cfg.build(node) |
| node = qual_names.resolve(node) |
| node = activity.resolve(node, ctx, None) |
| node = reaching_definitions.resolve(node, ctx, graphs) |
| anno.dup( |
| node, |
| { |
| anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, |
| }, |
| ) |
| return node |
|
|
| def transform_ast(self, node, ctx): |
| unsupported_features_checker.verify(node) |
| node = self.initial_analysis(node, ctx) |
|
|
| node = functions.transform(node, ctx) |
| node = directives.transform(node, ctx) |
| node = break_statements.transform(node, ctx) |
| if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): |
| node = asserts.transform(node, ctx) |
| |
| |
| |
| node = continue_statements.transform(node, ctx) |
| node = return_statements.transform(node, ctx) |
| if ctx.user.options.uses(converter.Feature.LISTS): |
| node = lists.transform(node, ctx) |
| node = slices.transform(node, ctx) |
| node = call_trees.transform(node, ctx) |
| node = control_flow.transform(node, ctx) |
| node = conditional_expressions.transform(node, ctx) |
| node = logical_expressions.transform(node, ctx) |
| node = variables.transform(node, ctx) |
| return node |
|
|
|
|
| def _convert_actual(entity, program_ctx): |
| """Applies AutoGraph to entity.""" |
|
|
| |
| if not hasattr(entity, '__code__'): |
| raise ValueError('Cannot apply autograph to a function that doesn\'t ' |
| 'expose a __code__ object. If this is a @tf.function,' |
| ' try passing f.python_function instead.') |
|
|
| transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx) |
|
|
| assert not hasattr(transformed, 'ag_module') |
| assert not hasattr(transformed, 'ag_source_map') |
| transformed.ag_module = module |
| transformed.ag_source_map = source_map |
| return transformed |
|
|
|
|
| |
| |
| |
|
|
|
|
| def autograph_artifact(entity, extras=None): |
| if inspect.ismethod(entity): |
| setattr(entity.__func__, 'autograph_info__', extras) |
| else: |
| setattr(entity, 'autograph_info__', extras) |
| return entity |
|
|
|
|
| def is_autograph_artifact(entity): |
| return hasattr(entity, 'autograph_info__') |
|
|
|
|
| def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): |
| """Converts a function call inline. |
| |
| For internal use only. |
| |
| Note: The argument list is optimized for readability of generated code, which |
| may look like this: |
| |
| ag__.converted_call(f, (arg1, arg2), None, fscope) |
| ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) |
| ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) |
| |
| Args: |
| f: The function to convert. |
| args: Tuple, the original positional arguments of f |
| kwargs: Optional[Dict], the original keyword arguments of f |
| caller_fn_scope: Optional[function_wrappers.FunctionScope], the function |
| scope of the converted function in which this call was originally made. |
| options: Optional[converter.ConversionOptions], conversion options. If not |
| specified, the value of caller_fn_scope.callopts is used. Either options |
| or caller_fn_scope must be present. |
| |
| Returns: |
| Any, the result of executing a possibly-converted `f` with the given |
| arguments. |
| """ |
| logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, |
| kwargs) |
|
|
| if options is None: |
| if caller_fn_scope is None: |
| raise ValueError('either caller_fn_scope or options must have a value') |
| options = caller_fn_scope.callopts |
|
|
| if conversion.is_in_allowlist_cache(f, options): |
| logging.log(2, 'Allowlisted %s: from cache', f) |
| return _call_unconverted(f, args, kwargs, options, False) |
|
|
| if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: |
| logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f) |
| return _call_unconverted(f, args, kwargs, options, False) |
|
|
| if is_autograph_artifact(f): |
| logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) |
| return _call_unconverted(f, args, kwargs, options) |
|
|
| |
| if isinstance(f, functools.partial): |
| new_kwargs = {} |
| if f.keywords is not None: |
| |
| new_kwargs = f.keywords.copy() |
| if kwargs is not None: |
| new_kwargs.update(kwargs) |
| new_args = f.args + args |
| logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, |
| new_kwargs) |
| return converted_call( |
| f.func, |
| new_args, |
| new_kwargs, |
| caller_fn_scope=caller_fn_scope, |
| options=options) |
|
|
| if inspect_utils.isbuiltin(f): |
| if f is eval: |
| return py_builtins.eval_in_original_context(f, args, caller_fn_scope) |
| if f is super: |
| return py_builtins.super_in_original_context(f, args, caller_fn_scope) |
| if f is globals: |
| return py_builtins.globals_in_original_context(caller_fn_scope) |
| if f is locals: |
| return py_builtins.locals_in_original_context(caller_fn_scope) |
| if kwargs: |
| return py_builtins.overload_of(f)(*args, **kwargs) |
| else: |
| return py_builtins.overload_of(f)(*args) |
|
|
| if conversion.is_unsupported(f): |
| return _call_unconverted(f, args, kwargs, options) |
|
|
| if not options.user_requested and conversion.is_allowlisted(f): |
| return _call_unconverted(f, args, kwargs, options) |
|
|
| |
| |
| |
| |
| if not options.internal_convert_user_code: |
| return _call_unconverted(f, args, kwargs, options) |
|
|
| try: |
| if inspect.ismethod(f) or inspect.isfunction(f): |
| target_entity = f |
| effective_args = args |
|
|
| f_self = getattr(f, '__self__', None) |
| if f_self is not None: |
| |
| effective_args = (f_self,) + effective_args |
|
|
| elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): |
| |
| |
| |
| |
| target_entity = f.__class__.__call__ |
| effective_args = (f,) + args |
|
|
| else: |
| target_entity = f |
| raise NotImplementedError('unknown callable type "%s"' % type(f)) |
|
|
| except Exception as e: |
| logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) |
| if is_autograph_strict_conversion_mode(): |
| raise |
| return _fall_back_unconverted(f, args, kwargs, options, e) |
|
|
| if not hasattr(target_entity, '__code__'): |
| logging.log(2, 'Permanently allowed: %s: native binding', target_entity) |
| return _call_unconverted(f, args, kwargs, options) |
| elif (hasattr(target_entity.__code__, 'co_filename') and |
| target_entity.__code__.co_filename == '<string>'): |
| |
| logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)', |
| target_entity) |
| return _call_unconverted(f, args, kwargs, options) |
|
|
| try: |
| program_ctx = converter.ProgramContext(options=options) |
| converted_f = _convert_actual(target_entity, program_ctx) |
| if logging.has_verbosity(2): |
| _log_callargs(converted_f, effective_args, kwargs) |
| except Exception as e: |
| logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) |
| if is_autograph_strict_conversion_mode(): |
| raise |
| return _fall_back_unconverted(f, args, kwargs, options, e) |
|
|
| |
| try: |
| if kwargs is not None: |
| result = converted_f(*effective_args, **kwargs) |
| else: |
| result = converted_f(*effective_args) |
| except Exception as e: |
| _attach_error_metadata(e, converted_f) |
| raise |
|
|
| return result |
|
|
|
|
| def _call_unconverted(f, args, kwargs, options, update_cache=True): |
| """Calls the original function without converting with AutoGraph.""" |
| if update_cache: |
| conversion.cache_allowlisted(f, options) |
|
|
| |
|
|
| if kwargs is not None: |
| return f(*args, **kwargs) |
| return f(*args) |
|
|
|
|
| def _fall_back_unconverted(f, args, kwargs, options, exc): |
| """Falls back to calling the function unconverted, in case of error.""" |
| |
| warning_template = ( |
| 'AutoGraph could not transform %s and will run it as-is.\n' |
| '%s' |
| 'Cause: %s\n' |
| 'To silence this warning, decorate the function with' |
| ' @tf.autograph.experimental.do_not_convert') |
| if isinstance(exc, errors.InaccessibleSourceCodeError): |
| if ag_ctx.INSPECT_SOURCE_SUPPORTED: |
| logging.warning(warning_template, f, '', exc) |
| elif isinstance(exc, errors.UnsupportedLanguageElementError): |
| if not conversion.is_in_allowlist_cache(f, options): |
| logging.warning(warning_template, f, '', exc) |
| else: |
| file_bug_message = ( |
| 'Please report this to the TensorFlow team. When filing the bug, set' |
| ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' |
| ' attach the full output.\n') |
| logging.warning(warning_template, f, file_bug_message, exc) |
|
|
| return _call_unconverted(f, args, kwargs, options) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def internal_convert(f, ctx, convert_by_default=True, user_requested=False): |
| """Decorator that applies AutoGraph to a function. |
| |
| Use in internal APIs. |
| |
| This API is suitable for high order functions internal to the TensorFlow API, |
| and more generally any function to which AutoGraph is not applied. |
| |
| Guidance: `convert` was a decorator meant for use directly by developers, but |
| most of today's uses go through `tf.function`. `tf_convert` is to be called |
| from high order functions internal to TF. By default, all the internal |
| TensorFlow functions are skipped when AutoGraph processes the code. This may |
| lead to user-supplied functions to be incorrectly skipped as well. |
| `tf_convert` helps avoid that. See the following example for more details. |
| |
| ``` |
| =====tf_internal_module.py===== |
| |
| def unconverted(input_fn): |
| return input_fn() |
| |
| def converted(input_fn): |
| return tf.__internal__.autograph.tf_convert( |
| input_fn, ctx=tf.__internal__.autograph.control_status_ctx())() |
| |
| ======user_module.py====== |
| |
| @tf.function |
| def foo(input_fn) |
| return unconverted(input_fn) |
| |
| @tf.function |
| def bar(input_fn) |
| return converted(input_fn) |
| |
| @tf.function(autograph=False) |
| def baz(input_fn) |
| return converted(input_fn) |
| ``` |
| |
| The `foo` method above will execute the `input_fn` without autograph |
| conversion, while the `bar` method will run an autographed `input_fn`. The |
| `baz` method will run an unconverted `input_fn`, since `tf_convert` respect |
| the control status context. |
| |
| Note that both methods in `tf_internal_module` are skipped by autograph when |
| tracing the `tf.function`. The configuration of whether a module/package |
| should be skipped by autograph is controlled in |
| tensorflow/python/autograph/core/config.py. |
| |
| Args: |
| f: Callable. |
| ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. |
| convert_by_default: bool, whether to use AutoGraph when the context doesn't |
| specify. |
| user_requested: bool, whether to ignore the conversion allowlist. See |
| ConversionOptions.user_requested. |
| |
| Returns: |
| Either `f or the converted version of `f`. |
| """ |
|
|
| if is_autograph_artifact(f): |
| return f |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if ctx.status == ag_ctx.Status.ENABLED: |
| wrapper_factory = convert( |
| recursive=True, user_requested=user_requested, conversion_ctx=ctx) |
| elif ctx.status == ag_ctx.Status.DISABLED: |
| wrapper_factory = do_not_convert |
| elif ctx.status == ag_ctx.Status.UNSPECIFIED: |
| if convert_by_default: |
| wrapper_factory = convert( |
| recursive=True, user_requested=user_requested, conversion_ctx=ctx) |
| else: |
| wrapper_factory = call_with_unspecified_conversion_status |
| else: |
| assert False, 'This switch contains all possible cases!' |
| wrapper = wrapper_factory(f) |
|
|
| |
|
|
| return autograph_artifact(wrapper) |
|
|
|
|
| def call_with_unspecified_conversion_status(func): |
| """Decorator that resets the conversion context to the unspecified status.""" |
|
|
| def wrapper(*args, **kwargs): |
| with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): |
| return func(*args, **kwargs) |
|
|
| if inspect.isfunction(func) or inspect.ismethod(func): |
| wrapper = functools.update_wrapper(wrapper, func) |
|
|
| return autograph_artifact(wrapper) |
|
|
|
|
| def _log_callargs(f, args, kwargs): |
| """Logging helper.""" |
| logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) |
| logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) |
|
|
| |
| if kwargs is not None: |
| callargs = inspect.getcallargs(f, *args, **kwargs) |
| else: |
| callargs = inspect.getcallargs(f, *args) |
|
|
| formatted_callargs = '\n'.join( |
| ' {}: {}'.format(k, v) for k, v in callargs.items()) |
| logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def do_not_convert(func=None): |
| """Decorator that suppresses the conversion of a function. |
| |
| Args: |
| func: function to decorate. |
| |
| Returns: |
| If `func` is not None, returns a `Callable` which is equivalent to |
| `func`, but is not converted by AutoGraph. |
| If `func` is None, returns a decorator that, when invoked with a |
| single `func` argument, returns a `Callable` equivalent to the |
| above case. |
| """ |
| if func is None: |
| return do_not_convert |
|
|
| def wrapper(*args, **kwargs): |
| with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): |
| return func(*args, **kwargs) |
|
|
| if inspect.isfunction(func) or inspect.ismethod(func): |
| wrapper = functools.update_wrapper(wrapper, func) |
|
|
| return autograph_artifact(wrapper) |
|
|
|
|
| |
| def convert(recursive=False, |
| optional_features=None, |
| user_requested=True, |
| conversion_ctx=ag_ctx.NullCtx()): |
| """Decorator that compiles a function to use AutoGraph operators. |
| |
| The decorator is dynamic - it recompiles the target whenever the decorated |
| function is called. This means the parameter values are known at conversion. |
| It also means that repeated calls with different types of parameters will be |
| correctly processed. |
| |
| Args: |
| recursive: bool, whether to recursively convert any functions or classes |
| that the converted function may use. |
| optional_features: converted.Feature, allows toggling optional or |
| experimental features. When set to None, only the core features are |
| enabled. |
| user_requested: bool, whether this is a function that the user explicitly |
| asked to be converted. See ConversionOptions.user_requested. |
| conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in |
| which `f` is used. |
| |
| Returns: |
| Callable, a decorator that converts the given function into an equivalent |
| function that uses TensorFlow ops. |
| """ |
|
|
| def decorator(f): |
| """Decorator implementation.""" |
|
|
| def wrapper(*args, **kwargs): |
| """Wrapper that calls the converted version of f.""" |
| options = converter.ConversionOptions( |
| recursive=recursive, |
| user_requested=user_requested, |
| optional_features=optional_features) |
| try: |
| with conversion_ctx: |
| return converted_call(f, args, kwargs, options=options) |
| except Exception as e: |
| if hasattr(e, 'ag_error_metadata'): |
| raise e.ag_error_metadata.to_exception(e) |
| else: |
| raise |
|
|
| if inspect.isfunction(f) or inspect.ismethod(f): |
| wrapper = functools.update_wrapper(wrapper, f) |
|
|
| |
| return autograph_artifact(wrapper) |
|
|
| return decorator |
|
|
|
|
| |
| def to_graph(entity, recursive=True, experimental_optional_features=None): |
| """Converts a Python entity into a "Auto-"graph. |
| |
| Also see: `malt.to_code`. |
| |
| Unlike `tf.function`, `to_graph` is a low-level transpiler that converts |
| Python code to TensorFlow graph code. It does not implement any caching, |
| variable management or create any actual ops, and is best used where greater |
| control over the generated TensorFlow graph is desired. Another difference |
| from `tf.function` is that `to_graph` will not wrap the graph into a |
| TensorFlow function or a Python callable. Internally, `tf.function` uses |
| `to_graph`. |
| |
| Example usage: |
| |
| >>> def f(x): |
| ... if x > 0: |
| ... y = x * x |
| ... else: |
| ... y = -x |
| ... return y |
| ... |
| >>> converted_f = to_graph(f) |
| >>> x = tf.constant(2) |
| >>> converted_f(x) # converted_foo is like a TensorFlow Op. |
| <tf.Tensor: shape=(), dtype=int32, numpy=4> |
| |
| Supported Python entities include: |
| * functions |
| * classes |
| * object methods |
| |
| Functions are converted into new functions with converted code. |
| |
| Classes are converted by generating a new class whose methods use converted |
| code. |
| |
| Methods are converted into unbound function that have an additional first |
| argument called `self`. |
| |
| For a tutorial, see the |
| [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function). |
| For more detailed information, see the |
| [reference documentation](https://github.com/pennylaneai/diastatic-malt/blob/main/malt/g3doc/reference/index.md). |
| |
| Args: |
| entity: Python callable or class to convert. |
| recursive: Whether to recursively convert any functions that the converted |
| function may call. |
| experimental_optional_features: `None`, a tuple of, or a single |
| `tf.autograph.experimental.Feature` value. |
| |
| Returns: |
| Same as `entity`, the converted Python function or class. |
| |
| Raises: |
| ValueError: If the entity could not be converted. |
| """ |
| try: |
| program_ctx = converter.ProgramContext( |
| options=converter.ConversionOptions( |
| recursive=recursive, |
| user_requested=True, |
| optional_features=experimental_optional_features)) |
| return autograph_artifact(_convert_actual(entity, program_ctx)) |
| except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: |
| logging.error(1, 'Error converting %s', entity, exc_info=True) |
| raise ConversionError('converting {}: {}: {}'.format( |
| entity, e.__class__.__name__, str(e))) |
|
|
|
|
| def to_code(entity, recursive=True, experimental_optional_features=None): |
| """Returns the source code generated by DiastaticMalt, as a string. |
| |
| Example usage: |
| |
| >>> def f(x): |
| ... if x < 0: |
| ... x = -x |
| ... return x |
| >>> malt.to_code(f) |
| "...def tf__f(x):..." |
| |
| Also see: `malt.to_graph`. |
| |
| Note: If a function has been decorated with `tf.function`, pass its |
| underlying Python function, rather than the callable that `tf.function |
| creates: |
| |
| >>> @tf.function |
| ... def f(x): |
| ... if x < 0: |
| ... x = -x |
| ... return x |
| >>> malt.to_code(f.python_function) |
| "...def tf__f(x):..." |
| |
| Args: |
| entity: Python callable or class to convert. |
| recursive: Whether to recursively convert any functions that the converted |
| function may call. |
| experimental_optional_features: `None`, a tuple of, or a single |
| `malt.experimental.Feature` value. |
| |
| Returns: |
| The converted code as string. |
| """ |
| |
| source = inspect.getsource( |
| to_graph( |
| entity, |
| recursive=recursive, |
| experimental_optional_features=experimental_optional_features)) |
| return textwrap.dedent(source) |
|
|
|
|
| _TRANSPILER = PyToPy() |
|
|