import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
class TokenClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for token classification.
"""
def __call__(self, inputs: Union[str, List[str]], **kwargs):
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
inputs = list(inputs)
batch_size = len(inputs)
elif isinstance(inputs, str):
inputs = [inputs]
batch_size = 1
else:
raise ValueError("At least one input is required.")
offset_mapping = kwargs.get("offset_mapping")
if offset_mapping:
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
offset_mapping = [offset_mapping]
if len(offset_mapping) != batch_size:
raise ValueError("offset_mapping should have the same batch size as the input")
return inputs, offset_mapping
class AggregationStrategy(ExplicitEnum):
"""All the valid aggregation strategies for TokenClassificationPipeline"""
NONE = "none"
SIMPLE = "simple"
FIRST = "first"
AVERAGE = "average"
MAX = "max"
[docs]@add_end_docstrings(
PIPELINE_INIT_ARGS,
r"""
ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
A list of labels to ignore.
grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to
the same entity together in the predictions or not.
aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction.
- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Words will simply use the tag of the first token of the word when
there is ambiguity.
- "average" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Word entity will simply be the token with the maximum score.
""",
)
class TokenClassificationPipeline(Pipeline):
"""
Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the `named entity recognition
examples <../task_summary.html#named-entity-recognition>`__ for more information.
This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location
or miscellaneous).
The models that this pipeline can use are models that have been fine-tuned on a token classification task. See the
up-to-date list of available models on `huggingface.co/models
<https://huggingface.co/models?filter=token-classification>`__.
"""
default_input_names = "sequences"
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = TokenClassificationArgumentHandler(),
device: int = -1,
binary_output: bool = False,
ignore_labels=["O"],
task: str = "",
grouped_entities: Optional[bool] = None,
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
device=device,
binary_output=binary_output,
task=task,
)
self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self._args_parser = args_parser
self.ignore_labels = ignore_labels
if aggregation_strategy is None:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None or ignore_subwords is not None:
if grouped_entities and ignore_subwords:
aggregation_strategy = AggregationStrategy.FIRST
elif grouped_entities and not ignore_subwords:
aggregation_strategy = AggregationStrategy.SIMPLE
else:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None:
warnings.warn(
f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if isinstance(aggregation_strategy, str):
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
if (
aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
and not self.tokenizer.is_fast
):
raise ValueError(
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
)
self.aggregation_strategy = aggregation_strategy
[docs] def __call__(self, inputs: Union[str, List[str]], **kwargs):
"""
Classify each token of the text(s) given as inputs.
Args:
inputs (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of texts) for token classification.
Return:
A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy)
with the following keys:
- **word** (:obj:`str`) -- The token/word classified.
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when
`aggregation_strategy` is not :obj:`"none"`.
- **index** (:obj:`int`, only present when ``aggregation_strategy="none"``) -- The index of the
corresponding token in the sentence.
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer
- **end** (:obj:`int`, `optional`) -- The index of the end of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer
"""
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)
answers = []
for i, sentence in enumerate(_inputs):
# Manage correct placement of the tensors
with self.device_placement():
tokens = self.tokenizer(
sentence,
return_attention_mask=False,
return_tensors=self.framework,
truncation=True,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if self.tokenizer.is_fast:
offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
elif offset_mappings:
offset_mapping = offset_mappings[i]
else:
offset_mapping = None
special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]
# Forward
if self.framework == "tf":
entities = self.model(tokens.data)[0][0].numpy()
input_ids = tokens["input_ids"].numpy()[0]
else:
with torch.no_grad():
tokens = self.ensure_tensor_on_device(**tokens)
entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens["input_ids"].cpu().numpy()[0]
scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
pre_entities = self.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask)
grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in self.ignore_labels
and entity.get("entity_group", None) not in self.ignore_labels
]
answers.append(entities)
if len(answers) == 1:
return answers[0]
return answers
[docs] def gather_pre_entities(
self,
sentence: str,
input_ids: np.ndarray,
scores: np.ndarray,
offset_mapping: Optional[List[Tuple[int, int]]],
special_tokens_mask: np.ndarray,
) -> List[dict]:
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
for idx, token_scores in enumerate(scores):
# Filter special_tokens, they should only occur
# at the sentence boundaries since we're not encoding pairs of
# sentences so we don't have to keep track of those.
if special_tokens_mask[idx]:
continue
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
is_subword = len(word_ref) != len(word)
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
is_subword = False
else:
start_ind = None
end_ind = None
is_subword = False
pre_entity = {
"word": word,
"scores": token_scores,
"start": start_ind,
"end": end_ind,
"index": idx,
"is_subword": is_subword,
}
pre_entities.append(pre_entity)
return pre_entities
def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
entities = []
for pre_entity in pre_entities:
entity_idx = pre_entity["scores"].argmax()
score = pre_entity["scores"][entity_idx]
entity = {
"entity": self.model.config.id2label[entity_idx],
"score": score,
"index": pre_entity["index"],
"word": pre_entity["word"],
"start": pre_entity["start"],
"end": pre_entity["end"],
}
entities.append(entity)
else:
entities = self.aggregate_words(pre_entities, aggregation_strategy)
if aggregation_strategy == AggregationStrategy.NONE:
return entities
return self.group_entities(entities)
def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict:
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
if aggregation_strategy == AggregationStrategy.FIRST:
scores = entities[0]["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.model.config.id2label[idx]
elif aggregation_strategy == AggregationStrategy.MAX:
max_entity = max(entities, key=lambda entity: entity["scores"].max())
scores = max_entity["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.model.config.id2label[idx]
elif aggregation_strategy == AggregationStrategy.AVERAGE:
scores = np.stack([entity["scores"] for entity in entities])
average_scores = np.nanmean(scores, axis=0)
entity_idx = average_scores.argmax()
entity = self.model.config.id2label[entity_idx]
score = average_scores[entity_idx]
else:
raise ValueError("Invalid aggregation_strategy")
new_entity = {
"entity": entity,
"score": score,
"word": word,
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return new_entity
[docs] def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
"""
Override tokens from a given word that disagree to force agreement on word boundaries.
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
company| B-ENT I-ENT
"""
assert aggregation_strategy not in {
AggregationStrategy.NONE,
AggregationStrategy.SIMPLE,
}, "NONE and SIMPLE strategies are invalid"
word_entities = []
word_group = None
for entity in entities:
if word_group is None:
word_group = [entity]
elif entity["is_subword"]:
word_group.append(entity)
else:
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
word_group = [entity]
# Last item
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
return word_entities
[docs] def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Args:
entities (:obj:`dict`): The entities predicted by the pipeline.
"""
# Get the first entity in the entity group
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]
entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": self.tokenizer.convert_tokens_to_string(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return entity_group
def get_tag(self, entity_name: str) -> Tuple[str, str]:
if entity_name.startswith("B-"):
bi = "B"
tag = entity_name[2:]
elif entity_name.startswith("I-"):
bi = "I"
tag = entity_name[2:]
else:
# It's not in B-, I- format
bi = "B"
tag = entity_name
return bi, tag
[docs] def group_entities(self, entities: List[dict]) -> List[dict]:
"""
Find and group together the adjacent tokens with the same entity predicted.
Args:
entities (:obj:`dict`): The entities predicted by the pipeline.
"""
entity_groups = []
entity_group_disagg = []
for entity in entities:
if not entity_group_disagg:
entity_group_disagg.append(entity)
continue
# If the current entity is similar and adjacent to the previous entity,
# append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" prefixes
# Shouldn't merge if both entities are B-type
bi, tag = self.get_tag(entity["entity"])
last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])
if tag == last_tag and bi != "B":
# Modify subword type to be previous_type
entity_group_disagg.append(entity)
else:
# If the current entity is different from the previous entity
# aggregate the disaggregated entity group
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity]
if entity_group_disagg:
# it's the last entity, add it to the entity groups
entity_groups.append(self.group_sub_entities(entity_group_disagg))
return entity_groups
NerPipeline = TokenClassificationPipeline