maximuspowers commited on
Commit
53c6897
1 Parent(s): b8e54ae

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -30
pipeline.py DELETED
@@ -1,30 +0,0 @@
1
- from transformers import PIPELINE_REGISTRY, TokenClassificationPipeline
2
- import torch
3
-
4
- @PIPELINE_REGISTRY.register_pipeline(task="multi_label_token_classification", pipeline_class=None)
5
- class MultiLabelTokenClassificationPipeline(TokenClassificationPipeline):
6
- def __init__(self, model, tokenizer, **kwargs):
7
- super().__init__(model=model, tokenizer=tokenizer, **kwargs)
8
- self.id2label = {
9
- 0: 'O',
10
- 1: 'B-STEREO',
11
- 2: 'I-STEREO',
12
- 3: 'B-GEN',
13
- 4: 'I-GEN',
14
- 5: 'B-UNFAIR',
15
- 6: 'I-UNFAIR'
16
- }
17
-
18
- def postprocess(self, model_outputs, **kwargs):
19
- results = []
20
- for logits, tokens in zip(model_outputs[0], model_outputs[1]):
21
- probabilities = torch.sigmoid(logits)
22
- predicted_labels = (probabilities > 0.5).int()
23
- token_results = []
24
- for i, token in enumerate(tokens):
25
- if token not in self.tokenizer.all_special_tokens:
26
- label_indices = (predicted_labels[i] == 1).nonzero(as_tuple=False).squeeze(-1)
27
- labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
28
- token_results.append({"token": token, "labels": labels})
29
- results.append(token_results)
30
- return results