# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import importlib
import json
import os
import pickle
import sys
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from packaging import version
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TFAutoModel
if is_torch_available():
import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
from ..models.auto.modeling_auto import AutoModel
else:
Dataset = None
KeyDataset = None
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__)
def collate_fn(items):
if len(items) != 1:
raise ValueError("This collate_fn is meant to be used with batch_size=1")
return items[0]
def infer_framework_load_model(
model,
config: AutoConfig,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
config (:class:`~transformers.AutoConfig`):
The config associated with the model to help using the correct class
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
The task defining which pipeline will be returned.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
if isinstance(model, str):
model_kwargs["_from_pipeline"] = task
class_tuple = ()
look_pt = is_torch_available() and framework in {"pt", None}
look_tf = is_tf_available() and framework in {"tf", None}
if model_classes:
if look_pt:
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
if look_tf:
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
if config.architectures:
classes = []
for architecture in config.architectures:
transformers_module = importlib.import_module("transformers")
if look_pt:
_class = getattr(transformers_module, architecture, None)
if _class is not None:
classes.append(_class)
if look_tf:
_class = getattr(transformers_module, f"TF{architecture}", None)
if _class is not None:
classes.append(_class)
class_tuple = class_tuple + tuple(classes)
if len(class_tuple) == 0:
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
for model_class in class_tuple:
kwargs = model_kwargs.copy()
if framework == "pt" and model.endswith(".h5"):
kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
try:
model = model_class.from_pretrained(model, **kwargs)
if hasattr(model, "eval"):
model = model.eval()
# Stop loading on the first successful load.
break
except (OSError, ValueError):
continue
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model
def infer_framework_from_model(
model,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
The task defining which pipeline will be returned.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
else:
config = model.config
return infer_framework_load_model(
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
)
def get_framework(model, revision: Optional[str] = None):
"""
Select framework (TensorFlow or PyTorch) to use.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
the model name). If no specific model is provided, defaults to using PyTorch.
"""
warnings.warn(
"`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.",
FutureWarning,
)
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
if isinstance(model, str):
if is_torch_available() and not is_tf_available():
model = AutoModel.from_pretrained(model, revision=revision)
elif is_tf_available() and not is_torch_available():
model = TFAutoModel.from_pretrained(model, revision=revision)
else:
try:
model = AutoModel.from_pretrained(model, revision=revision)
except OSError:
model = TFAutoModel.from_pretrained(model, revision=revision)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
"""
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
Args:
targeted_task (:obj:`Dict` ):
Dictionary representing the given task, that should contain default models
framework (:obj:`str`, None)
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
task_options (:obj:`Any`, None)
Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for
translation task.
Returns
:obj:`str` The model string representing the default model for this pipeline
"""
if is_torch_available() and not is_tf_available():
framework = "pt"
elif is_tf_available() and not is_torch_available():
framework = "tf"
defaults = targeted_task["default"]
if task_options:
if task_options not in defaults:
raise ValueError(f"The task does not provide any default models for options {task_options}")
default_models = defaults[task_options]["model"]
elif "model" in defaults:
default_models = targeted_task["default"]["model"]
else:
# XXX This error message needs to be updated to be more generic if more tasks are going to become
# parametrized
raise ValueError('The task defaults can\'t be correctly selected. You probably meant "translation_XX_to_YY"')
if framework is None:
framework = "pt"
return default_models[framework]
[docs]class PipelineException(Exception):
"""
Raised by a :class:`~transformers.Pipeline` when handling __call__.
Args:
task (:obj:`str`): The task of the pipeline.
model (:obj:`str`): The model used by the pipeline.
reason (:obj:`str`): The error message to display.
"""
def __init__(self, task: str, model: str, reason: str):
super().__init__(reason)
self.task = task
self.model = model
[docs]class ArgumentHandler(ABC):
"""
Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
"""
@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()
class _ScikitCompat(ABC):
"""
Interface layer for the Scikit and Keras compatibility.
"""
@abstractmethod
def transform(self, X):
raise NotImplementedError()
@abstractmethod
def predict(self, X):
raise NotImplementedError()
PIPELINE_INIT_ARGS = r"""
Arguments:
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
:class:`~transformers.PreTrainedTokenizer`.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`):
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
must be installed.
If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
is provided.
task (:obj:`str`, defaults to :obj:`""`):
A task-identifier for the pipeline.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
if is_torch_available():
class PipelineDataset(Dataset):
def __init__(self, dataset, process, params):
self.dataset = dataset
self.process = process
self.params = params
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
item = self.dataset[i]
processed = self.process(item, **self.params)
return processed
class PipelineIterator(IterableDataset):
def __init__(self, loader, infer, params):
self.loader = loader
self.infer = infer
self.params = params
def __len__(self):
return len(self.loader)
def __iter__(self):
self.iterator = iter(self.loader)
return self
def __next__(self):
item = next(self.iterator)
processed = self.infer(item, **self.params)
return processed
class KeyDataset(Dataset):
def __init__(self, dataset: Dataset, key: str):
self.dataset = dataset
self.key = key
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
return self.dataset[i][self.key]
[docs]@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
different pipelines.
Base class implementing pipelined operations. Pipeline workflow is defined as a sequence of the following
operations:
Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output
Pipeline supports running on CPU or GPU through the device argument (see below).
Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
pickle format.
"""
default_input_names = None
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: Optional[PreTrainedTokenizer] = None,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
**kwargs,
):
if framework is None:
framework, model = infer_framework_load_model(model, config=model.config)
self.task = task
self.model = model
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.binary_output = binary_output
# Special handling
if self.framework == "pt" and self.device.type == "cuda":
self.model = self.model.to(self.device)
# Update config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
self.call_count = 0
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
[docs] def save_pretrained(self, save_directory: str):
"""
Save the pipeline's model and tokenizer.
Args:
save_directory (:obj:`str`):
A path to the directory where to saved. It will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
self.model.save_pretrained(save_directory)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory)
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory)
if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory)
[docs] def predict(self, X):
"""
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
"""
return self(X=X)
[docs] @contextmanager
def device_placement(self):
"""
Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
Returns:
Context manager
Examples::
# Explicitly ask for tensor allocation on CUDA device :0
pipe = pipeline(..., device=0)
with pipe.device_placement():
# Every framework specific tensor allocation will be done on the request device
output = pipe(...)
"""
if self.framework == "tf":
with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"):
yield
else:
if self.device.type == "cuda":
torch.cuda.set_device(self.device)
yield
[docs] def ensure_tensor_on_device(self, **inputs):
"""
Ensure PyTorch tensors are on the specified device.
Args:
inputs (keyword arguments that should be :obj:`torch.Tensor`, the rest is ignored): The tensors to place on :obj:`self.device`.
Recursive on lists **only**.
Return:
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
"""
return self._ensure_tensor_on_device(inputs, self.device)
def _ensure_tensor_on_device(self, inputs, device):
if isinstance(inputs, ModelOutput):
return ModelOutput(
{name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
)
elif isinstance(inputs, dict):
return {name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
elif isinstance(inputs, UserDict):
return UserDict({name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()})
elif isinstance(inputs, list):
return [self._ensure_tensor_on_device(item, device) for item in inputs]
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs.to(device)
else:
return inputs
[docs] def check_model_type(self, supported_models: Union[List[str], dict]):
"""
Check if the model class is in supported by the pipeline.
Args:
supported_models (:obj:`List[str]` or :obj:`dict`):
The list of models supported by the pipeline, or a dictionary with model class values.
"""
if not isinstance(supported_models, list): # Create from a model mapping
supported_models_names = []
for config, model in supported_models.items():
# Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple):
supported_models_names.extend([_model.__name__ for _model in model])
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
logger.error(
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}."
)
@abstractmethod
def _sanitize_parameters(self, **pipeline_parameters):
"""
_sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
methods. It should return 3 dictionnaries of the resolved parameters used by the various `preprocess`,
`forward` and `postprocess` methods. Do not fill dictionnaries if the caller didn't specify a kwargs. This
let's you keep defaults in function signatures, which is more "natural".
It is not meant to be called directly, it will be automatically called and the final parameters resolved by
`__init__` and `__call__`
"""
raise NotImplementedError("_sanitize_parameters not implemented")
[docs] @abstractmethod
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
"""
Preprocess will take the `input_` of a specific pipeline and return a dictionnary of everything necessary for
`_forward` to run properly. It should contain at least one tensor, but might have arbitrary other items.
"""
raise NotImplementedError("preprocess not implemented")
@abstractmethod
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
"""
_forward will receive the prepared dictionnary from `preprocess` and run it on the model. This method might
involve the GPU or the CPU and should be agnostic to it. Isolating this function is the reason for `preprocess`
and `postprocess` to exist, so that the hot path, this method generally can run as fast as possible.
It is not meant to be called directly, `forward` is preferred. It is basically the same but contains additional
code surrounding `_forward` making sure tensors and models are on the same device, disabling the training part
of the code (leading to faster inference).
"""
raise NotImplementedError("_forward not implemented")
[docs] @abstractmethod
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
"""
Postprocess will receive the raw outputs of the `_forward` method, generally tensors, and reformat them into
something more friendly. Generally it will output a list or a dict or results (containing just strings and
numbers).
"""
raise NotImplementedError("postprocess not implemented")
def get_inference_context(self):
inference_context = (
torch.inference_mode if version.parse(torch.__version__) >= version.parse("1.9.0") else torch.no_grad
)
return inference_context
def forward(self, model_inputs, **forward_params):
with self.device_placement():
if self.framework == "tf":
model_inputs["training"] = False
model_outputs = self._forward(model_inputs, **forward_params)
elif self.framework == "pt":
inference_context = self.get_inference_context()
with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
model_outputs = self._forward(model_inputs, **forward_params)
model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
else:
raise ValueError(f"Framework {self.framework} is not supported")
return model_outputs
def get_iterator(self, inputs, num_workers: int, preprocess_params, forward_params, postprocess_params):
if "TOKENIZERS_PARALLELISM" not in os.environ:
logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
dataset = PipelineDataset(inputs, self.preprocess, preprocess_params)
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, collate_fn=collate_fn)
model_iterator = PipelineIterator(dataloader, self.forward, forward_params)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator
def __call__(self, inputs, *args, num_workers=8, **kwargs):
if args:
logger.warning(f"Ignoring args : {args}")
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {**self._preprocess_params, **preprocess_params}
forward_params = {**self._forward_params, **forward_params}
postprocess_params = {**self._postprocess_params, **postprocess_params}
self.call_count += 1
if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda":
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
UserWarning,
)
if isinstance(inputs, list):
if self.framework == "pt":
final_iterator = self.get_iterator(
inputs, num_workers, preprocess_params, forward_params, postprocess_params
)
outputs = [output for output in final_iterator]
return outputs
else:
return self.run_multi(inputs, preprocess_params, forward_params, postprocess_params)
elif Dataset is not None and isinstance(inputs, Dataset):
return self.get_iterator(inputs, num_workers, preprocess_params, forward_params, postprocess_params)
else:
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
def run_multi(self, inputs, preprocess_params, forward_params, postprocess_params):
return [self.run_single(item, preprocess_params, forward_params, postprocess_params) for item in inputs]
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
model_inputs = self.preprocess(inputs, **preprocess_params)
model_outputs = self.forward(model_inputs, **forward_params)
outputs = self.postprocess(model_outputs, **postprocess_params)
return outputs