from typing import List, Union
import numpy as np
from ..file_utils import add_end_docstrings
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
logger = logging.get_logger(__name__)
[docs]class ZeroShotClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for zero-shot for text classification by turning each possible label into an NLI
premise/hypothesis pair.
"""
def _parse_labels(self, labels):
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",") if label.strip()]
return labels
def __call__(self, sequences, labels, hypothesis_template):
if len(labels) == 0 or len(sequences) == 0:
raise ValueError("You must include at least one label and at least one sequence.")
if hypothesis_template.format(labels[0]) == hypothesis_template:
raise ValueError(
(
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
"Make sure the passed template includes formatting syntax such as {{}} where the label should go."
).format(hypothesis_template)
)
if isinstance(sequences, str):
sequences = [sequences]
sequence_pairs = []
for sequence in sequences:
sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])
return sequence_pairs, sequences
[docs]@add_end_docstrings(PIPELINE_INIT_ARGS)
class ZeroShotClassificationPipeline(Pipeline):
"""
NLI-based zero-shot classification pipeline using a :obj:`ModelForSequenceClassification` trained on NLI (natural
language inference) tasks.
Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the candidate
label being valid. Any NLI model can be used, but the id of the `entailment` label must be included in the model
config's :attr:`~transformers.PretrainedConfig.label2id`.
This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task identifier:
:obj:`"zero-shot-classification"`.
The models that this pipeline can use are models that have been fine-tuned on an NLI task. See the up-to-date list
of available models on `huggingface.co/models <https://huggingface.co/models?search=nli>`__.
"""
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
self._args_parser = args_parser
super().__init__(*args, **kwargs)
if self.entailment_id == -1:
logger.warning(
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
"-1. Define a descriptive label2id mapping in the model config to ensure correct outputs."
)
@property
def entailment_id(self):
for label, ind in self.model.config.label2id.items():
if label.lower().startswith("entail"):
return ind
return -1
def _parse_and_tokenize(
self, sequence_pairs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.ONLY_FIRST, **kwargs
):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
"""
return_tensors = self.framework
if getattr(self.tokenizer, "pad_token", None) is None:
# XXX some tokenizers do not have a padding token, we use simple lists
# and no padding then
logger.warning("The tokenizer {self.tokenizer} does not have a pad token, we're not running it as a batch")
padding = False
inputs = []
for sequence_pair in sequence_pairs:
model_input = self.tokenizer(
text=sequence_pair[0],
text_pair=sequence_pair[1],
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
inputs.append(model_input)
else:
try:
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
except Exception as e:
if "too short" in str(e):
# tokenizers might yell that we want to truncate
# to a value that is not even reached by the input.
# In that case we don't want to truncate.
# It seems there's not a really better way to catch that
# exception.
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
)
else:
raise e
return inputs
def _sanitize_parameters(self, **kwargs):
if kwargs.get("multi_class", None) is not None:
kwargs["multi_label"] = kwargs["multi_class"]
logger.warning(
"The `multi_class` argument has been deprecated and renamed to `multi_label`. "
"`multi_class` will be removed in a future version of Transformers."
)
preprocess_params = {}
if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = self._args_parser._parse_labels(kwargs["candidate_labels"])
if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
postprocess_params = {}
if "multi_label" in kwargs:
postprocess_params["multi_label"] = kwargs["multi_label"]
return preprocess_params, {}, postprocess_params
[docs] def __call__(
self,
sequences: Union[str, List[str]],
*args,
**kwargs,
):
"""
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
documentation for more information.
Args:
sequences (:obj:`str` or :obj:`List[str]`):
The sequence(s) to classify, will be truncated if the model input is too large.
candidate_labels (:obj:`str` or :obj:`List[str]`):
The set of possible class labels to classify each sequence into. Can be a single label, a string of
comma-separated labels, or a list of labels.
hypothesis_template (:obj:`str`, `optional`, defaults to :obj:`"This example is {}."`):
The template used to turn each label into an NLI-style hypothesis. This template must include a {} or
similar syntax for the candidate label to be inserted into the template. For example, the default
template is :obj:`"This example is {}."` With the candidate label :obj:`"sports"`, this would be fed
into the model like :obj:`"<cls> sequence to classify <sep> This example is sports . <sep>"`. The
default template works well in many cases, but it may be worthwhile to experiment with different
templates depending on the task setting.
multi_label (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not multiple candidate labels can be true. If :obj:`False`, the scores are normalized such
that the sum of the label likelihoods for each sequence is 1. If :obj:`True`, the labels are considered
independent and probabilities are normalized for each candidate by doing a softmax of the entailment
score vs. the contradiction score.
Return:
A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the following keys:
- **sequence** (:obj:`str`) -- The sequence for which this is the output.
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
"""
if len(args) == 0:
pass
elif len(args) == 1 and "candidate_labels" not in kwargs:
kwargs["candidate_labels"] = args[0]
else:
raise ValueError(f"Unable to understand extra arguments {args}")
return super().__call__(sequences, **kwargs)
[docs] def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
model_inputs = self._parse_and_tokenize(sequence_pairs)
prepared_inputs = {
"candidate_labels": candidate_labels,
"sequences": sequences,
"inputs": model_inputs,
}
return prepared_inputs
def _forward(self, inputs):
candidate_labels = inputs["candidate_labels"]
sequences = inputs["sequences"]
model_inputs = inputs["inputs"]
if isinstance(model_inputs, list):
outputs = []
for input_ in model_inputs:
prediction = self.model(**input_)[0].cpu()
outputs.append(prediction)
else:
outputs = self.model(**model_inputs)
model_outputs = {"candidate_labels": candidate_labels, "sequences": sequences, "outputs": outputs}
return model_outputs
[docs] def postprocess(self, model_outputs, multi_label=False):
candidate_labels = model_outputs["candidate_labels"]
sequences = model_outputs["sequences"]
outputs = model_outputs["outputs"]
if self.framework == "pt":
if isinstance(outputs, list):
logits = np.concatenate([output.cpu().numpy() for output in outputs], axis=0)
else:
logits = outputs["logits"].cpu().numpy()
else:
if isinstance(outputs, list):
logits = np.concatenate([output.numpy() for output in outputs], axis=0)
else:
logits = outputs["logits"].numpy()
N = logits.shape[0]
n = len(candidate_labels)
num_sequences = N // n
reshaped_outputs = logits.reshape((num_sequences, n, -1))
if multi_label or len(candidate_labels) == 1:
# softmax over the entailment vs. contradiction dim for each label independently
entailment_id = self.entailment_id
contradiction_id = -1 if entailment_id == 0 else 0
entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
scores = scores[..., 1]
else:
# softmax the "entailment" logits over all candidate labels
entail_logits = reshaped_outputs[..., self.entailment_id]
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
result = []
for iseq in range(num_sequences):
top_inds = list(reversed(scores[iseq].argsort()))
result.append(
{
"sequence": sequences[iseq],
"labels": [candidate_labels[i] for i in top_inds],
"scores": scores[iseq, top_inds].tolist(),
}
)
if len(result) == 1:
return result[0]
return result