|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import builtins |
|
import collections |
|
import functools |
|
import inspect |
|
import math |
|
import operator |
|
import os |
|
import random |
|
import warnings |
|
from typing import Any, Callable, Dict, List, Optional, Type, Union |
|
|
|
import torch |
|
from torch import nn |
|
from torch.fx import Graph, GraphModule, Proxy, Tracer |
|
from torch.fx._compatibility import compatibility |
|
from torch.fx.proxy import ParameterProxy |
|
|
|
from .. import PretrainedConfig, PreTrainedModel, logging |
|
from ..models.auto import get_values |
|
from ..models.auto.modeling_auto import ( |
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, |
|
MODEL_FOR_BACKBONE_MAPPING_NAMES, |
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, |
|
MODEL_FOR_CTC_MAPPING_NAMES, |
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES, |
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, |
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, |
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES, |
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, |
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES, |
|
MODEL_FOR_PRETRAINING_MAPPING_NAMES, |
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, |
|
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, |
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, |
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, |
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, |
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, |
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, |
|
MODEL_MAPPING_NAMES, |
|
) |
|
from ..utils import ( |
|
ENV_VARS_TRUE_VALUES, |
|
TORCH_FX_REQUIRED_VERSION, |
|
get_torch_version, |
|
is_peft_available, |
|
is_torch_fx_available, |
|
) |
|
|
|
|
|
if is_peft_available(): |
|
from peft import PeftModel |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES |
|
|
|
|
|
def _generate_supported_model_class_names( |
|
model_name: Type[PretrainedConfig], |
|
supported_tasks: Optional[Union[str, List[str]]] = None, |
|
) -> List[str]: |
|
task_mapping = { |
|
"default": MODEL_MAPPING_NAMES, |
|
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES, |
|
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES, |
|
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES, |
|
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, |
|
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, |
|
"speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, |
|
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, |
|
"document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES, |
|
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, |
|
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, |
|
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, |
|
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, |
|
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, |
|
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, |
|
"ctc": MODEL_FOR_CTC_MAPPING_NAMES, |
|
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, |
|
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, |
|
"backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES, |
|
} |
|
|
|
if supported_tasks is None: |
|
supported_tasks = task_mapping.keys() |
|
if isinstance(supported_tasks, str): |
|
supported_tasks = [supported_tasks] |
|
|
|
model_class_names = [] |
|
for task in supported_tasks: |
|
class_name = task_mapping[task].get(model_name, None) |
|
if class_name: |
|
model_class_names.append(class_name) |
|
|
|
return model_class_names |
|
|
|
|
|
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ |
|
"altclip", |
|
"albert", |
|
"bart", |
|
"bert", |
|
"blenderbot", |
|
"blenderbot-small", |
|
"bloom", |
|
"clip", |
|
"convnext", |
|
"deberta", |
|
"deberta-v2", |
|
"distilbert", |
|
"donut-swin", |
|
"electra", |
|
"gpt2", |
|
"gpt_neo", |
|
"gptj", |
|
"hubert", |
|
"layoutlm", |
|
"lxmert", |
|
"m2m_100", |
|
"marian", |
|
"mbart", |
|
"megatron-bert", |
|
"mobilebert", |
|
"mt5", |
|
"nezha", |
|
"opt", |
|
"pegasus", |
|
"plbart", |
|
"resnet", |
|
"roberta", |
|
"segformer", |
|
"speech_to_text", |
|
"speech_to_text_2", |
|
"swin", |
|
"t5", |
|
"trocr", |
|
"vit", |
|
"xglm", |
|
"wav2vec2", |
|
|
|
] |
|
|
|
_REGULAR_SUPPORTED_MODELS = [] |
|
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: |
|
if isinstance(item, dict): |
|
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item)) |
|
else: |
|
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item)) |
|
|
|
_SPECIAL_SUPPORTED_MODELS = [ |
|
"CLIPTextModel", |
|
"CLIPTextModelWithProjection", |
|
"CLIPVisionModel", |
|
"CLIPVisionModelWithProjection", |
|
"AltCLIPTextModel", |
|
"AltCLIPVisionModel", |
|
"GitVisionModel", |
|
"GPT2DoubleHeadsModel", |
|
"Speech2Text2Decoder", |
|
"TrOCRDecoder", |
|
"PeftModelForCausalLM", |
|
"PeftModelForSeq2SeqLM" |
|
|
|
|
|
] |
|
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS))) |
|
|
|
|
|
def torch_nn_embedding(self, input): |
|
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype) |
|
|
|
|
|
def torch_nn_functional_embedding( |
|
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False |
|
): |
|
return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype) |
|
|
|
|
|
def torch_nn_layernorm(self, input): |
|
return input |
|
|
|
|
|
def torch_nn_groupnorm(self, input): |
|
return input |
|
|
|
|
|
def torch_nn_linear(self, input): |
|
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") |
|
|
|
|
|
def torch_relu(x): |
|
return x |
|
|
|
|
|
def torch_nn_relu(self, x): |
|
return x |
|
|
|
|
|
def torch_nn_functional_relu(x, inplace=False): |
|
if not inplace: |
|
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis") |
|
return x |
|
|
|
|
|
def torch_where(condition, x, y): |
|
|
|
|
|
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") |
|
|
|
|
|
def torch_abs(input, *, out=None): |
|
if out is not None: |
|
raise ValueError("Don't support in-place abs for MetaTensor analysis") |
|
return input |
|
|
|
|
|
def torch_arange(*args, **kwargs): |
|
n = len(args) |
|
step = 1 |
|
if n == 1: |
|
start = 0 |
|
end = args[0] |
|
elif n == 2: |
|
start, end = args |
|
else: |
|
start, end, step = args |
|
if isinstance(start, float): |
|
start = int(start) |
|
if isinstance(end, float): |
|
start = int(end) |
|
if isinstance(step, float): |
|
step = int(step) |
|
step = kwargs.get("step", step) |
|
dtype = kwargs.get("dtype") |
|
return torch.empty((end - start) // step, dtype=dtype, device="meta") |
|
|
|
|
|
def torch_full(*args, **kwargs): |
|
args = list(args) |
|
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"): |
|
args[1] = 1 |
|
kwargs_without_device = dict(kwargs) |
|
kwargs_without_device.pop("device", None) |
|
return torch.full(*args, **kwargs_without_device) |
|
|
|
|
|
def torch_cat(tensors, dim=None, axis=None, *, out=None): |
|
if dim is None and axis is None: |
|
dim = 0 |
|
if dim is None and axis is not None: |
|
dim = axis |
|
if dim < 0: |
|
dim = tensors[0].dim() + dim |
|
shapes = [t.shape for t in tensors] |
|
shape = list(shapes[0]) |
|
concatenated_dim = sum(shape[dim] for shape in shapes) |
|
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :] |
|
return torch.empty(final_shape, device="meta") |
|
|
|
|
|
def torch_stack(tensors, dim=None, axis=None, *, out=None): |
|
if dim is None and axis is None: |
|
dim = 0 |
|
if dim is None and axis is not None: |
|
dim = axis |
|
if dim < 0: |
|
dim = tensors[0].dim() + 1 + dim |
|
shape = list(tensors[0].shape) |
|
shape.insert(dim, len(tensors)) |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_add(input, other, *, alpha=1, out=None): |
|
if not isinstance(input, torch.Tensor): |
|
return torch.empty_like(other, device="meta") |
|
if not isinstance(other, torch.Tensor): |
|
return torch.empty_like(input, device="meta") |
|
max_length = max(input.dim(), other.dim()) |
|
input_shape = list(input.shape) + [1] * (max_length - input.dim()) |
|
other_shape = list(other.shape) + [1] * (max_length - other.dim()) |
|
shape = [] |
|
for i in range(max_length): |
|
shape.append(max(input_shape[i], other_shape[i])) |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_mul(input, other, *, out=None): |
|
return torch_add(input, other, out=out) |
|
|
|
|
|
def torch_tensor_mul(self, other): |
|
return torch_mul(self, other) |
|
|
|
|
|
def torch_matmul(input, other, *, out=None): |
|
d1 = input.dim() |
|
d2 = other.dim() |
|
shape = None |
|
if d1 == 1 and d2 == 1: |
|
shape = None |
|
elif d1 == 2 and d2 == 2: |
|
shape = (input.size(0), other.size(1)) |
|
elif d1 == 1 and d2 == 2: |
|
shape = (other.size(1),) |
|
elif d1 == 2 and d1 == 1: |
|
shape = (input.size(0),) |
|
else: |
|
max_length = max(input.dim(), other.dim()) |
|
shape1 = list(input.shape) |
|
shape2 = list(other.shape) |
|
if d1 == 1: |
|
shape1 = [1] + shape1 |
|
if d2 == 1: |
|
shape2.append(1) |
|
shape1 = [-1] * (max_length - d1) + list(input.shape) |
|
shape2 = [-1] * (max_length - d2) + list(other.shape) |
|
shape = [] |
|
for i in range(max_length): |
|
shape.append(max(shape1[i], shape2[i])) |
|
shape[-2] = shape1[-2] |
|
shape[-1] = shape2[-1] |
|
if d1 == 1: |
|
shape.pop(-2) |
|
if d2 == 1: |
|
shape.pop(-1) |
|
if shape is None: |
|
return torch.tensor(0.0, device="meta") |
|
return torch.empty(*shape, device="meta") |
|
|
|
|
|
def torch_bmm(input, mat2, *, out=None): |
|
if out is not None: |
|
raise ValueError("Don't support in-place bmm for MetaTensor analysis") |
|
batch_size, n, m = input.shape |
|
_, _, p = mat2.shape |
|
return torch.empty(batch_size, n, p, device="meta") |
|
|
|
|
|
def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): |
|
if out is not None: |
|
raise ValueError("Don't support in-place baddbmm for MetaTensor analysis") |
|
return torch_bmm(batch1, batch2) |
|
|
|
|
|
def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None): |
|
return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out) |
|
|
|
|
|
def torch_einsum(equation, *operands): |
|
|
|
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands) |
|
return torch.einsum(equation, *concrete_operands).to("meta") |
|
|
|
|
|
def torch_tensor_repeat(self, *sizes): |
|
shape = list(self.shape) |
|
for i, x in enumerate(sizes): |
|
shape[i] *= x |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_repeat_interleave(*args, dim=None, output_size=None): |
|
num_args = len(args) |
|
if num_args == 1: |
|
shape = [output_size if output_size is not None else args[0].sum()] |
|
else: |
|
shape = list(args[0].shape) |
|
if dim is None: |
|
if num_args > 2: |
|
dim = args[2] |
|
else: |
|
shape = [sum(shape)] |
|
dim = 0 |
|
repeats = args[1] |
|
if isinstance(repeats, int) or torch.numel(repeats) == 1: |
|
shape[dim] *= int(repeats) |
|
else: |
|
shape[dim] = output_size if output_size is not None else repeats.sum() |
|
return torch.empty(*shape, device="meta") |
|
|
|
|
|
def torch_index_select(input, dim, index, *, out=None): |
|
shape = list(input.shape) |
|
shape[dim] = len(index) |
|
return torch.empty(*shape, device="meta") |
|
|
|
|
|
def torch_tensor_index_select(self, dim, index): |
|
return torch_index_select(self, dim, index) |
|
|
|
|
|
def torch_gather(input, dim, index, *, sparse_grad=False, out=None): |
|
shape = list(input.shape) |
|
shape[dim] = index.shape[dim] |
|
return torch.empty(*shape, device="meta") |
|
|
|
|
|
def torch_tensor_gather(self, dim, index): |
|
return torch_gather(self, dim, index) |
|
|
|
|
|
def torch_roll(input, shifts, dims=None): |
|
return input |
|
|
|
|
|
def torch_flip(input, dims): |
|
return input |
|
|
|
|
|
def torch_tensor_flip(self, dims): |
|
return self |
|
|
|
|
|
def torch_nn_conv1d(self, input): |
|
l_in = input.shape[-1] |
|
shape = None |
|
padding = self.padding |
|
if padding == "valid": |
|
padding = (0, 0) |
|
if padding == "same": |
|
shape = list(input.shape) |
|
if shape is None: |
|
shape = list(input.shape) |
|
l_out = math.floor( |
|
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 |
|
) |
|
shape[-1] = l_out |
|
shape[-2] = self.out_channels |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_nn_conv2d(self, input): |
|
h_in, w_in = input.shape[-2:] |
|
shape = None |
|
padding = self.padding |
|
if padding == "valid": |
|
padding = (0, 0) |
|
if padding == "same": |
|
shape = list(input.shape) |
|
if shape is None: |
|
shape = list(input.shape) |
|
h_out = math.floor( |
|
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 |
|
) |
|
w_out = math.floor( |
|
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 |
|
) |
|
shape[-2:] = [h_out, w_out] |
|
shape[-3] = self.out_channels |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_squeeze(input, dim=None): |
|
shape = list(input.shape) |
|
if dim is not None: |
|
if dim < 0: |
|
dim = input.dim() + dim |
|
if shape[dim] == 1: |
|
shape.pop(dim) |
|
else: |
|
new_shape = [] |
|
for dim_value in shape: |
|
if dim_value == 1: |
|
continue |
|
new_shape.append(dim_value) |
|
shape = new_shape |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_tensor_squeeze(self, dim=None): |
|
return torch_squeeze(self, dim) |
|
|
|
|
|
def torch_unsqueeze(input, dim): |
|
shape = list(input.shape) |
|
if dim < 0: |
|
dim = input.dim() + 1 + dim |
|
shape.insert(dim, 1) |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_tensor_unsqueeze(self, dim): |
|
return torch_unsqueeze(self, dim) |
|
|
|
|
|
def torch_unique_consecutive(input, **kwargs): |
|
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs) |
|
if isinstance(output, torch.Tensor): |
|
return output.to("meta") |
|
else: |
|
return tuple(map(output, lambda x: x.to("meta"))) |
|
|
|
|
|
def torch_nn_functional_one_hot(tensor, num_classes=-1): |
|
if num_classes < 0: |
|
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis") |
|
shape = list(tensor.shape) + [num_classes] |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_nn_mseloss(self, input, target): |
|
if self.reduction == "none": |
|
shape = target.shape |
|
else: |
|
shape = (1,) |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_nn_crossentropyloss(self, input, target): |
|
if self.reduction == "none": |
|
shape = target.shape |
|
else: |
|
shape = (1,) |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def torch_nn_bcewithlogitsloss(self, input, target): |
|
if self.reduction == "none": |
|
shape = target.shape |
|
else: |
|
shape = (1,) |
|
return torch.empty(shape, device="meta") |
|
|
|
|
|
def operator_getitem(a, b): |
|
def to_concrete(t): |
|
if isinstance(t, torch.Tensor): |
|
concrete = torch.ones_like(t, device="cpu") |
|
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: |
|
concrete = concrete.to(torch.int64) |
|
return concrete |
|
return t |
|
|
|
if isinstance(a, torch.Tensor): |
|
|
|
if isinstance(b, tuple): |
|
b = tuple(map(to_concrete, b)) |
|
else: |
|
b = to_concrete(b) |
|
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") |
|
return operator.getitem(a, b) |
|
|
|
|
|
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { |
|
torch.nn.Embedding: torch_nn_embedding, |
|
torch.nn.functional.embedding: torch_nn_functional_embedding, |
|
torch.nn.LayerNorm: torch_nn_layernorm, |
|
torch.nn.GroupNorm: torch_nn_groupnorm, |
|
torch.nn.Linear: torch_nn_linear, |
|
torch.relu: torch_relu, |
|
torch.nn.functional.relu: torch_nn_functional_relu, |
|
torch.nn.ReLU: torch_nn_relu, |
|
torch.where: torch_where, |
|
torch.abs: torch_abs, |
|
torch.arange: torch_arange, |
|
torch.full: torch_full, |
|
torch.cat: torch_cat, |
|
torch.stack: torch_stack, |
|
torch.add: torch_add, |
|
torch.mul: torch_mul, |
|
torch.Tensor.mul: torch_tensor_mul, |
|
torch.matmul: torch_matmul, |
|
torch.bmm: torch_bmm, |
|
torch.baddbmm: torch_baddbmm, |
|
torch.Tensor.baddbmm: torch_tensor_baddbmm, |
|
torch.einsum: torch_einsum, |
|
torch.Tensor.repeat: torch_tensor_repeat, |
|
torch.repeat_interleave: torch_repeat_interleave, |
|
torch.roll: torch_roll, |
|
torch.flip: torch_flip, |
|
torch.Tensor.flip: torch_tensor_flip, |
|
torch.index_select: torch_index_select, |
|
torch.Tensor.index_select: torch_tensor_index_select, |
|
torch.gather: torch_gather, |
|
torch.Tensor.gather: torch_tensor_gather, |
|
torch.nn.Conv1d: torch_nn_conv1d, |
|
torch.nn.Conv2d: torch_nn_conv2d, |
|
torch.squeeze: torch_squeeze, |
|
torch.Tensor.squeeze: torch_tensor_squeeze, |
|
torch.unsqueeze: torch_unsqueeze, |
|
torch.Tensor.unsqueeze: torch_tensor_unsqueeze, |
|
torch.unique_consecutive: torch_unique_consecutive, |
|
torch.nn.functional.one_hot: torch_nn_functional_one_hot, |
|
torch.nn.MSELoss: torch_nn_mseloss, |
|
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, |
|
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, |
|
operator.getitem: operator_getitem, |
|
} |
|
|
|
|
|
class HFProxy(Proxy): |
|
""" |
|
Proxy that uses metadata to handle data-dependent control-flow. |
|
""" |
|
|
|
def install_metadata(self, metadata): |
|
self._metadata = metadata |
|
|
|
@property |
|
def shape(self): |
|
return self.tracer.create_proxy("call_method", "size", (self,), {}) |
|
|
|
@property |
|
def device(self): |
|
|
|
|
|
return MetaDeviceAttribute(self, "device") |
|
|
|
def __len__(self): |
|
if hasattr(self, "_metadata") and self._metadata is not None: |
|
return len(self._metadata) |
|
return super().__len__() |
|
|
|
def __bool__(self): |
|
if hasattr(self, "_metadata") and self._metadata is not None: |
|
return self._metadata |
|
return super().__bool__() |
|
|
|
def __getattr__(self, k): |
|
if k == "_metadata": |
|
return self.__getattribute__(k) |
|
|
|
|
|
return HFAttribute(self, k) |
|
|
|
def __setitem__(self, indices, values): |
|
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) |
|
|
|
def __contains__(self, key): |
|
if hasattr(self, "_metadata") and self._metadata is not None: |
|
return key in self._metadata |
|
return super().__contains__(key) |
|
|
|
|
|
class HFAttribute(HFProxy): |
|
def __init__(self, root, attr: str): |
|
self.root = root |
|
self.attr = attr |
|
self.tracer = root.tracer |
|
self._node = None |
|
|
|
if hasattr(self.root, "_metadata"): |
|
self.install_metadata(getattr(self.root._metadata, attr)) |
|
|
|
@property |
|
def node(self): |
|
|
|
|
|
if self._node is None: |
|
self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node |
|
return self._node |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) |
|
|
|
|
|
class MetaDeviceAttribute(HFAttribute): |
|
pass |
|
|
|
|
|
def _proxies_to_metas(v): |
|
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others.""" |
|
if isinstance(v, MetaDeviceAttribute): |
|
return "meta" |
|
if isinstance(v, torch.fx.Proxy): |
|
if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")): |
|
raise RuntimeError(f"No metadata was found for {v}") |
|
return v._metadata |
|
return v |
|
|
|
|
|
def _gen_constructor_wrapper(target): |
|
@functools.wraps(target) |
|
def wrapper(*args, **kwargs): |
|
proxy = None |
|
|
|
def check_has_proxy(v): |
|
if isinstance(v, Proxy): |
|
nonlocal proxy |
|
proxy = v |
|
|
|
torch.fx.node.map_aggregate(args, check_has_proxy) |
|
torch.fx.node.map_aggregate(kwargs, check_has_proxy) |
|
|
|
if proxy is not None: |
|
return proxy.tracer.create_proxy("call_function", target, args, kwargs) |
|
else: |
|
return target(*args, **kwargs) |
|
|
|
return wrapper, target |
|
|
|
|
|
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): |
|
if forbidden_values is None: |
|
forbidden_values = [] |
|
value = random.randint(low, high) |
|
while value in forbidden_values: |
|
value = random.randint(low, high) |
|
return value |
|
|
|
|
|
class HFTracer(Tracer): |
|
""" |
|
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the |
|
regular PyTorch torch.fx.Proxy. |
|
""" |
|
|
|
|
|
proxy_buffer_attributes: bool = True |
|
allow_insert_stateless_mods: bool = True |
|
_TORCH_METHODS_TO_PATCH = [ |
|
"arange", |
|
"zeros", |
|
"ones", |
|
"full", |
|
"full_like", |
|
"eye", |
|
"empty", |
|
"tensor", |
|
"clamp", |
|
"finfo", |
|
] |
|
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) |
|
|
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=()): |
|
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) |
|
|
|
if not is_torch_fx_available(): |
|
raise ImportError( |
|
f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version " |
|
f"{TORCH_FX_REQUIRED_VERSION} is supported." |
|
) |
|
|
|
def _generate_dummy_input( |
|
self, model: PreTrainedModel, input_name: str, shape: List[int] |
|
) -> Dict[str, torch.Tensor]: |
|
"""Generates dummy input for model inference recording.""" |
|
|
|
|
|
model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__ |
|
device = model.device |
|
inputs_dict = {} |
|
|
|
if input_name in ["labels", "start_positions", "end_positions"]: |
|
batch_size = shape[0] |
|
if model_class_name in [ |
|
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES), |
|
]: |
|
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
elif model_class_name in [ |
|
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES), |
|
"XLNetForQuestionAnswering", |
|
]: |
|
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES): |
|
if not hasattr(model.config, "problem_type") or model.config.problem_type is None: |
|
raise ValueError( |
|
"Could not retrieve the problem type for the sequence classification task, please set " |
|
'model.config.problem_type to one of the following values: "regression", ' |
|
'"single_label_classification", or "multi_label_classification".' |
|
) |
|
|
|
if model.config.problem_type == "regression": |
|
labels_shape = (batch_size, model.config.num_labels) |
|
labels_dtype = torch.float32 |
|
elif model.config.problem_type == "single_label_classification": |
|
labels_shape = (batch_size,) |
|
labels_dtype = torch.long |
|
elif model.config.problem_type == "multi_label_classification": |
|
labels_shape = (batch_size, model.config.num_labels) |
|
labels_dtype = torch.float32 |
|
else: |
|
raise ValueError( |
|
'Expected model.config.problem_type to be either: "regression", "single_label_classification"' |
|
f', or "multi_label_classification", but "{model.config.problem_type}" was provided.' |
|
) |
|
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device) |
|
|
|
elif model_class_name in [ |
|
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES), |
|
*get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES), |
|
"GPT2DoubleHeadsModel", |
|
"PeftModelForCausalLM", |
|
"PeftModelForSeq2SeqLM", |
|
]: |
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) |
|
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]: |
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device) |
|
else: |
|
raise NotImplementedError( |
|
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet." |
|
) |
|
elif "pixel_values" in input_name: |
|
batch_size = shape[0] |
|
image_size = getattr(model.config, "image_size", None) |
|
if image_size is None: |
|
if hasattr(model.config, "vision_config"): |
|
image_size = model.config.vision_config.image_size |
|
elif hasattr(model.config, "encoder"): |
|
image_size = model.config.encoder.image_size |
|
else: |
|
image_size = (_generate_random_int(), _generate_random_int()) |
|
|
|
|
|
num_channels = getattr(model.config, "num_channels", 3) |
|
if not isinstance(image_size, collections.abc.Iterable): |
|
image_size = (image_size, image_size) |
|
height, width = image_size |
|
inputs_dict[input_name] = torch.zeros( |
|
batch_size, num_channels, height, width, dtype=torch.float32, device=device |
|
) |
|
elif "bbox" in input_name: |
|
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device) |
|
elif "input_features" in input_name: |
|
inputs_dict[input_name] = torch.zeros( |
|
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device |
|
) |
|
elif "visual_feats" in input_name: |
|
inputs_dict[input_name] = torch.zeros( |
|
shape |
|
+ [ |
|
model.config.visual_feat_dim, |
|
], |
|
dtype=torch.float, |
|
device=device, |
|
) |
|
elif "visual_pos" in input_name: |
|
inputs_dict[input_name] = torch.zeros( |
|
shape |
|
+ [ |
|
model.config.visual_pos_dim, |
|
], |
|
dtype=torch.float, |
|
device=device, |
|
) |
|
elif "inputs" in input_name: |
|
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device) |
|
elif "input_values" in input_name: |
|
batch_size, _ = shape |
|
|
|
seq_length = _generate_random_int(low=10000, high=20000) |
|
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device) |
|
elif "mask" in input_name or "ids" in input_name: |
|
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) |
|
else: |
|
shape_with_hidden_size = shape + [model.config.hidden_size] |
|
inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device) |
|
|
|
return inputs_dict |
|
|
|
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): |
|
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) |
|
|
|
if kind == "placeholder" and target in self.meta_args: |
|
rv.install_metadata(self.meta_args[target]) |
|
return rv |
|
|
|
if target in self.orig_fns: |
|
|
|
|
|
|
|
|
|
|
|
if "device" in kwargs: |
|
kwargs["device"] = "meta" |
|
|
|
try: |
|
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas) |
|
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas) |
|
|
|
if kind == "call_function": |
|
meta_target = _MANUAL_META_OVERRIDES.get(target, target) |
|
meta_out = meta_target(*args_metas, **kwargs_metas) |
|
if isinstance(meta_out, torch.Tensor): |
|
meta_out = meta_out.to(device="meta") |
|
elif kind == "call_method": |
|
method = getattr(args_metas[0].__class__, target) |
|
meta_target = _MANUAL_META_OVERRIDES.get(method, method) |
|
meta_out = meta_target(*args_metas, **kwargs_metas) |
|
elif kind == "call_module": |
|
if not hasattr(self, "orig_forward"): |
|
raise AttributeError(f"{self} does not have an attribute called orig_forward") |
|
self._disable_module_getattr = True |
|
try: |
|
mod = self.root.get_submodule(target) |
|
mod_type = type(mod) |
|
if mod_type in _MANUAL_META_OVERRIDES: |
|
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas) |
|
else: |
|
meta_out = self.orig_forward(*args_metas, **kwargs_metas) |
|
finally: |
|
self._disable_module_getattr = False |
|
elif kind == "get_attr": |
|
self._disable_module_getattr = True |
|
try: |
|
attr_itr = self.root |
|
atoms = target.split(".") |
|
for atom in atoms: |
|
attr_itr = getattr(attr_itr, atom) |
|
if isinstance(attr_itr, torch.Tensor): |
|
meta_out = attr_itr.to(device="meta") |
|
else: |
|
meta_out = attr_itr |
|
finally: |
|
self._disable_module_getattr = False |
|
else: |
|
return rv |
|
|
|
if not isinstance(rv, Proxy): |
|
raise ValueError("Don't support composite output yet") |
|
rv.install_metadata(meta_out) |
|
except Exception as e: |
|
if _IS_IN_DEBUG_MODE: |
|
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") |
|
|
|
return rv |
|
|
|
|
|
def _module_getattr(self, attr, attr_val, parameter_proxy_cache): |
|
if getattr(self, "_disable_module_getattr", False): |
|
return attr_val |
|
else: |
|
|
|
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): |
|
for n, p in collection_to_search: |
|
if attr_val is p: |
|
if n not in parameter_proxy_cache: |
|
kwargs = {} |
|
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: |
|
kwargs["proxy_factory_fn"] = ( |
|
None |
|
if not self.param_shapes_constant |
|
else lambda node: ParameterProxy(self, node, n, attr_val) |
|
) |
|
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) |
|
parameter_proxy_cache[n] = val_proxy |
|
return parameter_proxy_cache[n] |
|
return None |
|
|
|
if isinstance(attr_val, torch.nn.Parameter): |
|
maybe_parameter_proxy = maybe_get_proxy_for_attr( |
|
attr_val, self.root.named_parameters(), parameter_proxy_cache |
|
) |
|
if maybe_parameter_proxy is not None: |
|
return maybe_parameter_proxy |
|
|
|
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): |
|
maybe_buffer_proxy = maybe_get_proxy_for_attr( |
|
attr_val, self.root.named_buffers(), parameter_proxy_cache |
|
) |
|
if maybe_buffer_proxy is not None: |
|
return maybe_buffer_proxy |
|
|
|
return attr_val |
|
|
|
|
|
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): |
|
return self._module_getattr(attr, attr_val, parameter_proxy_cache) |
|
|
|
def call_module(self, m, forward, args, kwargs): |
|
self.orig_forward = forward |
|
return super().call_module(m, forward, args, kwargs) |
|
|
|
def proxy(self, node): |
|
return HFProxy(node, self) |
|
|
|
def trace( |
|
self, |
|
root: Union[torch.nn.Module, Callable[..., Any]], |
|
concrete_args: Optional[Dict[str, Any]] = None, |
|
dummy_inputs: Optional[Dict[str, Any]] = None, |
|
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True, |
|
) -> Graph: |
|
""" |
|
Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a |
|
`torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from |
|
the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a |
|
`torch.nn.Module` instance to use as the root and add embedded constants to. |
|
|
|
Args: |
|
root (`torch.nn.Module` or `Callable`): |
|
Either a `torch.nn.Module`` or a function to be traced through. If root is not a |
|
[`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail. |
|
concrete_args (`Dict[str, Any], *optional*): |
|
Concrete arguments that should not be treated as Proxies |
|
dummy_inputs (`Dict[str, Any]`, *optional*): |
|
The dummy inputs needed to handle data-dependent control-flow if `root` is not a |
|
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a |
|
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs. |
|
complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`): |
|
If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in |
|
`dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing. |
|
|
|
Returns: |
|
`torch.fx.Graph`: |
|
A FX `torch.fx.Graph` representing the semantics of the passed-in `root`. |
|
|
|
""" |
|
sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root) |
|
|
|
if concrete_args is None: |
|
concrete_args = {} |
|
|
|
if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs: |
|
for param in sig.parameters.values(): |
|
if param.name in dummy_inputs: |
|
continue |
|
if param.default is inspect.Parameter.empty: |
|
raise ValueError(f"You need to specify a default value for the parameter {param.name}.") |
|
concrete_args.update( |
|
{ |
|
p.name: p.default |
|
for p in sig.parameters.values() |
|
if (p.name not in dummy_inputs and p.name not in concrete_args) |
|
} |
|
) |
|
|
|
input_names = sig.parameters.keys() - concrete_args.keys() |
|
|
|
|
|
batch_size = _generate_random_int() |
|
sequence_length = _generate_random_int() |
|
shape = [batch_size, sequence_length] |
|
|
|
if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES): |
|
num_choices = _generate_random_int(low=2, high=5) |
|
shape.insert(1, num_choices) |
|
|
|
inputs = dict(dummy_inputs) if dummy_inputs is not None else {} |
|
for input_name in input_names: |
|
if input_name in inputs: |
|
continue |
|
|
|
|
|
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith( |
|
("_deserialize_graph_module", "_CodeOnlyModule") |
|
): |
|
inputs.update(self._generate_dummy_input(root, input_name, shape)) |
|
else: |
|
raise RuntimeError( |
|
f"Could not generate input named {input_name} for because root is not a" |
|
" transformers.PreTrainedModel." |
|
) |
|
|
|
concrete_metas = { |
|
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_ |
|
for input_name, input_ in inputs.items() |
|
} |
|
for param in sig.parameters.values(): |
|
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names: |
|
concrete_metas[f"**{param.name}"] = {} |
|
self.meta_args = concrete_metas |
|
self.patched_torch_methods = { |
|
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH |
|
} |
|
self.orig_fns = set() |
|
|
|
for name, (wrapper, orig) in self.patched_torch_methods.items(): |
|
setattr(torch, name, wrapper) |
|
self.orig_fns.add(orig) |
|
|
|
try: |
|
self.graph = super().trace(root, concrete_args=concrete_args) |
|
finally: |
|
for name, (_, orig) in self.patched_torch_methods.items(): |
|
setattr(torch, name, orig) |
|
|
|
|
|
|
|
for node in self.graph.nodes: |
|
if node.op == "placeholder": |
|
|
|
if node.target in input_names: |
|
node.args = () |
|
|
|
|
|
node.type = torch.Tensor |
|
|
|
else: |
|
to_visit = [node] |
|
to_delete = collections.OrderedDict() |
|
while to_visit: |
|
n = to_visit.pop(0) |
|
to_delete[n] = None |
|
to_visit += list(n.users.keys()) |
|
|
|
for user in reversed(to_delete.keys()): |
|
self.graph.erase_node(user) |
|
|
|
|
|
|
|
if node.op == "output": |
|
node.type = None |
|
|
|
return self.graph |
|
|
|
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool: |
|
""" |
|
Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module |
|
because its attributes are input-dependent. |
|
""" |
|
return any(isinstance(attr, Proxy) for attr in mod.__dict__.values()) |
|
|
|
def _insert_module_as_submodule(self, mod: nn.Module) -> str: |
|
""" |
|
Helper method which tries to insert a module that was not declared as submodule. |
|
""" |
|
|
|
|
|
if self._stateless_mod_instanciation_depends_on_proxies(mod): |
|
return "" |
|
idx = 0 |
|
mod_name = mod.__class__.__name__.lower() |
|
path = f"{mod_name}_{idx}" |
|
already_inserted = False |
|
while hasattr(self.root, path): |
|
if getattr(self.root, path) is mod: |
|
already_inserted = True |
|
break |
|
path = f"{mod_name}_{idx}" |
|
idx += 1 |
|
|
|
|
|
if not already_inserted: |
|
self.root.add_module(path, mod) |
|
return path |
|
|
|
def path_of_module(self, mod: nn.Module) -> str: |
|
""" |
|
Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has |
|
a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the |
|
string "foo.bar". |
|
|
|
Args: |
|
mod (str): The `Module` to retrieve the qualified name for. |
|
""" |
|
try: |
|
return super().path_of_module(mod) |
|
except NameError as e: |
|
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: |
|
path = self._insert_module_as_submodule(mod) |
|
return path |
|
raise e |
|
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: |
|
return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module( |
|
m, module_qualified_name |
|
) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def keys(self, obj: "Proxy") -> Any: |
|
"""Called when a proxy object is has the keys() method called. |
|
This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in |
|
your custom tracer. |
|
""" |
|
attribute = HFAttribute(obj, "keys")() |
|
if obj.node.target == "**kwargs": |
|
return attribute._metadata |
|
return attribute |
|
|
|
|
|
def get_concrete_args(model: nn.Module, input_names: List[str]): |
|
sig = inspect.signature(model.forward) |
|
|
|
if not (set(input_names) <= set(sig.parameters.keys())): |
|
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names) |
|
formatted_allowed_input_names = ", ".join(sig.parameters.keys()) |
|
raise ValueError( |
|
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:" |
|
f" {formatted_allowed_input_names}" |
|
) |
|
|
|
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} |
|
|
|
|
|
def check_if_model_is_supported(model: PreTrainedModel): |
|
if model.__class__.__name__ not in _SUPPORTED_MODELS: |
|
supported_model_names = ", ".join(_SUPPORTED_MODELS) |
|
raise NotImplementedError( |
|
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" |
|
) |
|
|
|
|
|
def symbolic_trace( |
|
model: PreTrainedModel, |
|
input_names: Optional[List[str]] = None, |
|
disable_check: bool = False, |
|
tracer_cls: Type[HFTracer] = HFTracer, |
|
) -> GraphModule: |
|
""" |
|
Performs symbolic tracing on the model. |
|
|
|
Args: |
|
model ([`PretrainedModel`]): |
|
The model to trace. |
|
input_names (`List[str]`, *optional*): |
|
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead. |
|
disable_check (`bool`, *optional*, defaults to `False`): |
|
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes. |
|
tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`): |
|
The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead. |
|
|
|
Returns: |
|
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. |
|
|
|
Example: |
|
|
|
```python |
|
from transformers.utils.fx import symbolic_trace |
|
|
|
traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"]) |
|
``` |
|
""" |
|
if input_names is None: |
|
input_names = model.dummy_inputs.keys() |
|
|
|
input_names = list(input_names) |
|
concrete_args = get_concrete_args(model, input_names) |
|
|
|
if not disable_check: |
|
check_if_model_is_supported(model) |
|
|
|
|
|
tracer = tracer_cls() |
|
traced_graph = tracer.trace(model, concrete_args=concrete_args) |
|
traced = torch.fx.GraphModule(model, traced_graph) |
|
|
|
traced.config = model.config |
|
|
|
|
|
traced.class_for_deserialization = model.__class__ |
|
traced.device = model.device |
|
|
|
return traced |
|
|