Spaces:
Running
Running
File size: 887 Bytes
55dc8b1 e4f39c4 55dc8b1 e4f39c4 f2f4fc6 e4f39c4 f92dd51 e4f39c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import numpy as np
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
TokenClassificationPipeline,
)
from transformers.pipelines import AggregationStrategy
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, model_outputs):
results = super().postprocess(
model_outputs=model_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE
if self.model.config.model_type == "roberta"
else AggregationStrategy.FIRST,
)
return np.unique([result.get("word").strip() for result in results])
|