emanuelaboros commited on
Commit
51adc94
·
verified ·
1 Parent(s): 471ce47

Update generic_ner.py

Browse files
Files changed (1) hide show
  1. generic_ner.py +48 -34
generic_ner.py CHANGED
@@ -200,58 +200,72 @@ class MultitaskTokenClassificationPipeline(Pipeline):
200
  }
201
  return preprocess_kwargs, {}, {}
202
 
203
- def preprocess(self, text, **kwargs):
 
 
 
 
 
 
 
204
 
205
- tokenized_inputs = self.tokenizer(
206
- text, padding="max_length", truncation=True, max_length=512
207
- )
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- text_sentence = tokenize(add_spaces_around_punctuation(text))
210
- return tokenized_inputs, text_sentence, text
211
 
212
  def _forward(self, inputs):
213
- inputs, text_sentences, text = inputs
214
- input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
215
- self.model.device
216
- )
217
- attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
218
- self.model.device
219
- )
220
  with torch.no_grad():
221
- outputs = self.model(input_ids, attention_mask)
222
- return outputs, text_sentences, text
223
-
 
 
224
 
225
  def postprocess(self, outputs, **kwargs):
226
- """
227
- Postprocess the outputs of the model
228
- :param outputs:
229
- :param kwargs:
230
- :return:
231
- """
232
- tokens_result, text_sentence, text = outputs
233
-
234
- predictions = {}
235
- confidence_scores = {}
236
- for task, logits in tokens_result.logits.items():
237
- predictions[task] = torch.argmax(logits, dim=-1).tolist()
238
- confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
239
-
240
  decoded_predictions = {}
241
  for task, preds in predictions.items():
242
  decoded_predictions[task] = [
243
  [self.id2label[task][label] for label in seq] for seq in preds
244
  ]
 
 
245
  entities = {}
246
  for task, preds in predictions.items():
247
  words_list, preds_list, confidence_list = realign(
248
- text_sentence,
249
- preds[0],
250
- confidence_scores[task][0],
251
  self.tokenizer,
252
  self.id2label[task],
253
  )
254
-
255
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
256
 
257
  return entities
 
200
  }
201
  return preprocess_kwargs, {}, {}
202
 
203
+ def chunk_text_exact(self, text, tokenizer, max_subtokens):
204
+ """
205
+ Splits text into exact subtoken chunks based on the tokenizer's max length.
206
+ """
207
+ subtokens = tokenizer.encode(text, add_special_tokens=False)
208
+ for i in range(0, len(subtokens), max_subtokens):
209
+ chunk = subtokens[i : i + max_subtokens]
210
+ yield tokenizer.decode(chunk, clean_up_tokenization_spaces=False)
211
 
212
+ def preprocess(self, text, **kwargs):
213
+ # Get the model's max input length
214
+ max_input_length = self.tokenizer.model_max_length - 2 # Reserve space for [CLS] and [SEP]
215
+
216
+ # Split the text into subtoken chunks
217
+ text_chunks = list(self.chunk_text_exact(text, self.tokenizer, max_input_length))
218
+ print(text_chunks)
219
+ # Tokenize and add special tokens for each chunk
220
+ tokenized_chunks = [
221
+ self.tokenizer(
222
+ chunk, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
223
+ )
224
+ for chunk in text_chunks
225
+ ]
226
 
227
+ return tokenized_chunks, text_chunks, text
 
228
 
229
  def _forward(self, inputs):
230
+ tokenized_chunks, text_chunks, text = inputs
231
+ outputs = []
 
 
 
 
 
232
  with torch.no_grad():
233
+ for tokenized_input in tokenized_chunks:
234
+ input_ids = torch.tensor([tokenized_input["input_ids"]], dtype=torch.long).to(self.model.device)
235
+ attention_mask = torch.tensor([tokenized_input["attention_mask"]], dtype=torch.long).to(self.model.device)
236
+ outputs.append(self.model(input_ids, attention_mask))
237
+ return outputs, text_chunks, text
238
 
239
  def postprocess(self, outputs, **kwargs):
240
+ tokens_result, text_chunks, text = outputs
241
+
242
+ # Initialize variables for collecting results across chunks
243
+ predictions = {task: [] for task in self.label_map.keys()}
244
+ confidence_scores = {task: [] for task in self.label_map.keys()}
245
+
246
+ # Collect predictions from each chunk
247
+ for chunk_result in tokens_result:
248
+ for task, logits in chunk_result.logits.items():
249
+ predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
250
+ confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
251
+ print(predictions)
252
+ # Decode and process the predictions
 
253
  decoded_predictions = {}
254
  for task, preds in predictions.items():
255
  decoded_predictions[task] = [
256
  [self.id2label[task][label] for label in seq] for seq in preds
257
  ]
258
+ print(decoded_predictions)
259
+ # Extract entities from the combined predictions
260
  entities = {}
261
  for task, preds in predictions.items():
262
  words_list, preds_list, confidence_list = realign(
263
+ text_chunks,
264
+ preds,
265
+ confidence_scores[task],
266
  self.tokenizer,
267
  self.id2label[task],
268
  )
 
269
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
270
 
271
  return entities