maximuspowers commited on
Commit
14023f9
1 Parent(s): f0d4714

pipeline registry?

Browse files
Files changed (1) hide show
  1. pipeline.py +16 -28
pipeline.py CHANGED
@@ -1,15 +1,10 @@
1
- from typing import List, Dict
2
- import json
3
  import torch
4
- from transformers import BertTokenizerFast, BertForTokenClassification
5
-
6
- class BiasNERPipeline:
7
- def __init__(self, model_path: str = 'maximuspowers/bias-detection-ner'):
8
- self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
9
- self.model = BertForTokenClassification.from_pretrained(model_path)
10
- self.model.eval()
11
- self.model.to('cuda' if torch.cuda.is_available() else 'cpu')
12
 
 
 
 
 
13
  self.id2label = {
14
  0: 'O',
15
  1: 'B-STEREO',
@@ -20,23 +15,16 @@ class BiasNERPipeline:
20
  6: 'I-UNFAIR'
21
  }
22
 
23
- def __call__(self, inputs: str) -> str:
24
- tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=128)
25
- input_ids = tokenized_inputs['input_ids'].to(self.model.device)
26
- attention_mask = tokenized_inputs['attention_mask'].to(self.model.device)
27
-
28
- with torch.no_grad():
29
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
30
- logits = outputs.logits
31
  probabilities = torch.sigmoid(logits)
32
  predicted_labels = (probabilities > 0.5).int()
33
-
34
- result = []
35
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
36
- for i, token in enumerate(tokens):
37
- if token not in self.tokenizer.all_special_tokens:
38
- label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
39
- labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
40
- result.append({"token": token, "labels": labels})
41
-
42
- return json.dumps(result, indent=4)
 
1
+ from transformers import PIPELINE_REGISTRY, TokenClassificationPipeline
 
2
  import torch
 
 
 
 
 
 
 
 
3
 
4
+ @PIPELINE_REGISTRY.register_pipeline(task="multi_label_token_classification", pipeline_class=None)
5
+ class MultiLabelTokenClassificationPipeline(TokenClassificationPipeline):
6
+ def __init__(self, model, tokenizer, **kwargs):
7
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
8
  self.id2label = {
9
  0: 'O',
10
  1: 'B-STEREO',
 
15
  6: 'I-UNFAIR'
16
  }
17
 
18
+ def postprocess(self, model_outputs, **kwargs):
19
+ results = []
20
+ for logits, tokens in zip(model_outputs[0], model_outputs[1]):
 
 
 
 
 
21
  probabilities = torch.sigmoid(logits)
22
  predicted_labels = (probabilities > 0.5).int()
23
+ token_results = []
24
+ for i, token in enumerate(tokens):
25
+ if token not in self.tokenizer.all_special_tokens:
26
+ label_indices = (predicted_labels[i] == 1).nonzero(as_tuple=False).squeeze(-1)
27
+ labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
28
+ token_results.append({"token": token, "labels": labels})
29
+ results.append(token_results)
30
+ return results