Source code for transformers.pipelines.token_classification

import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np

from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline


if TYPE_CHECKING:
    from ..modeling_tf_utils import TFPreTrainedModel
    from ..modeling_utils import PreTrainedModel

if is_tf_available():

    from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING

if is_torch_available():
    import torch

    from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING


class TokenClassificationArgumentHandler(ArgumentHandler):
    """
    Handles arguments for token classification.
    """

    def __call__(self, inputs: Union[str, List[str]], **kwargs):

        if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
            inputs = list(inputs)
            batch_size = len(inputs)
        elif isinstance(inputs, str):
            inputs = [inputs]
            batch_size = 1
        else:
            raise ValueError("At least one input is required.")

        offset_mapping = kwargs.get("offset_mapping")
        if offset_mapping:
            if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
                offset_mapping = [offset_mapping]
            if len(offset_mapping) != batch_size:
                raise ValueError("offset_mapping should have the same batch size as the input")
        return inputs, offset_mapping


class AggregationStrategy(ExplicitEnum):
    """All the valid aggregation strategies for TokenClassificationPipeline"""

    NONE = "none"
    SIMPLE = "simple"
    FIRST = "first"
    AVERAGE = "average"
    MAX = "max"


[docs]@add_end_docstrings( PIPELINE_INIT_ARGS, r""" ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`): A list of labels to ignore. grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`): DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to the same entity together in the predictions or not. aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction. - "none" : Will simply not do any aggregation and simply return raw results from the model - "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C, I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D", "entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as different entities. On word based languages, we might end up splitting words undesirably : Imagine Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity": "NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages that support that meaning, which is basically tokens separated by a space). These mitigations will only work on real words, "New york" might still be tagged with two different entities. - "first" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, cannot end up with different tags. Words will simply use the tag of the first token of the word when there is ambiguity. - "average" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, cannot end up with different tags. scores will be averaged first across tokens, and then the maximum label is applied. - "max" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, cannot end up with different tags. Word entity will simply be the token with the maximum score. """, ) class TokenClassificationPipeline(Pipeline): """ Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the `named entity recognition examples <../task_summary.html#named-entity-recognition>`__ for more information. This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location or miscellaneous). The models that this pipeline can use are models that have been fine-tuned on a token classification task. See the up-to-date list of available models on `huggingface.co/models <https://huggingface.co/models?filter=token-classification>`__. """ default_input_names = "sequences" def __init__( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], tokenizer: PreTrainedTokenizer, modelcard: Optional[ModelCard] = None, framework: Optional[str] = None, args_parser: ArgumentHandler = TokenClassificationArgumentHandler(), device: int = -1, binary_output: bool = False, ignore_labels=["O"], task: str = "", grouped_entities: Optional[bool] = None, ignore_subwords: Optional[bool] = None, aggregation_strategy: Optional[AggregationStrategy] = None, ): super().__init__( model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, device=device, binary_output=binary_output, task=task, ) self.check_model_type( TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING if self.framework == "tf" else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING ) self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._args_parser = args_parser self.ignore_labels = ignore_labels if aggregation_strategy is None: aggregation_strategy = AggregationStrategy.NONE if grouped_entities is not None or ignore_subwords is not None: if grouped_entities and ignore_subwords: aggregation_strategy = AggregationStrategy.FIRST elif grouped_entities and not ignore_subwords: aggregation_strategy = AggregationStrategy.SIMPLE else: aggregation_strategy = AggregationStrategy.NONE if grouped_entities is not None: warnings.warn( f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' ) if ignore_subwords is not None: warnings.warn( f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' ) if isinstance(aggregation_strategy, str): aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()] if ( aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE} and not self.tokenizer.is_fast ): raise ValueError( "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" 'to `"simple"` or use a fast tokenizer.' ) self.aggregation_strategy = aggregation_strategy
[docs] def __call__(self, inputs: Union[str, List[str]], **kwargs): """ Classify each token of the text(s) given as inputs. Args: inputs (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) for token classification. Return: A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy) with the following keys: - **word** (:obj:`str`) -- The token/word classified. - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`. - **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when `aggregation_strategy` is not :obj:`"none"`. - **index** (:obj:`int`, only present when ``aggregation_strategy="none"``) -- The index of the corresponding token in the sentence. - **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence. Only exists if the offsets are available within the tokenizer - **end** (:obj:`int`, `optional`) -- The index of the end of the corresponding entity in the sentence. Only exists if the offsets are available within the tokenizer """ _inputs, offset_mappings = self._args_parser(inputs, **kwargs) answers = [] for i, sentence in enumerate(_inputs): # Manage correct placement of the tensors with self.device_placement(): tokens = self.tokenizer( sentence, return_attention_mask=False, return_tensors=self.framework, truncation=True, return_special_tokens_mask=True, return_offsets_mapping=self.tokenizer.is_fast, ) if self.tokenizer.is_fast: offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0] elif offset_mappings: offset_mapping = offset_mappings[i] else: offset_mapping = None special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0] # Forward if self.framework == "tf": entities = self.model(tokens.data)[0][0].numpy() input_ids = tokens["input_ids"].numpy()[0] else: with torch.no_grad(): tokens = self.ensure_tensor_on_device(**tokens) entities = self.model(**tokens)[0][0].cpu().numpy() input_ids = tokens["input_ids"].cpu().numpy()[0] scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) pre_entities = self.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask) grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy) # Filter anything that is in self.ignore_labels entities = [ entity for entity in grouped_entities if entity.get("entity", None) not in self.ignore_labels and entity.get("entity_group", None) not in self.ignore_labels ] answers.append(entities) if len(answers) == 1: return answers[0] return answers
[docs] def gather_pre_entities( self, sentence: str, input_ids: np.ndarray, scores: np.ndarray, offset_mapping: Optional[List[Tuple[int, int]]], special_tokens_mask: np.ndarray, ) -> List[dict]: """Fuse various numpy arrays into dicts with all the information needed for aggregation""" pre_entities = [] for idx, token_scores in enumerate(scores): # Filter special_tokens, they should only occur # at the sentence boundaries since we're not encoding pairs of # sentences so we don't have to keep track of those. if special_tokens_mask[idx]: continue word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) if offset_mapping is not None: start_ind, end_ind = offset_mapping[idx] word_ref = sentence[start_ind:end_ind] if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None): # This is a BPE, word aware tokenizer, there is a correct way # to fuse tokens is_subword = len(word) != len(word_ref) else: # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately. if self.aggregation_strategy in { AggregationStrategy.FIRST, AggregationStrategy.AVERAGE, AggregationStrategy.MAX, }: warnings.warn("Tokenizer does not support real words, using fallback heuristic", UserWarning) is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False if int(input_ids[idx]) == self.tokenizer.unk_token_id: word = word_ref is_subword = False else: start_ind = None end_ind = None is_subword = False pre_entity = { "word": word, "scores": token_scores, "start": start_ind, "end": end_ind, "index": idx, "is_subword": is_subword, } pre_entities.append(pre_entity) return pre_entities
def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}: entities = [] for pre_entity in pre_entities: entity_idx = pre_entity["scores"].argmax() score = pre_entity["scores"][entity_idx] entity = { "entity": self.model.config.id2label[entity_idx], "score": score, "index": pre_entity["index"], "word": pre_entity["word"], "start": pre_entity["start"], "end": pre_entity["end"], } entities.append(entity) else: entities = self.aggregate_words(pre_entities, aggregation_strategy) if aggregation_strategy == AggregationStrategy.NONE: return entities return self.group_entities(entities) def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict: word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities]) if aggregation_strategy == AggregationStrategy.FIRST: scores = entities[0]["scores"] idx = scores.argmax() score = scores[idx] entity = self.model.config.id2label[idx] elif aggregation_strategy == AggregationStrategy.MAX: max_entity = max(entities, key=lambda entity: entity["scores"].max()) scores = max_entity["scores"] idx = scores.argmax() score = scores[idx] entity = self.model.config.id2label[idx] elif aggregation_strategy == AggregationStrategy.AVERAGE: scores = np.stack([entity["scores"] for entity in entities]) average_scores = np.nanmean(scores, axis=0) entity_idx = average_scores.argmax() entity = self.model.config.id2label[entity_idx] score = average_scores[entity_idx] else: raise ValueError("Invalid aggregation_strategy") new_entity = { "entity": entity, "score": score, "word": word, "start": entities[0]["start"], "end": entities[-1]["end"], } return new_entity
[docs] def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: """ Override tokens from a given word that disagree to force agreement on word boundaries. Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| company| B-ENT I-ENT """ assert aggregation_strategy not in { AggregationStrategy.NONE, AggregationStrategy.SIMPLE, }, "NONE and SIMPLE strategies are invalid" word_entities = [] word_group = None for entity in entities: if word_group is None: word_group = [entity] elif entity["is_subword"]: word_group.append(entity) else: word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) word_group = [entity] # Last item word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) return word_entities
[docs] def group_sub_entities(self, entities: List[dict]) -> dict: """ Group together the adjacent tokens with the same entity predicted. Args: entities (:obj:`dict`): The entities predicted by the pipeline. """ # Get the first entity in the entity group entity = entities[0]["entity"].split("-")[-1] scores = np.nanmean([entity["score"] for entity in entities]) tokens = [entity["word"] for entity in entities] entity_group = { "entity_group": entity, "score": np.mean(scores), "word": self.tokenizer.convert_tokens_to_string(tokens), "start": entities[0]["start"], "end": entities[-1]["end"], } return entity_group
def get_tag(self, entity_name: str) -> Tuple[str, str]: if entity_name.startswith("B-"): bi = "B" tag = entity_name[2:] elif entity_name.startswith("I-"): bi = "I" tag = entity_name[2:] else: # It's not in B-, I- format # Default to I- for continuation. bi = "I" tag = entity_name return bi, tag
[docs] def group_entities(self, entities: List[dict]) -> List[dict]: """ Find and group together the adjacent tokens with the same entity predicted. Args: entities (:obj:`dict`): The entities predicted by the pipeline. """ entity_groups = [] entity_group_disagg = [] for entity in entities: if not entity_group_disagg: entity_group_disagg.append(entity) continue # If the current entity is similar and adjacent to the previous entity, # append it to the disaggregated entity group # The split is meant to account for the "B" and "I" prefixes # Shouldn't merge if both entities are B-type bi, tag = self.get_tag(entity["entity"]) last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"]) if tag == last_tag and bi != "B": # Modify subword type to be previous_type entity_group_disagg.append(entity) else: # If the current entity is different from the previous entity # aggregate the disaggregated entity group entity_groups.append(self.group_sub_entities(entity_group_disagg)) entity_group_disagg = [entity] if entity_group_disagg: # it's the last entity, add it to the entity groups entity_groups.append(self.group_sub_entities(entity_group_disagg)) return entity_groups
NerPipeline = TokenClassificationPipeline