File size: 887 Bytes
55dc8b1
e4f39c4
 
 
55dc8b1
e4f39c4
 
 
 
 
 
 
 
f2f4fc6
e4f39c4
 
 
 
 
 
 
f92dd51
 
 
e4f39c4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    TokenClassificationPipeline,
)
from transformers.pipelines import AggregationStrategy


class KeyphraseExtractionPipeline(TokenClassificationPipeline):
    def __init__(self, model, *args, **kwargs):
        super().__init__(
            model=AutoModelForTokenClassification.from_pretrained(model),
            tokenizer=AutoTokenizer.from_pretrained(model),
            *args,
            **kwargs
        )

    def postprocess(self, model_outputs):
        results = super().postprocess(
            model_outputs=model_outputs,
            aggregation_strategy=AggregationStrategy.SIMPLE
            if self.model.config.model_type == "roberta"
            else AggregationStrategy.FIRST,
        )
        return np.unique([result.get("word").strip() for result in results])