File size: 5,104 Bytes
10b2302
 
 
c1db962
 
 
8e19b14
 
c1db962
 
10b2302
c1db962
 
 
 
10b2302
 
8e19b14
 
 
 
 
 
10b2302
8e19b14
10b2302
 
 
 
 
 
 
 
8e19b14
10b2302
 
 
 
 
8e19b14
10b2302
 
 
 
c1db962
10b2302
 
 
 
 
 
 
c1db962
 
10b2302
 
 
 
 
 
 
 
 
 
c1db962
10b2302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1db962
 
 
10b2302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1db962
10b2302
 
 
 
c1db962
10b2302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e19b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10b2302
 
8e19b14
 
 
 
 
 
 
10b2302
 
 
 
8e19b14
 
10b2302
 
 
c1db962
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from typing import Tuple
import string

from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
import spacy
import torch
import gradio as gr


class NER:
    prompt: str = """
Identify entities in the text having the following classes:
{}

Text:
 """

    def __init__(
        self, 
        model_name: str,
        sents_batch: int=10,
        tokens_limit: int=2048
    ):
        self.sents_batch = sents_batch
        self.tokens_limit = tokens_limit

        self.nlp: spacy.Language = spacy.load(
            'en_core_web_sm', 
            disable = ['lemmatizer', 'parser', 'tagger', 'ner']
        )
        self.nlp.add_pipe('sentencizer')

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForTokenClassification.from_pretrained(model_name)
        
        self.pipeline = pipeline(
            "ner", 
            model=model, 
            tokenizer=self.tokenizer,
            aggregation_strategy='first', 
            batch_size=12,
            device=device
        )
    

    def get_last_sentence_id(self, i: int, sentences_len: int) -> int:
        return min(i + self.sents_batch, sentences_len) - 1


    def chunkanize(self, text: str) -> Tuple[list[str], list[int]]:
        doc = self.nlp(text)
        chunks = []
        starts = []
        sentences = list(doc.sents)

        for i in range(0, len(sentences), self.sents_batch):
            start = sentences[i].start_char
            starts.append(start)

            last_sentence = self.get_last_sentence_id(i, len(sentences))
            end = sentences[last_sentence].end_char

            chunks.append(text[start:end])
        return chunks, starts
    

    def get_inputs(
        self, chunks: list[str], labels: list[str]
    ) -> Tuple[list[str], list[int]]:
        inputs = []
        prompts_lens = []

        for label in labels:
            prompt = self.prompt.format(label)
            prompts_lens.append(len(prompt))
            for chunk in chunks:
                inputs.append(prompt + chunk)

        return inputs, prompts_lens


    @classmethod
    def clean_span(
        cls, start: int, end: int, span: str
    ) -> Tuple[int, int, str]:
        if len(span) >= 1:
            if span[0] in string.punctuation:
                return cls.clean_span(start+1, end, span[1:])
            if span[-1] in string.punctuation:
                return cls.clean_span(start, end-1, span[:-1])
        return start, end, span.strip()


    def predict(
        self,
        text: str,
        inputs: list[str],
        labels: list[str], 
        chunks_starts: list[int], 
        prompts_lens: list[int],
        threshold: float
    ) -> list[dict[str, any]]:
        outputs = []

        for id, output in enumerate(self.pipeline(inputs)):
            label = labels[id//len(chunks_starts)]
            shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)]
            for ent in output:
                start = ent['start'] + shift + 1
                end = ent['end'] + shift
                start, end, span = self.clean_span(start, end, text[start:end])
                if not span:
                    continue
                
                if ent['score'] >= threshold:
                    outputs.append({
                        'span': span,
                        'start': start,
                        'end': end,
                        'entity': label
                    })
        return outputs


    def check_text(self, text: str) -> None:
        if not text:
            raise gr.Error('No text provided. Please provide text.')
        
    
    def check_labels(self, labels: list[str]) -> None:
        if not labels:
            raise gr.Error(
                'No labels provided. Please provide labels.'
                ' Multiple labels should be divided by commas.'
                ' See examples below.'
            )
        

    def check_tokens_limit(self, inputs: list[str]) -> None:
        tokens = 0
        for input_ in inputs:
            tokens += len(self.tokenizer.encode(input_))
            if tokens > self.tokens_limit:
                raise gr.Error(
                    'Too many tokens! Please reduce size of text or amount of labels.'
                    f' Max tokens count is: {self.tokens_limit}.'
                )


    def process(
        self, labels: str, text: str, threshold: float=0.
    ) -> dict[str, any]:
        labels_list = list({
            l for label in labels.split(',')
            if (l:=label.strip())
        })
        
        self.check_labels(labels_list)
        self.check_text(text)

        chunks, chunks_starts = self.chunkanize(text)
        inputs, prompts_lens = self.get_inputs(chunks, labels_list)
        
        self.check_tokens_limit(inputs)

        outputs = self.predict(
            text, inputs, labels_list, chunks_starts, prompts_lens, threshold
        )
        return {"text": text, "entities": outputs}