KoichiYasuoka's picture
initial release
cf6f740
raw
history blame
1.92 kB
from transformers import TokenClassificationPipeline
class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
def __init__(self,**kwargs):
import numpy
super().__init__(**kwargs)
x=self.model.config.label2id
y=[k for k in x if not k.startswith("I-")]
self.transition=numpy.full((len(x),len(x)),numpy.nan)
for k,v in x.items():
for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
self.transition[v,x[j]]=0
def check_model_type(self,supported_models):
pass
def postprocess(self,model_outputs,**kwargs):
import numpy
if "logits" not in model_outputs:
return self.postprocess(model_outputs[0],**kwargs)
m=model_outputs["logits"][0].numpy()
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
z=e/e.sum(axis=-1,keepdims=True)
for i in range(m.shape[0]-1,0,-1):
m[i-1]+=numpy.nanmax(m[i]+self.transition,axis=1)
k=[numpy.nanargmax(m[0]+self.transition[0])]
for i in range(1,m.shape[0]):
k.append(numpy.nanargmax(m[i]+self.transition[k[-1]]))
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
for i,t in reversed(list(enumerate(w))):
p=t.pop("entity")
if p.startswith("I-"):
w[i-1]["score"]=min(w[i-1]["score"],t["score"])
w[i-1]["end"]=w.pop(i)["end"]
elif p.startswith("B-"):
t["entity_group"]=p[2:]
else:
t["entity_group"]=p
s=model_outputs["sentence"]
for i,t in enumerate(w):
if t["end"]<len(s):
if s[t["end"]] in {"\u0f0b","\u0f0c"}:
if len(w)-i==1 or t["end"]<w[i+1]["start"]:
t["end"]+=1
t["text"]=s[t["start"]:t["end"]]
return w