not-lain commited on
Commit
e4ca3a6
·
1 Parent(s): 0ebf01d

commit files to HF hub

Browse files
Files changed (1) hide show
  1. tunBertClassificationPipeline.py +18 -4
tunBertClassificationPipeline.py CHANGED
@@ -1,7 +1,11 @@
1
- from transformers import Pipeline
2
  import torch
 
3
 
4
  class TBCP(Pipeline):
 
 
 
5
  def _sanitize_parameters(self, **kwargs):
6
  postprocess_kwargs = {}
7
  if "text_pair" in kwargs:
@@ -19,7 +23,17 @@ class TBCP(Pipeline):
19
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
20
 
21
  best_class = probabilities.argmax().item()
22
- label = self.model.config.id2label[best_class]
23
- score = probabilities.squeeze()[best_class].item()
24
  logits = logits.squeeze().tolist()
25
- return {"label": label, "score": score, "logits": logits}
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline, AutoModelForSequenceClassification,AutoTokenizer
2
  import torch
3
+ from transformers.pipelines import PIPELINE_REGISTRY
4
 
5
  class TBCP(Pipeline):
6
+ def __init__(self,**kwargs):
7
+ Pipeline.__init__(self,**kwargs)
8
+ self.tokenizer = AutoTokenizer.from_pretrained(kwargs["tokenizer"])
9
  def _sanitize_parameters(self, **kwargs):
10
  postprocess_kwargs = {}
11
  if "text_pair" in kwargs:
 
23
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
24
 
25
  best_class = probabilities.argmax().item()
26
+ label = f"Label_{best_class}"
27
+ # score = probabilities.squeeze()[best_class].item()
28
  logits = logits.squeeze().tolist()
29
+ return {"label": label,
30
+ # "score": score,
31
+ "logits": logits}
32
+
33
+ PIPELINE_REGISTRY.register_pipeline(
34
+ "TunBERT-classifier",
35
+ pipeline_class=TBCP,
36
+ pt_model=AutoModelForSequenceClassification,
37
+ default={"pt": ("not-lain/TunBERT", "main")},
38
+ type="text", # current support type: text, audio, image, multimodal
39
+ )