Optimization
Transformation
class optimum.fx.optimization.Transformation
< source >( )
A torch.fx graph transformation.
It must implement the transform() method, and be used as a callable.
__call__
< source >(
graph_module: GraphModule
lint_and_recompile: bool = True
)
→
torch.fx.GraphModule
Parameters
-
graph_module (
torch.fx.GraphModule
) — The module to transform. -
lint_and_recompile (
bool
, defaults toTrue
) — Whether the transformed module should be linted and recompiled. This can be set toFalse
when chaining transformations together to perform this operation only once.
Returns
torch.fx.GraphModule
The transformed module.
get_transformed_nodes
< source >(
graph_module: GraphModule
)
→
List[torch.fx.Node]
mark_as_transformed
< source >( node: Node )
Marks a node as transformed by this transformation.
transform
< source >(
graph_module: GraphModule
)
→
torch.fx.GraphModule
transformed
< source >(
node: Node
)
→
bool
Reversible transformation
class optimum.fx.optimization.ReversibleTransformation
< source >( )
A torch.fx graph transformation that is reversible.
It must implement the transform() and reverse() methods, and be used as a callable.
__call__
< source >(
graph_module: GraphModule
lint_and_recompile: bool = True
reverse: bool = False
)
→
torch.fx.GraphModule
Parameters
-
graph_module (
torch.fx.GraphModule
) — The module to transform. -
lint_and_recompile (
bool
, defaults toTrue
) — Whether the transformed module should be linted and recompiled. This can be set toFalse
when chaining transformations together to perform this operation only once. -
reverse (
bool
, defaults toFalse
) — IfTrue
, the reverse transformation is performed.
Returns
torch.fx.GraphModule
The transformed module.
mark_as_restored
< source >( node: Node )
Marks a node as restored back to its original state.
reverse
< source >(
graph_module: GraphModule
)
→
torch.fx.GraphModule
optimum.fx.optimization.compose
< source >( *args: Transformation inplace: bool = True )
Parameters
- args (Transformation) — The transformations to compose together.
-
inplace (
bool
, defaults toTrue
) — Whether the resulting transformation should be inplace, or create a new graph module.
Composes a list of transformations together.
Example:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "token_type_ids"],
... )
>>> composition = compose(ChangeTrueDivToMulByInverse(), MergeLinears())
>>> transformed_model = composition(traced)
Transformations
class optimum.fx.optimization.MergeLinears
< source >( )
Transformation that merges linear layers that take the same input into one big linear layer.
Example:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import MergeLinears
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "token_type_ids"],
... )
>>> transformation = MergeLinears()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)
class optimum.fx.optimization.FuseBiasInLinear
< source >( )
Transformation that fuses the bias to the weight in torch.nn.Linear.
Example:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import FuseBiasInLinear
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "token_type_ids"],
... )
>>> transformation = FuseBiasInLinear()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)
class optimum.fx.optimization.ChangeTrueDivToMulByInverse
< source >( )
Transformation that changes truediv nodes to multiplication by the inverse nodes when the denominator is static. For example, that is sometimes the case for the scaling factor in attention layers.
Example:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "token_type_ids"],
... )
>>> transformation = ChangeTrueDivToMulByInverse()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)
class optimum.fx.optimization.FuseBatchNorm2dInConv2d
< source >( )
Transformation that fuses nn.BatchNorm2d
following nn.Conv2d
into a single nn.Conv2d
.
The fusion will be done only if the convolution has the batch normalization as sole following node.
For example, fusion will not be done in the case
Example:
>>> from transformers.utils.fx import symbolic_trace
>>> from transformers import AutoModelForImageClassification
>>> from optimum.fx.optimization import FuseBatchNorm2dInConv2d
>>> model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
>>> model.eval()
>>> traced_model = symbolic_trace(
... model,
... input_names=["pixel_values"],
... disable_check=True
... )
>>> transformation = FuseBatchNorm2dInConv2d()
>>> transformed_model = transformation(traced_model)
class optimum.fx.optimization.FuseBatchNorm1dInLinear
< source >( )
Transformation that fuses nn.BatchNorm1d
following or preceding nn.Linear
into a single nn.Linear
.
The fusion will be done only if the linear layer has the batch normalization as sole following node, or the batch normalization
has the linear layer as sole following node.
For example, fusion will not be done in the case
Example:
>>> from transformers.utils.fx import symbolic_trace
>>> from transformers import AutoModel
>>> from optimum.fx.optimization import FuseBatchNorm1dInLinear
>>> model = AutoModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
>>> model.eval()
>>> traced_model = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "pixel_values"],
... disable_check=True
... )
>>> transformation = FuseBatchNorm1dInLinear()
>>> transformed_model = transformation(traced_model)