from transformers import ( AutoModelForTokenClassification, AutoTokenizer, TokenClassificationPipeline, ) from transformers.pipelines import AggregationStrategy import numpy as np class KeyphraseExtractionPipeline(TokenClassificationPipeline): def __init__(self, model, *args, **kwargs): super().__init__( model=AutoModelForTokenClassification.from_pretrained(model), tokenizer=AutoTokenizer.from_pretrained(model), *args, **kwargs ) def postprocess(self, model_outputs): results = super().postprocess( model_outputs=model_outputs, aggregation_strategy=AggregationStrategy.SIMPLE if self.model.config.model_type == "roberta" else AggregationStrategy.FIRST, ) return np.unique([result.get("word").strip() for result in results])