mvy commited on
Commit
8e19b14
1 Parent(s): 6f3f044

add validations checks

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. ner.py +46 -5
app.py CHANGED
@@ -24,7 +24,7 @@ examples = [
24
  ner = NER('knowledgator/UTC-DeBERTa-small')
25
 
26
  gradio_app = gr.Interface(
27
- ner,
28
  inputs = [
29
  'text',
30
  gr.Textbox(placeholder="Enter sentence here..."),
 
24
  ner = NER('knowledgator/UTC-DeBERTa-small')
25
 
26
  gradio_app = gr.Interface(
27
+ ner.process,
28
  inputs = [
29
  'text',
30
  gr.Textbox(placeholder="Enter sentence here..."),
ner.py CHANGED
@@ -4,6 +4,8 @@ import string
4
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
5
  import spacy
6
  import torch
 
 
7
 
8
  class NER:
9
  prompt: str = """
@@ -13,8 +15,14 @@ Identify entities in the text having the following classes:
13
  Text:
14
  """
15
 
16
- def __init__(self, model_name: str, sents_batch: int=10):
 
 
 
 
 
17
  self.sents_batch = sents_batch
 
18
 
19
  self.nlp: spacy.Language = spacy.load(
20
  'en_core_web_sm',
@@ -23,13 +31,13 @@ Text:
23
  self.nlp.add_pipe('sentencizer')
24
 
25
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26
- tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  model = AutoModelForTokenClassification.from_pretrained(model_name)
28
 
29
  self.pipeline = pipeline(
30
  "ner",
31
  model=model,
32
- tokenizer=tokenizer,
33
  aggregation_strategy='first',
34
  batch_size=12,
35
  device=device
@@ -115,14 +123,47 @@ Text:
115
  return outputs
116
 
117
 
118
- def __call__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  self, labels: str, text: str, threshold: float=0.
120
  ) -> dict[str, any]:
121
- labels_list = [label.strip() for label in labels.split(',')]
 
 
 
 
 
 
122
 
123
  chunks, chunks_starts = self.chunkanize(text)
124
  inputs, prompts_lens = self.get_inputs(chunks, labels_list)
125
 
 
 
126
  outputs = self.predict(
127
  text, inputs, labels_list, chunks_starts, prompts_lens, threshold
128
  )
 
4
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
5
  import spacy
6
  import torch
7
+ import gradio as gr
8
+
9
 
10
  class NER:
11
  prompt: str = """
 
15
  Text:
16
  """
17
 
18
+ def __init__(
19
+ self,
20
+ model_name: str,
21
+ sents_batch: int=10,
22
+ tokens_limit: int=2048
23
+ ):
24
  self.sents_batch = sents_batch
25
+ self.tokens_limit = tokens_limit
26
 
27
  self.nlp: spacy.Language = spacy.load(
28
  'en_core_web_sm',
 
31
  self.nlp.add_pipe('sentencizer')
32
 
33
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
35
  model = AutoModelForTokenClassification.from_pretrained(model_name)
36
 
37
  self.pipeline = pipeline(
38
  "ner",
39
  model=model,
40
+ tokenizer=self.tokenizer,
41
  aggregation_strategy='first',
42
  batch_size=12,
43
  device=device
 
123
  return outputs
124
 
125
 
126
+ def check_text(self, text: str) -> None:
127
+ if not text:
128
+ raise gr.Error('No text provided. Please provide text.')
129
+
130
+
131
+ def check_labels(self, labels: list[str]) -> None:
132
+ if not labels:
133
+ raise gr.Error(
134
+ 'No labels provided. Please provide labels.'
135
+ ' Multiple labels should be divided by commas.'
136
+ ' See examples below.'
137
+ )
138
+
139
+
140
+ def check_tokens_limit(self, inputs: list[str]) -> None:
141
+ tokens = 0
142
+ for input_ in inputs:
143
+ tokens += len(self.tokenizer.encode(input_))
144
+ if tokens > self.tokens_limit:
145
+ raise gr.Error(
146
+ 'Too many tokens! Please reduce size of text or amount of labels.'
147
+ f' Max tokens count is: {self.tokens_limit}.'
148
+ )
149
+
150
+
151
+ def process(
152
  self, labels: str, text: str, threshold: float=0.
153
  ) -> dict[str, any]:
154
+ labels_list = list({
155
+ l for label in labels.split(',')
156
+ if (l:=label.strip())
157
+ })
158
+
159
+ self.check_labels(labels_list)
160
+ self.check_text(text)
161
 
162
  chunks, chunks_starts = self.chunkanize(text)
163
  inputs, prompts_lens = self.get_inputs(chunks, labels_list)
164
 
165
+ self.check_tokens_limit(inputs)
166
+
167
  outputs = self.predict(
168
  text, inputs, labels_list, chunks_starts, prompts_lens, threshold
169
  )