Optimum documentation

Optimization

You are viewing v1.3.0 version. A newer version v1.19.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Optimization

The optimum.fx.optimization module provides a set of torch.fx graph transformations, along with classes and functions to write your own transformations and compose them.

The transformation guide

In πŸ€— Optimum, there are two kinds of transformations: reversible and non-reversible transformations.

Write a non-reversible transformation

The most basic case of transformations is non-reversible transformations. Those transformations cannot be reversed, meaning that after applying them to a graph module, there is no way to get the original model back. To implement such transformations in πŸ€— Optimum, it is very easy: you just need to subclass Transformation and implement the transform() method.

For instance, the following transformation changes all the multiplications to additions:

>>> from optimum.fx.optimization import Transformation

>>> class ChangeMulToAdd(Transformation):
>>>     def transform(self, graph_module):
>>>         for node in graph_module.graph.nodes:
>>>             if node.op == "call_function" and node.target == operator.mul:
>>>                 node.target = operator.add
>>>         return graph_module

After implementing it, your transformation can be used as a regular function:

>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )

>>> transformation = ChangeMulToAdd()
>>> transformed_model = transformation(traced)

Write a reversible transformation

A reversible transformation implements both the transformation and its reverse, allowing to retrieve the original model from the transformed one. To implement such transformation, you need to subclass ReversibleTransformation and implement the transform() and reverse() methods.

For instance, the following transformation is reversible:

>>> from optimum.fx.optimization import ReversibleTransformation

>>> class MulToMulTimesTwo(ReversibleTransformation):
>>>     def transform(self, graph_module):
>>>         for node in graph_module.graph.nodes:
>>>             if node.op == "call_function" and node.target == operator.mul:
>>>                 x, y = node.args
>>>                 node.args = (2 * x, y)
>>>         return graph_module

>>>     def reverse(self, graph_module):
>>>         for node in graph_module.graph.nodes:
>>>             if node.op == "call_function" and node.target == operator.mul:
>>>                 x, y = node.args
>>>                 node.args = (x / 2, y)
>>>         return graph_module

Composing transformations together

As applying mutilple transformations in chain is needed more often that not, compose() is provided. It is an utility function that allows you to create a transformation by chaining multiple other transformations.

>>> from optimum.fx.optimization import compose
>>> composition = compose(MulToMulTimesTwo(), ChangeMulToAdd())

The Optimization API

Main classes and functions

class optimum.fx.optimization.Transformation

< >

( )

Parameters

  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

A torch.fx graph transformation.

It must implemement the transform() method, and be used as a callable.

__call__

< >

( graph_module: GraphModule lint_and_recompile: bool = True ) β†’ torch.fx.GraphModule

Parameters

  • graph_module (torch.fx.GraphModule) — The module to transform.
  • lint_and_recompile (bool, defaults to True) — Whether the transformed module should be linted and recompiled. This can be set to False when chaining transformations together to perform this operation only once.

Returns

torch.fx.GraphModule

The transformed module.

get_transformed_nodes

< >

( graph_module: GraphModule ) β†’ List[torch.fx.Node]

Parameters

  • graph_module (torch.fx.GraphModule) — The graph_module to get the nodes from.

Returns

List[torch.fx.Node]

Gives the list of nodes that were transformed by the transformation.

mark_as_transformed

< >

( node: Node )

Parameters

  • node (torch.fx.Node) — The node to mark as transformed.

Marks a node as transformed by this transformation.

transform

< >

( graph_module: GraphModule ) β†’ torch.fx.GraphModule

Parameters

  • graph_module (torch.fx.GraphModule) — The module to transform.

Returns

torch.fx.GraphModule

The transformed module.

transformed

< >

( node: Node ) β†’ bool

Parameters

  • node (torch.fx.Node) — The node to check.

Returns

bool

Specifies whether the node was transformed by this transformation or not.

class optimum.fx.optimization.ReversibleTransformation

< >

( )

Parameters

  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

A torch.fx graph transformation that is reversible.

It must implemement the transform() and reverse() methods, and be used as a callable.

__call__

< >

( graph_module: GraphModule lint_and_recompile: bool = True reverse: bool = False ) β†’ torch.fx.GraphModule

Parameters

  • graph_module (torch.fx.GraphModule) — The module to transform.
  • lint_and_recompile (bool, defaults to True) — Whether the transformed module should be linted and recompiled. This can be set to False when chaining transformations together to perform this operation only once.
  • reverse (bool, defaults to False) — If True, the reverse transformation is performed.

Returns

torch.fx.GraphModule

The transformed module.

mark_as_restored

< >

( node: Node )

Parameters

  • node (torch.fx.Node) — The node to mark as restored.

Marks a node as restored back to its original state.

reverse

< >

( graph_module: GraphModule ) β†’ torch.fx.GraphModule

Parameters

  • graph_module (torch.fx.GraphModule) — The module to transform.

Returns

torch.fx.GraphModule

The reverse transformed module.

optimum.fx.optimization.compose

< >

( *args: Transformation inplace: bool = True )

Parameters

  • args (Transformation) — The transformations to compose together.
  • inplace (bool, defaults to True) — Whether the resulting transformation should be inplace, or create a new graph module.

Composes a list of transformations together.

Example:

>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> composition = compose(ChangeTrueDivToMulByInverse(), MergeLinears())
>>> transformed_model = composition(traced)

Transformations

class optimum.fx.optimization.MergeLinears

< >

( )

Parameters

  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

Transformation that merges linear layers that take the same input into one big linear layer.

Example:

>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import MergeLinears

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> transformation = MergeLinears()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)

class optimum.fx.optimization.ChangeTrueDivToMulByInverse

< >

( )

Parameters

  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

Transformation that changes truediv nodes to multiplication by the inverse nodes when the denominator is static. For example, that is sometimes the case for the scaling factor in attention layers.

Example:

>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> transformation = ChangeTrueDivToMulByInverse()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)