File size: 5,277 Bytes
7e39f2e
 
 
 
 
 
 
 
cdcea0f
7e39f2e
 
 
 
 
 
 
 
cdcea0f
 
7e39f2e
 
 
 
cdcea0f
 
7e39f2e
cdcea0f
7e39f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdcea0f
 
 
7e39f2e
 
 
 
 
 
 
 
 
cdcea0f
7e39f2e
 
cdcea0f
 
7e39f2e
 
 
 
 
 
 
 
 
b6d3162
7e39f2e
b6d3162
7e39f2e
b6d3162
 
 
7e39f2e
 
 
 
cdcea0f
7e39f2e
 
 
cdcea0f
7e39f2e
 
 
 
 
 
 
 
 
cdcea0f
7e39f2e
 
 
 
 
 
 
 
 
 
cdcea0f
 
 
7e39f2e
cdcea0f
7e39f2e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy
from transformers import TokenClassificationPipeline

class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
  def __init__(self,**kwargs):
    super().__init__(**kwargs)
    x=self.model.config.label2id
    y=[k for k in x if k.startswith("B-") or not (k.startswith("I-") or k.endswith("|root") or k.find("|l-")>0 or k.find("|r-")>0)]
    self.transition=numpy.full((len(x),len(x)),-numpy.inf)
    for k,v in x.items():
      for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
        self.transition[v,x[j]]=0
  def check_model_type(self,supported_models):
    pass
  def postprocess(self,model_outputs,**kwargs):
    if "logits" not in model_outputs:
      return self.postprocess(model_outputs[0],**kwargs)
    return self.bellman_ford_token_classification(model_outputs,**kwargs)
  def bellman_ford_token_classification(self,model_outputs,**kwargs):
    m=model_outputs["logits"][0].numpy()
    e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
    z=e/e.sum(axis=-1,keepdims=True)
    for i in range(m.shape[0]-1,0,-1):
      m[i-1]+=numpy.max(m[i]+self.transition,axis=1)
    k=[numpy.argmax(m[0]+self.transition[0])]
    for i in range(1,m.shape[0]):
      k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
    w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
    if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
      for i,t in reversed(list(enumerate(w))):
        p=t.pop("entity")
        if p.startswith("I-"):
          w[i-1]["score"]=min(w[i-1]["score"],t["score"])
          w[i-1]["end"]=w.pop(i)["end"]
        elif p.startswith("B-"):
          t["entity_group"]=p[2:]
        else:
          t["entity_group"]=p
    for t in w:
      t["text"]=model_outputs["sentence"][t["start"]:t["end"]]
    return w

class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline):
  def __init__(self,**kwargs):
    kwargs["aggregation_strategy"]="simple"
    super().__init__(**kwargs)
    x=self.model.config.label2id
    self.root=numpy.full((len(x)),-numpy.inf)
    self.left_arc=numpy.full((len(x)),-numpy.inf)
    self.right_arc=numpy.full((len(x)),-numpy.inf)
    for k,v in x.items():
      if k.endswith("|root"):
        self.root[v]=0
      elif k.find("|l-")>0:
        self.left_arc[v]=0
      elif k.find("|r-")>0:
        self.right_arc[v]=0
  def postprocess(self,model_outputs,**kwargs):
    import torch
    kwargs["aggregation_strategy"]="simple"
    if "logits" not in model_outputs:
      return self.postprocess(model_outputs[0],**kwargs)
    w=self.bellman_ford_token_classification(model_outputs,**kwargs)
    d=[t["text"] for t in w]
    v=self.tokenizer(d,add_special_tokens=False)
    e=self.model.get_input_embeddings().weight
    m=[]
    for x in v["input_ids"]:
      if x==[]:
        x=[self.tokenizer.unk_token_id]
      m.append(e[x,:].sum(axis=0))
    m.append(e[self.tokenizer.sep_token_id,:])
    m.append(e[self.tokenizer.pad_token_id,:])
    m=torch.stack(m).to(self.device)
    k=list(range(len(d)+1))
    e=[]
    with torch.no_grad():
      for i in range(len(d)):
        e.append(self.model(inputs_embeds=torch.unsqueeze(m[k+list(range(i,len(d)))+[-1]*i,:],0)).logits[0,-len(d):,:])
    e=torch.stack(e).cpu().numpy()
    for i in range(len(d)):
      for j in range(i):
        e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
      e[-i-1,-i-1]=e[-i-1,0]+self.root
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,j in enumerate(d):
      u+="\t".join([str(i+1),j,"_",q[i][0],"_","_" if len(q[i])<3 else "|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(d) and w[i]["end"]<w[i+1]["start"] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h