The Tasks Manager
Exporting a model from one framework to some format (also called backend here) involves specifying inputs and outputs information that the export function needs. The way optimum.exporters
is structured for each backend is as follows:
- Configuration classes containing the information for each model to perform the export.
- Exporting functions using the proper configuration for the model to export.
The role of the TasksManager is to be the main entry-point to load a model given a name and a task, and to get the proper configuration for a given (architecture, backend) couple. That way, there is a centralized place to register the task -> model class
and (architecture, backend) -> configuration
mappings. This allows the export functions to use this, and to rely on the various checks it provides.
Task names
The tasks supported might depend on the backend, but here are the mappings between a task name and the auto class for both PyTorch and TensorFlow.
It is possible to know which tasks are supported for a model for a given backend, by doing:
>>> from optimum.exporters.tasks import TasksManager
>>> model_type = "distilbert"
>>> # For instance, for the ONNX export.
>>> backend = "onnx"
>>> distilbert_tasks = list(TasksManager.get_supported_tasks_for_model_type(model_type, backend).keys())
>>> print(distilbert_tasks)
['default', 'masked-lm', 'sequence-classification', 'multiple-choice', 'token-classification', 'question-answering']
PyTorch
Task | Auto Class |
---|---|
causal-lm , causal-lm-with-past |
AutoModelForCausalLM |
default , default-with-past |
AutoModel |
masked-lm |
AutoModelForMaskedLM |
question-answering |
AutoModelForQuestionAnswering |
seq2seq-lm , seq2seq-lm-with-past |
AutoModelForSeq2SeqLM |
sequence-classification |
AutoModelForSequenceClassification |
token-classification |
AutoModelForTokenClassification |
multiple-choice |
AutoModelForMultipleChoice |
image-classification |
AutoModelForImageClassification |
object-detection |
AutoModelForObjectDetection |
image-segmentation |
AutoModelForImageSegmentation |
masked-im |
AutoModelForMaskedImageModeling |
semantic-segmentation |
AutoModelForSemanticSegmentation |
speech2seq-lm |
AutoModelForSpeechSeq2Seq |
TensorFlow
Task | Auto Class |
---|---|
causal-lm , causal-lm-with-past |
TFAutoModelForCausalLM |
default , default-with-past |
TFAutoModel |
masked-lm |
TFAutoModelForMaskedLM |
question-answering |
TFAutoModelForQuestionAnswering |
seq2seq-lm , seq2seq-lm-with-past |
TFAutoModelForSeq2SeqLM |
sequence-classification |
TFAutoModelForSequenceClassification |
token-classification |
TFAutoModelForTokenClassification |
multiple-choice |
TFAutoModelForMultipleChoice |
semantic-segmentation |
TFAutoModelForSemanticSegmentation |
Reference
Handles the task name -> model class
and architecture -> configuration
mappings.
create_register
< source >(
backend: str
overwrite_existing: bool = False
)
→
Callable[[str, Tuple[str, ...]], Callable[[Type], Type]]
Parameters
-
backend (
str
) — The name of the backend that the register function will handle. -
overwrite_existing (
bool
, defaults toFalse
) — Whether or not the register function is allowed to overwrite an already existing config.
Returns
Callable[[str, Tuple[str, ...]], Callable[[Type], Type]]
A decorator taking the model type and a the supported tasks.
Creates a register function for the specified backend.
determine_framework
< source >(
model_name_or_path: typing.Union[str, pathlib.Path]
subfolder: str = ''
framework: typing.Optional[str] = None
)
→
str
Parameters
-
model_name_or_path (
Union[str, Path]
) — Can be either the model id of a model repo on the Hugging Face Hub, or a path to a local directory containing a model. -
subfolder (
str
, optional, defaults to""
) — In case the model files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. -
framework (
Optional[str]
, optional) — The framework to use for the export. See above for priority if none provided.
Returns
str
The framework to use for the export.
Determines the framework to use for the export.
The priority is in the following order:
- User input via
framework
. - If local checkpoint is provided, use the same framework as the checkpoint.
- If model repo, try to infer the framework from the Hub.
- If could not infer, use available framework in environment, with priority given to PyTorch.
Retrieves all the possible tasks.
get_exporter_config_constructor
< source >(
exporter: str
model: typing.Union[ForwardRef('PreTrainedModel'), ForwardRef('TFPreTrainedModel'), NoneType] = None
task: str = 'default'
model_type: typing.Optional[str] = None
model_name: typing.Optional[str] = None
exporter_config_kwargs: typing.Union[typing.Dict[str, typing.Any], NoneType] = None
)
→
ExportConfigConstructor
Parameters
-
exporter (
str
) — The exporter to use. -
model (
Optional[Union[PreTrainedModel, TFPreTrainedModel]]
, defaults toNone
) — The instance of the model. -
task (
str
, defaults to"default"
) — The task to retrieve the config for. -
model_type (
Optional[str]
, defaults toNone
) — The model type to retrieve the config for. -
model_name (
Optional[str]
, defaults toNone
) — The name attribute of the model object, only used for the exception message. -
exporter_config_kwargs(`Optional[Dict[str, Any]]
, defaults to
None`) — Arguments that will be passed to the exporter config class when building the config constructor.
Returns
ExportConfigConstructor
The ExportConfig
constructor for the requested backend.
Gets the ExportConfigConstructor
for a model (or alternatively for a model type) and task combination.
get_model_class_for_task
< source >( task: str framework: str = 'pt' )
Attempts to retrieve an AutoModel class from a task name.
get_model_from_task
< source >( task: str model_name_or_path: typing.Union[str, pathlib.Path] subfolder: str = '' revision: typing.Optional[str] = None framework: typing.Optional[str] = None cache_dir: typing.Optional[str] = None torch_dtype: typing.Optional[ForwardRef('torch.dtype')] = None **model_kwargs )
Parameters
-
task (
str
) — The task required. -
model_name_or_path (
Union[str, Path]
) — Can be either the model id of a model repo on the Hugging Face Hub, or a path to a local directory containing a model. -
subfolder (
str
, optional, defaults to""
) — In case the model files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. -
revision (
Optional[str]
, optional) — Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. -
framework (
Optional[str]
, optional) — The framework to use for the export. SeeTasksManager.determine_framework
for the priority should none be provided. -
cache_dir (
Optional[str]
, optional) — Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. -
torch_dtype (
Optional[torch.dtype]
, defaults toNone
) — Data type to load the model on. PyTorch-only argument. -
model_kwargs (
Dict[str, Any]
, optional) — Keyword arguments to pass to the model.from_pretrained()
method.
Retrieves a model from its name and the task to be enabled.
Returns the list of supported architectures by the exporter for a given task.
get_supported_tasks_for_model_type
< source >(
model_type: str
exporter: str
model_name: typing.Optional[str] = None
)
→
TaskNameToExportConfigDict
Parameters
-
model_type (
str
) — The model type to retrieve the supported tasks for. -
exporter (
str
) — The name of the exporter. -
model_name (
Optional[str]
, defaults toNone
) — The name attribute of the model object, only used for the exception message.
Returns
TaskNameToExportConfigDict
The dictionary mapping each task to a corresponding ExportConfig
constructor.
Retrieves the task -> exporter backend config constructors
map from the model type.
infer_task_from_model
< source >(
model: typing.Union[str, ForwardRef('PreTrainedModel'), ForwardRef('TFPreTrainedModel'), typing.Type]
subfolder: str = ''
revision: typing.Optional[str] = None
)
→
str
Parameters
-
model (
str
) — The model to infer the task from. This can either be the name of a repo on the HuggingFace Hub, an instance of a model, or a model class. -
subfolder (
str
, optional, defaults to""
) — In case the model files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. -
revision (
Optional[str]
, optional) — Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id.
Returns
str
The task name automatically detected from the model repo.
Infers the task from the model repo.