mvy commited on
Commit
10b2302
1 Parent(s): 58e3f16
Files changed (2) hide show
  1. app.py +8 -2
  2. ner.py +111 -59
app.py CHANGED
@@ -21,9 +21,15 @@ examples = [
21
  ],
22
  ]
23
 
 
 
24
  gradio_app = gr.Interface(
25
- NER.ner,
26
- inputs = ['text', gr.Textbox(placeholder="Enter sentence here..."), gr.Number(value=0.0, label="treshold")],
 
 
 
 
27
  outputs = [gr.HighlightedText()],
28
  examples=examples,
29
  theme="huggingface",
 
21
  ],
22
  ]
23
 
24
+ ner = NER('knowledgator/UTC-DeBERTa-small')
25
+
26
  gradio_app = gr.Interface(
27
+ ner,
28
+ inputs = [
29
+ 'text',
30
+ gr.Textbox(placeholder="Enter sentence here..."),
31
+ gr.Number(value=0.0, label="threshold")
32
+ ],
33
  outputs = [gr.HighlightedText()],
34
  examples=examples,
35
  theme="huggingface",
ner.py CHANGED
@@ -1,78 +1,130 @@
 
 
 
1
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
2
  import spacy
3
  import torch
4
 
5
- nlp = spacy.load('en_core_web_sm', disable = ['lemmatizer', 'parser', 'tagger', 'ner'])
6
- nlp.add_pipe('sentencizer')
7
-
8
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
9
-
10
-
11
  class NER:
12
- model_name = 'knowledgator/UTC-DeBERTa-small'
13
- prompt="""
14
  Identify entities in the text having the following classes:
15
  {}
16
 
17
  Text:
18
- """
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
- model = AutoModelForTokenClassification.from_pretrained(model_name)
21
- ner_pipeline = pipeline(
22
- "ner",
23
- model=model,
24
- tokenizer=tokenizer,
25
- aggregation_strategy='first',
26
- batch_size=12,
27
- device=device
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- @classmethod
31
- def chunkanize(cls, text, prompt_ = '', n_sents = 10):
32
- doc = nlp(text)
 
 
 
 
33
  chunks = []
34
  starts = []
35
- start = 0
36
- end = 0
37
- proc = False
38
- for id, sent in enumerate(doc.sents, start=1):
39
- if not proc:
40
- start = sent[0].idx
41
- starts.append(start)
42
- proc = True
43
- end = sent[-1].idx+len(sent[-1].text)
44
- if id%n_sents==0:
45
- chunk_text = prompt_+text[start:end]
46
- chunks.append(chunk_text)
47
- proc = False
48
- if proc:
49
- chunk_text = prompt_+text[start:end]
50
- chunks.append(chunk_text)
51
  return chunks, starts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  @classmethod
55
- def ner(cls, labels, text, treshold = 0.):
56
- chunks, starts, classes = [], [], []
57
- label2prompt_len = {}
58
- for label in labels.split(', '):
59
- prompt_ = cls.prompt.format(label)
60
- prompt_len = len(prompt_)
61
- label2prompt_len[label] = prompt_len
62
- curr_chunks, curr_starts = cls.chunkanize(text, prompt_)
63
- curr_labels = [label for _ in range(len(curr_chunks))]
64
- chunks+=curr_chunks
65
- starts+=curr_starts
66
- classes+=curr_labels
 
 
 
 
 
 
 
 
67
  outputs = []
68
- for id, output in enumerate(cls.ner_pipeline(chunks)):
69
- label = classes[id]
70
- prompt_len = label2prompt_len[label]
71
- start = starts[id]-prompt_len
72
  for ent in output:
73
- if ent['score']>treshold:
74
- ent['start'] += start
75
- ent['end'] += start
76
- ent['entity'] = label
77
- outputs.append(ent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return {"text": text, "entities": outputs}
 
1
+ from typing import Tuple
2
+ import string
3
+
4
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
5
  import spacy
6
  import torch
7
 
 
 
 
 
 
 
8
  class NER:
9
+ prompt: str = """
 
10
  Identify entities in the text having the following classes:
11
  {}
12
 
13
  Text:
14
+ """
15
+
16
+ def __init__(self, model_name: str, sents_batch: int=10):
17
+ self.sents_batch = sents_batch
18
+
19
+ self.nlp: spacy.Language = spacy.load(
20
+ 'en_core_web_sm',
21
+ disable = ['lemmatizer', 'parser', 'tagger', 'ner']
22
+ )
23
+ self.nlp.add_pipe('sentencizer')
24
+
25
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
28
+
29
+ self.pipeline = pipeline(
30
+ "ner",
31
+ model=model,
32
+ tokenizer=tokenizer,
33
+ aggregation_strategy='first',
34
+ batch_size=12,
35
+ device=device
36
+ )
37
 
38
+
39
+ def get_last_sentence_id(self, i: int, sentences_len: int) -> int:
40
+ return min(i + self.sents_batch, sentences_len) - 1
41
+
42
+
43
+ def chunkanize(self, text: str) -> Tuple[list[str], list[int]]:
44
+ doc = self.nlp(text)
45
  chunks = []
46
  starts = []
47
+ sentences = list(doc.sents)
48
+
49
+ for i in range(0, len(sentences), self.sents_batch):
50
+ start = sentences[i].start_char
51
+ starts.append(start)
52
+
53
+ last_sentence = self.get_last_sentence_id(i, len(sentences))
54
+ end = sentences[last_sentence].end_char
55
+
56
+ chunks.append(text[start:end])
 
 
 
 
 
 
57
  return chunks, starts
58
+
59
+
60
+ def get_inputs(
61
+ self, chunks: list[str], labels: list[str]
62
+ ) -> Tuple[list[str], list[int]]:
63
+ inputs = []
64
+ prompts_lens = []
65
+
66
+ for label in labels:
67
+ prompt = self.prompt.format(label)
68
+ prompts_lens.append(len(prompt))
69
+ for chunk in chunks:
70
+ inputs.append(prompt + chunk)
71
+
72
+ return inputs, prompts_lens
73
 
74
 
75
  @classmethod
76
+ def clean_span(
77
+ cls, start: int, end: int, span: str
78
+ ) -> Tuple[int, int, str]:
79
+ if len(span) >= 1:
80
+ if span[0] in string.punctuation:
81
+ return cls.clean_span(start+1, end, span[1:])
82
+ if span[-1] in string.punctuation:
83
+ return cls.clean_span(start, end-1, span[:-1])
84
+ return start, end, span.strip()
85
+
86
+
87
+ def predict(
88
+ self,
89
+ text: str,
90
+ inputs: list[str],
91
+ labels: list[str],
92
+ chunks_starts: list[int],
93
+ prompts_lens: list[int],
94
+ threshold: float
95
+ ) -> list[dict[str, any]]:
96
  outputs = []
97
+
98
+ for id, output in enumerate(self.pipeline(inputs)):
99
+ label = labels[id//len(chunks_starts)]
100
+ shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)]
101
  for ent in output:
102
+ start = ent['start'] + shift + 1
103
+ end = ent['end'] + shift
104
+ start, end, span = self.clean_span(start, end, text[start:end])
105
+ if not span:
106
+ continue
107
+
108
+ if ent['score'] >= threshold:
109
+ outputs.append({
110
+ 'span': span,
111
+ 'start': start,
112
+ 'end': end,
113
+ 'entity': label
114
+ })
115
+ return outputs
116
+
117
+
118
+ def __call__(
119
+ self, labels: str, text: str, threshold: float=0.
120
+ ) -> dict[str, any]:
121
+ labels_list = [label.strip() for label in labels.split(',')]
122
+
123
+ chunks, chunks_starts = self.chunkanize(text)
124
+ inputs, prompts_lens = self.get_inputs(chunks, labels_list)
125
+
126
+ outputs = self.predict(
127
+ text, inputs, labels_list, chunks_starts, prompts_lens, threshold
128
+ )
129
+ print(outputs)
130
  return {"text": text, "entities": outputs}