File size: 912 Bytes
775f69c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import TokenClassificationPipeline, AutoModelForTokenClassification, AutoTokenizer
from transformers.pipelines import AggregationStrategy
import numpy as np
import configparser

config = configparser.ConfigParser()
config.read("src/configs/config.cfg")
embed_config = config["EMBEDDINGS"]

class KeyphraseExtractionPipeline(TokenClassificationPipeline):

    def __init__(self,):
        super().__init__(
            model=AutoModelForTokenClassification.from_pretrained(str(embed_config["KEYWORD_EXTRACTOR"])),
            tokenizer=AutoTokenizer.from_pretrained(embed_config["KEYWORD_EXTRACTOR"], device_map = 'cuda')
        )

    def postprocess(self, all_outputs):
        results = super().postprocess(
            all_outputs=all_outputs,
            aggregation_strategy=AggregationStrategy.FIRST,
        )
        return np.unique([result.get("word").strip() for result in results])