KoichiYasuoka commited on
Commit
6de4e01
·
1 Parent(s): b352d69

BellmanFordTokenClassificationPipeline

Browse files
Files changed (2) hide show
  1. config.json +6 -0
  2. upos.py +42 -0
config.json CHANGED
@@ -4,6 +4,12 @@
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
  "bos_token_id": 0,
 
 
 
 
 
 
7
  "eos_token_id": 2,
8
  "hidden_act": "gelu",
9
  "hidden_dropout_prob": 0.1,
 
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
  "bos_token_id": 0,
7
+ "custom_pipelines": {
8
+ "upos": {
9
+ "impl": "upos.BellmanFordTokenClassificationPipeline",
10
+ "pt": "AutoModelForTokenClassification"
11
+ }
12
+ },
13
  "eos_token_id": 2,
14
  "hidden_act": "gelu",
15
  "hidden_dropout_prob": 0.1,
upos.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TokenClassificationPipeline
2
+ from transformers.modeling_outputs import TokenClassifierOutput
3
+
4
+ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
5
+ def __init__(self,**kwargs):
6
+ import numpy
7
+ super().__init__(**kwargs)
8
+ x=self.model.config.label2id
9
+ y=[k for k in x if not k.startswith("I-")]
10
+ self.transition=numpy.full((len(x),len(x)),numpy.nan)
11
+ for k,v in x.items():
12
+ for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
13
+ self.transition[v,x[j]]=0
14
+ def check_model_type(self,supported_models):
15
+ pass
16
+ def postprocess(self,model_outputs,**kwargs):
17
+ import numpy
18
+ if "logits" not in model_outputs:
19
+ return self.postprocess(model_outputs[0],**kwargs)
20
+ m=model_outputs["logits"][0].numpy()
21
+ e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
22
+ z=e/e.sum(axis=-1,keepdims=True)
23
+ for i in range(m.shape[0]-1,0,-1):
24
+ m[i-1]+=numpy.nanmax(m[i]+self.transition,axis=1)
25
+ k=[numpy.nanargmax(m[0])]
26
+ for i in range(1,m.shape[0]):
27
+ k.append(numpy.nanargmax(m[i]+self.transition[k[-1]]))
28
+ w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
29
+ if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
30
+ for i,t in reversed(list(enumerate(w))):
31
+ p=t.pop("entity")
32
+ if p.startswith("I-"):
33
+ w[i-1]["score"]=min(w[i-1]["score"],t["score"])
34
+ w[i-1]["end"]=w.pop(i)["end"]
35
+ elif p.startswith("B-"):
36
+ t["entity_group"]=p[2:]
37
+ else:
38
+ t["entity_group"]=p
39
+ for t in w:
40
+ t["text"]=model_outputs["sentence"][t["start"]:t["end"]]
41
+ return w
42
+